generalized_rcnn.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. """
  2. Implements the Generalized R-CNN framework
  3. """
  4. import warnings
  5. from collections import OrderedDict
  6. from typing import Optional, Union
  7. import torch
  8. from torch import nn
  9. from ...utils import _log_api_usage_once
  10. class GeneralizedRCNN(nn.Module):
  11. """
  12. Main class for Generalized R-CNN.
  13. Args:
  14. backbone (nn.Module):
  15. rpn (nn.Module):
  16. roi_heads (nn.Module): takes the features + the proposals from the RPN and computes
  17. detections / masks from it.
  18. transform (nn.Module): performs the data transformation from the inputs to feed into
  19. the model
  20. """
  21. def __init__(
  22. self,
  23. backbone: nn.Module,
  24. rpn: nn.Module,
  25. roi_heads: nn.Module,
  26. transform: nn.Module,
  27. ) -> None:
  28. super().__init__()
  29. _log_api_usage_once(self)
  30. self.transform = transform
  31. self.backbone = backbone
  32. self.rpn = rpn
  33. self.roi_heads = roi_heads
  34. # used only on torchscript mode
  35. self._has_warned = False
  36. @torch.jit.unused
  37. def eager_outputs(
  38. self, losses: dict[str, torch.Tensor], detections: list[dict[str, torch.Tensor]]
  39. ) -> Union[dict[str, torch.Tensor], list[dict[str, torch.Tensor]]]:
  40. if self.training:
  41. return losses
  42. return detections
  43. def forward(
  44. self,
  45. images: list[torch.Tensor],
  46. targets: Optional[list[dict[str, torch.Tensor]]] = None,
  47. ) -> tuple[dict[str, torch.Tensor], list[dict[str, torch.Tensor]]]:
  48. """
  49. Args:
  50. images (list[Tensor]): images to be processed
  51. targets (list[dict[str, tensor]]): ground-truth boxes present in the image (optional)
  52. Returns:
  53. result (list[BoxList] or dict[Tensor]): the output from the model.
  54. During training, it returns a dict[Tensor] which contains the losses.
  55. During testing, it returns list[BoxList] contains additional fields
  56. like `scores`, `labels` and `mask` (for Mask R-CNN models).
  57. """
  58. if self.training:
  59. if targets is None:
  60. torch._assert(False, "targets should not be none when in training mode")
  61. else:
  62. for target in targets:
  63. boxes = target["boxes"]
  64. if isinstance(boxes, torch.Tensor):
  65. torch._assert(
  66. len(boxes.shape) == 2 and boxes.shape[-1] == 4,
  67. f"Expected target boxes to be a tensor of shape [N, 4], got {boxes.shape}.",
  68. )
  69. else:
  70. torch._assert(
  71. False,
  72. f"Expected target boxes to be of type Tensor, got {type(boxes)}.",
  73. )
  74. original_image_sizes: list[tuple[int, int]] = []
  75. for img in images:
  76. val = img.shape[-2:]
  77. torch._assert(
  78. len(val) == 2,
  79. f"expecting the last two dimensions of the Tensor to be H and W instead got {img.shape[-2:]}",
  80. )
  81. original_image_sizes.append((val[0], val[1]))
  82. images, targets = self.transform(images, targets)
  83. # Check for degenerate boxes
  84. # TODO: Move this to a function
  85. if targets is not None:
  86. for target_idx, target in enumerate(targets):
  87. boxes = target["boxes"]
  88. degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
  89. if degenerate_boxes.any():
  90. # print the first degenerate box
  91. bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0]
  92. degen_bb: list[float] = boxes[bb_idx].tolist()
  93. torch._assert(
  94. False,
  95. "All bounding boxes should have positive height and width."
  96. f" Found invalid box {degen_bb} for target at index {target_idx}.",
  97. )
  98. features = self.backbone(images.tensors)
  99. if isinstance(features, torch.Tensor):
  100. features = OrderedDict([("0", features)])
  101. proposals, proposal_losses = self.rpn(images, features, targets)
  102. detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets)
  103. detections = self.transform.postprocess(
  104. detections, images.image_sizes, original_image_sizes
  105. ) # type: ignore[operator]
  106. losses = {}
  107. losses.update(detector_losses)
  108. losses.update(proposal_losses)
  109. if torch.jit.is_scripting():
  110. if not self._has_warned:
  111. warnings.warn("RCNN always returns a (Losses, Detections) tuple in scripting")
  112. self._has_warned = True
  113. return losses, detections
  114. else:
  115. return self.eager_outputs(losses, detections)