rpn.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387
  1. from typing import Optional
  2. import torch
  3. from torch import nn, Tensor
  4. from torch.nn import functional as F
  5. from torchvision.ops import boxes as box_ops, Conv2dNormActivation
  6. from . import _utils as det_utils
  7. # Import AnchorGenerator to keep compatibility.
  8. from .anchor_utils import AnchorGenerator # noqa: 401
  9. from .image_list import ImageList
  10. class RPNHead(nn.Module):
  11. """
  12. Adds a simple RPN Head with classification and regression heads
  13. Args:
  14. in_channels (int): number of channels of the input feature
  15. num_anchors (int): number of anchors to be predicted
  16. conv_depth (int, optional): number of convolutions
  17. """
  18. _version = 2
  19. def __init__(self, in_channels: int, num_anchors: int, conv_depth=1) -> None:
  20. super().__init__()
  21. convs = []
  22. for _ in range(conv_depth):
  23. convs.append(Conv2dNormActivation(in_channels, in_channels, kernel_size=3, norm_layer=None))
  24. self.conv = nn.Sequential(*convs)
  25. self.cls_logits = nn.Conv2d(in_channels, num_anchors, kernel_size=1, stride=1)
  26. self.bbox_pred = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=1, stride=1)
  27. for layer in self.modules():
  28. if isinstance(layer, nn.Conv2d):
  29. torch.nn.init.normal_(layer.weight, std=0.01) # type: ignore[arg-type]
  30. if layer.bias is not None:
  31. torch.nn.init.constant_(layer.bias, 0) # type: ignore[arg-type]
  32. def _load_from_state_dict(
  33. self,
  34. state_dict,
  35. prefix,
  36. local_metadata,
  37. strict,
  38. missing_keys,
  39. unexpected_keys,
  40. error_msgs,
  41. ):
  42. version = local_metadata.get("version", None)
  43. if version is None or version < 2:
  44. for type in ["weight", "bias"]:
  45. old_key = f"{prefix}conv.{type}"
  46. new_key = f"{prefix}conv.0.0.{type}"
  47. if old_key in state_dict:
  48. state_dict[new_key] = state_dict.pop(old_key)
  49. super()._load_from_state_dict(
  50. state_dict,
  51. prefix,
  52. local_metadata,
  53. strict,
  54. missing_keys,
  55. unexpected_keys,
  56. error_msgs,
  57. )
  58. def forward(self, x: list[Tensor]) -> tuple[list[Tensor], list[Tensor]]:
  59. logits = []
  60. bbox_reg = []
  61. for feature in x:
  62. t = self.conv(feature)
  63. logits.append(self.cls_logits(t))
  64. bbox_reg.append(self.bbox_pred(t))
  65. return logits, bbox_reg
  66. def permute_and_flatten(layer: Tensor, N: int, A: int, C: int, H: int, W: int) -> Tensor:
  67. layer = layer.view(N, -1, C, H, W)
  68. layer = layer.permute(0, 3, 4, 1, 2)
  69. layer = layer.reshape(N, -1, C)
  70. return layer
  71. def concat_box_prediction_layers(box_cls: list[Tensor], box_regression: list[Tensor]) -> tuple[Tensor, Tensor]:
  72. box_cls_flattened = []
  73. box_regression_flattened = []
  74. # for each feature level, permute the outputs to make them be in the
  75. # same format as the labels. Note that the labels are computed for
  76. # all feature levels concatenated, so we keep the same representation
  77. # for the objectness and the box_regression
  78. for box_cls_per_level, box_regression_per_level in zip(box_cls, box_regression):
  79. N, AxC, H, W = box_cls_per_level.shape
  80. Ax4 = box_regression_per_level.shape[1]
  81. A = Ax4 // 4
  82. C = AxC // A
  83. box_cls_per_level = permute_and_flatten(box_cls_per_level, N, A, C, H, W)
  84. box_cls_flattened.append(box_cls_per_level)
  85. box_regression_per_level = permute_and_flatten(box_regression_per_level, N, A, 4, H, W)
  86. box_regression_flattened.append(box_regression_per_level)
  87. # concatenate on the first dimension (representing the feature levels), to
  88. # take into account the way the labels were generated (with all feature maps
  89. # being concatenated as well)
  90. box_cls = torch.cat(box_cls_flattened, dim=1).flatten(0, -2)
  91. box_regression = torch.cat(box_regression_flattened, dim=1).reshape(-1, 4)
  92. return box_cls, box_regression
  93. class RegionProposalNetwork(torch.nn.Module):
  94. """
  95. Implements Region Proposal Network (RPN).
  96. Args:
  97. anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
  98. maps.
  99. head (nn.Module): module that computes the objectness and regression deltas
  100. fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
  101. considered as positive during training of the RPN.
  102. bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
  103. considered as negative during training of the RPN.
  104. batch_size_per_image (int): number of anchors that are sampled during training of the RPN
  105. for computing the loss
  106. positive_fraction (float): proportion of positive anchors in a mini-batch during training
  107. of the RPN
  108. pre_nms_top_n (Dict[str, int]): number of proposals to keep before applying NMS. It should
  109. contain two fields: training and testing, to allow for different values depending
  110. on training or evaluation
  111. post_nms_top_n (Dict[str, int]): number of proposals to keep after applying NMS. It should
  112. contain two fields: training and testing, to allow for different values depending
  113. on training or evaluation
  114. nms_thresh (float): NMS threshold used for postprocessing the RPN proposals
  115. score_thresh (float): only return proposals with an objectness score greater than score_thresh
  116. """
  117. __annotations__ = {
  118. "box_coder": det_utils.BoxCoder,
  119. "proposal_matcher": det_utils.Matcher,
  120. "fg_bg_sampler": det_utils.BalancedPositiveNegativeSampler,
  121. }
  122. def __init__(
  123. self,
  124. anchor_generator: AnchorGenerator,
  125. head: nn.Module,
  126. # Faster-RCNN Training
  127. fg_iou_thresh: float,
  128. bg_iou_thresh: float,
  129. batch_size_per_image: int,
  130. positive_fraction: float,
  131. # Faster-RCNN Inference
  132. pre_nms_top_n: dict[str, int],
  133. post_nms_top_n: dict[str, int],
  134. nms_thresh: float,
  135. score_thresh: float = 0.0,
  136. ) -> None:
  137. super().__init__()
  138. self.anchor_generator = anchor_generator
  139. self.head = head
  140. self.box_coder = det_utils.BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))
  141. # used during training
  142. self.box_similarity = box_ops.box_iou
  143. self.proposal_matcher = det_utils.Matcher(
  144. fg_iou_thresh,
  145. bg_iou_thresh,
  146. allow_low_quality_matches=True,
  147. )
  148. self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler(batch_size_per_image, positive_fraction)
  149. # used during testing
  150. self._pre_nms_top_n = pre_nms_top_n
  151. self._post_nms_top_n = post_nms_top_n
  152. self.nms_thresh = nms_thresh
  153. self.score_thresh = score_thresh
  154. self.min_size = 1e-3
  155. def pre_nms_top_n(self) -> int:
  156. if self.training:
  157. return self._pre_nms_top_n["training"]
  158. return self._pre_nms_top_n["testing"]
  159. def post_nms_top_n(self) -> int:
  160. if self.training:
  161. return self._post_nms_top_n["training"]
  162. return self._post_nms_top_n["testing"]
  163. def assign_targets_to_anchors(
  164. self, anchors: list[Tensor], targets: list[dict[str, Tensor]]
  165. ) -> tuple[list[Tensor], list[Tensor]]:
  166. labels = []
  167. matched_gt_boxes = []
  168. for anchors_per_image, targets_per_image in zip(anchors, targets):
  169. gt_boxes = targets_per_image["boxes"]
  170. if gt_boxes.numel() == 0:
  171. # Background image (negative example)
  172. device = anchors_per_image.device
  173. matched_gt_boxes_per_image = torch.zeros(anchors_per_image.shape, dtype=torch.float32, device=device)
  174. labels_per_image = torch.zeros((anchors_per_image.shape[0],), dtype=torch.float32, device=device)
  175. else:
  176. match_quality_matrix = self.box_similarity(gt_boxes, anchors_per_image)
  177. matched_idxs = self.proposal_matcher(match_quality_matrix)
  178. # get the targets corresponding GT for each proposal
  179. # NB: need to clamp the indices because we can have a single
  180. # GT in the image, and matched_idxs can be -2, which goes
  181. # out of bounds
  182. matched_gt_boxes_per_image = gt_boxes[matched_idxs.clamp(min=0)]
  183. labels_per_image = matched_idxs >= 0
  184. labels_per_image = labels_per_image.to(dtype=torch.float32)
  185. # Background (negative examples)
  186. bg_indices = matched_idxs == self.proposal_matcher.BELOW_LOW_THRESHOLD
  187. labels_per_image[bg_indices] = 0.0
  188. # discard indices that are between thresholds
  189. inds_to_discard = matched_idxs == self.proposal_matcher.BETWEEN_THRESHOLDS
  190. labels_per_image[inds_to_discard] = -1.0
  191. labels.append(labels_per_image)
  192. matched_gt_boxes.append(matched_gt_boxes_per_image)
  193. return labels, matched_gt_boxes
  194. def _get_top_n_idx(self, objectness: Tensor, num_anchors_per_level: list[int]) -> Tensor:
  195. r = []
  196. offset = 0
  197. for ob in objectness.split(num_anchors_per_level, 1):
  198. num_anchors = ob.shape[1]
  199. pre_nms_top_n = det_utils._topk_min(ob, self.pre_nms_top_n(), 1)
  200. _, top_n_idx = ob.topk(pre_nms_top_n, dim=1)
  201. r.append(top_n_idx + offset)
  202. offset += num_anchors
  203. return torch.cat(r, dim=1)
  204. def filter_proposals(
  205. self,
  206. proposals: Tensor,
  207. objectness: Tensor,
  208. image_shapes: list[tuple[int, int]],
  209. num_anchors_per_level: list[int],
  210. ) -> tuple[list[Tensor], list[Tensor]]:
  211. num_images = proposals.shape[0]
  212. device = proposals.device
  213. # do not backprop through objectness
  214. objectness = objectness.detach()
  215. objectness = objectness.reshape(num_images, -1)
  216. levels = [
  217. torch.full((n,), idx, dtype=torch.int64, device=device) for idx, n in enumerate(num_anchors_per_level)
  218. ]
  219. levels = torch.cat(levels, 0)
  220. levels = levels.reshape(1, -1).expand_as(objectness)
  221. # select top_n boxes independently per level before applying nms
  222. top_n_idx = self._get_top_n_idx(objectness, num_anchors_per_level)
  223. image_range = torch.arange(num_images, device=device)
  224. batch_idx = image_range[:, None]
  225. objectness = objectness[batch_idx, top_n_idx]
  226. levels = levels[batch_idx, top_n_idx]
  227. proposals = proposals[batch_idx, top_n_idx]
  228. objectness_prob = torch.sigmoid(objectness)
  229. final_boxes = []
  230. final_scores = []
  231. for boxes, scores, lvl, img_shape in zip(proposals, objectness_prob, levels, image_shapes):
  232. boxes = box_ops.clip_boxes_to_image(boxes, img_shape)
  233. # remove small boxes
  234. keep = box_ops.remove_small_boxes(boxes, self.min_size)
  235. boxes, scores, lvl = boxes[keep], scores[keep], lvl[keep]
  236. # remove low scoring boxes
  237. # use >= for Backwards compatibility
  238. keep = torch.where(scores >= self.score_thresh)[0]
  239. boxes, scores, lvl = boxes[keep], scores[keep], lvl[keep]
  240. # non-maximum suppression, independently done per level
  241. keep = box_ops.batched_nms(boxes, scores, lvl, self.nms_thresh)
  242. # keep only topk scoring predictions
  243. keep = keep[: self.post_nms_top_n()]
  244. boxes, scores = boxes[keep], scores[keep]
  245. final_boxes.append(boxes)
  246. final_scores.append(scores)
  247. return final_boxes, final_scores
  248. def compute_loss(
  249. self, objectness: Tensor, pred_bbox_deltas: Tensor, labels: list[Tensor], regression_targets: list[Tensor]
  250. ) -> tuple[Tensor, Tensor]:
  251. """
  252. Args:
  253. objectness (Tensor)
  254. pred_bbox_deltas (Tensor)
  255. labels (List[Tensor])
  256. regression_targets (List[Tensor])
  257. Returns:
  258. objectness_loss (Tensor)
  259. box_loss (Tensor)
  260. """
  261. sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
  262. sampled_pos_inds = torch.where(torch.cat(sampled_pos_inds, dim=0))[0]
  263. sampled_neg_inds = torch.where(torch.cat(sampled_neg_inds, dim=0))[0]
  264. sampled_inds = torch.cat([sampled_pos_inds, sampled_neg_inds], dim=0)
  265. objectness = objectness.flatten()
  266. labels = torch.cat(labels, dim=0)
  267. regression_targets = torch.cat(regression_targets, dim=0)
  268. box_loss = F.smooth_l1_loss(
  269. pred_bbox_deltas[sampled_pos_inds],
  270. regression_targets[sampled_pos_inds],
  271. beta=1 / 9,
  272. reduction="sum",
  273. ) / (sampled_inds.numel())
  274. objectness_loss = F.binary_cross_entropy_with_logits(objectness[sampled_inds], labels[sampled_inds])
  275. return objectness_loss, box_loss
  276. def forward(
  277. self,
  278. images: ImageList,
  279. features: dict[str, Tensor],
  280. targets: Optional[list[dict[str, Tensor]]] = None,
  281. ) -> tuple[list[Tensor], dict[str, Tensor]]:
  282. """
  283. Args:
  284. images (ImageList): images for which we want to compute the predictions
  285. features (Dict[str, Tensor]): features computed from the images that are
  286. used for computing the predictions. Each tensor in the list
  287. correspond to different feature levels
  288. targets (List[Dict[str, Tensor]]): ground-truth boxes present in the image (optional).
  289. If provided, each element in the dict should contain a field `boxes`,
  290. with the locations of the ground-truth boxes.
  291. Returns:
  292. boxes (List[Tensor]): the predicted boxes from the RPN, one Tensor per
  293. image.
  294. losses (Dict[str, Tensor]): the losses for the model during training. During
  295. testing, it is an empty dict.
  296. """
  297. # RPN uses all feature maps that are available
  298. features = list(features.values())
  299. objectness, pred_bbox_deltas = self.head(features)
  300. anchors = self.anchor_generator(images, features)
  301. num_images = len(anchors)
  302. num_anchors_per_level_shape_tensors = [o[0].shape for o in objectness]
  303. num_anchors_per_level = [s[0] * s[1] * s[2] for s in num_anchors_per_level_shape_tensors]
  304. objectness, pred_bbox_deltas = concat_box_prediction_layers(objectness, pred_bbox_deltas)
  305. # apply pred_bbox_deltas to anchors to obtain the decoded proposals
  306. # note that we detach the deltas because Faster R-CNN do not backprop through
  307. # the proposals
  308. proposals = self.box_coder.decode(pred_bbox_deltas.detach(), anchors)
  309. proposals = proposals.view(num_images, -1, 4)
  310. boxes, scores = self.filter_proposals(proposals, objectness, images.image_sizes, num_anchors_per_level)
  311. losses = {}
  312. if self.training:
  313. if targets is None:
  314. raise ValueError("targets should not be None")
  315. labels, matched_gt_boxes = self.assign_targets_to_anchors(anchors, targets)
  316. regression_targets = self.box_coder.encode(matched_gt_boxes, anchors)
  317. loss_objectness, loss_rpn_box_reg = self.compute_loss(
  318. objectness, pred_bbox_deltas, labels, regression_targets
  319. )
  320. losses = {
  321. "loss_objectness": loss_objectness,
  322. "loss_rpn_box_reg": loss_rpn_box_reg,
  323. }
  324. return boxes, losses