roi_heads.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887
  1. from typing import Optional
  2. import torch
  3. import torch.nn.functional as F
  4. import torchvision
  5. from torch import nn
  6. from torchvision.ops import boxes as box_ops, roi_align
  7. from . import _utils as det_utils
  8. def fastrcnn_loss(
  9. class_logits: torch.Tensor,
  10. box_regression: torch.Tensor,
  11. labels: list[torch.Tensor],
  12. regression_targets: list[torch.Tensor],
  13. ) -> tuple[torch.Tensor, torch.Tensor]:
  14. """
  15. Computes the loss for Faster R-CNN.
  16. Args:
  17. class_logits (Tensor)
  18. box_regression (Tensor)
  19. labels (list[BoxList])
  20. regression_targets (Tensor)
  21. Returns:
  22. classification_loss (Tensor)
  23. box_loss (Tensor)
  24. """
  25. labels = torch.cat(labels, dim=0)
  26. regression_targets = torch.cat(regression_targets, dim=0)
  27. classification_loss = F.cross_entropy(class_logits, labels)
  28. # get indices that correspond to the regression targets for
  29. # the corresponding ground truth labels, to be used with
  30. # advanced indexing
  31. sampled_pos_inds_subset = torch.where(labels > 0)[0]
  32. labels_pos = labels[sampled_pos_inds_subset]
  33. N, num_classes = class_logits.shape
  34. box_regression = box_regression.reshape(N, box_regression.size(-1) // 4, 4)
  35. box_loss = F.smooth_l1_loss(
  36. box_regression[sampled_pos_inds_subset, labels_pos],
  37. regression_targets[sampled_pos_inds_subset],
  38. beta=1 / 9,
  39. reduction="sum",
  40. )
  41. box_loss = box_loss / labels.numel()
  42. return classification_loss, box_loss
  43. def maskrcnn_inference(x: torch.Tensor, labels: list[torch.Tensor]) -> list[torch.Tensor]:
  44. """
  45. From the results of the CNN, post process the masks
  46. by taking the mask corresponding to the class with max
  47. probability (which are of fixed size and directly output
  48. by the CNN) and return the masks in the mask field of the BoxList.
  49. Args:
  50. x (Tensor): the mask logits
  51. labels (list[BoxList]): bounding boxes that are used as
  52. reference, one for each image
  53. Returns:
  54. results (list[BoxList]): one BoxList for each image, containing
  55. the extra field mask
  56. """
  57. mask_prob = x.sigmoid()
  58. # select masks corresponding to the predicted classes
  59. num_masks = x.shape[0]
  60. boxes_per_image = [label.shape[0] for label in labels]
  61. labels = torch.cat(labels)
  62. index = torch.arange(num_masks, device=labels.device)
  63. mask_prob = mask_prob[index, labels][:, None]
  64. mask_prob = mask_prob.split(boxes_per_image, dim=0)
  65. return mask_prob
  66. def project_masks_on_boxes(gt_masks, boxes, matched_idxs, M):
  67. # type: (Tensor, Tensor, Tensor, int) -> Tensor
  68. """
  69. Given segmentation masks and the bounding boxes corresponding
  70. to the location of the masks in the image, this function
  71. crops and resizes the masks in the position defined by the
  72. boxes. This prepares the masks for them to be fed to the
  73. loss computation as the targets.
  74. """
  75. matched_idxs = matched_idxs.to(boxes)
  76. rois = torch.cat([matched_idxs[:, None], boxes], dim=1)
  77. gt_masks = gt_masks[:, None].to(rois)
  78. return roi_align(gt_masks, rois, (M, M), 1.0)[:, 0]
  79. def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs):
  80. # type: (Tensor, list[Tensor], list[Tensor], list[Tensor], list[Tensor]) -> Tensor
  81. """
  82. Args:
  83. proposals (list[BoxList])
  84. mask_logits (Tensor)
  85. targets (list[BoxList])
  86. Return:
  87. mask_loss (Tensor): scalar tensor containing the loss
  88. """
  89. discretization_size = mask_logits.shape[-1]
  90. labels = [gt_label[idxs] for gt_label, idxs in zip(gt_labels, mask_matched_idxs)]
  91. mask_targets = [
  92. project_masks_on_boxes(m, p, i, discretization_size) for m, p, i in zip(gt_masks, proposals, mask_matched_idxs)
  93. ]
  94. labels = torch.cat(labels, dim=0)
  95. mask_targets = torch.cat(mask_targets, dim=0)
  96. # torch.mean (in binary_cross_entropy_with_logits) doesn't
  97. # accept empty tensors, so handle it separately
  98. if mask_targets.numel() == 0:
  99. return mask_logits.sum() * 0
  100. mask_loss = F.binary_cross_entropy_with_logits(
  101. mask_logits[torch.arange(labels.shape[0], device=labels.device), labels], mask_targets
  102. )
  103. return mask_loss
  104. def keypoints_to_heatmap(keypoints, rois, heatmap_size):
  105. # type: (Tensor, Tensor, int) -> tuple[Tensor, Tensor]
  106. offset_x = rois[:, 0]
  107. offset_y = rois[:, 1]
  108. scale_x = heatmap_size / (rois[:, 2] - rois[:, 0])
  109. scale_y = heatmap_size / (rois[:, 3] - rois[:, 1])
  110. offset_x = offset_x[:, None]
  111. offset_y = offset_y[:, None]
  112. scale_x = scale_x[:, None]
  113. scale_y = scale_y[:, None]
  114. x = keypoints[..., 0]
  115. y = keypoints[..., 1]
  116. x_boundary_inds = x == rois[:, 2][:, None]
  117. y_boundary_inds = y == rois[:, 3][:, None]
  118. x = (x - offset_x) * scale_x
  119. x = x.floor().long()
  120. y = (y - offset_y) * scale_y
  121. y = y.floor().long()
  122. x[x_boundary_inds] = heatmap_size - 1
  123. y[y_boundary_inds] = heatmap_size - 1
  124. valid_loc = (x >= 0) & (y >= 0) & (x < heatmap_size) & (y < heatmap_size)
  125. vis = keypoints[..., 2] > 0
  126. valid = (valid_loc & vis).long()
  127. lin_ind = y * heatmap_size + x
  128. heatmaps = lin_ind * valid
  129. return heatmaps, valid
  130. def _onnx_heatmaps_to_keypoints(
  131. maps, maps_i, roi_map_width, roi_map_height, widths_i, heights_i, offset_x_i, offset_y_i
  132. ):
  133. num_keypoints = torch.scalar_tensor(maps.size(1), dtype=torch.int64)
  134. width_correction = widths_i / roi_map_width
  135. height_correction = heights_i / roi_map_height
  136. roi_map = F.interpolate(
  137. maps_i[:, None], size=(int(roi_map_height), int(roi_map_width)), mode="bicubic", align_corners=False
  138. )[:, 0]
  139. w = torch.scalar_tensor(roi_map.size(2), dtype=torch.int64)
  140. pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
  141. x_int = pos % w
  142. y_int = (pos - x_int) // w
  143. x = (torch.tensor(0.5, dtype=torch.float32) + x_int.to(dtype=torch.float32)) * width_correction.to(
  144. dtype=torch.float32
  145. )
  146. y = (torch.tensor(0.5, dtype=torch.float32) + y_int.to(dtype=torch.float32)) * height_correction.to(
  147. dtype=torch.float32
  148. )
  149. xy_preds_i_0 = x + offset_x_i.to(dtype=torch.float32)
  150. xy_preds_i_1 = y + offset_y_i.to(dtype=torch.float32)
  151. xy_preds_i_2 = torch.ones(xy_preds_i_1.shape, dtype=torch.float32)
  152. xy_preds_i = torch.stack(
  153. [
  154. xy_preds_i_0.to(dtype=torch.float32),
  155. xy_preds_i_1.to(dtype=torch.float32),
  156. xy_preds_i_2.to(dtype=torch.float32),
  157. ],
  158. 0,
  159. )
  160. # TODO: simplify when indexing without rank will be supported by ONNX
  161. base = num_keypoints * num_keypoints + num_keypoints + 1
  162. ind = torch.arange(num_keypoints)
  163. ind = ind.to(dtype=torch.int64) * base
  164. end_scores_i = (
  165. roi_map.index_select(1, y_int.to(dtype=torch.int64))
  166. .index_select(2, x_int.to(dtype=torch.int64))
  167. .view(-1)
  168. .index_select(0, ind.to(dtype=torch.int64))
  169. )
  170. return xy_preds_i, end_scores_i
  171. @torch.jit._script_if_tracing
  172. def _onnx_heatmaps_to_keypoints_loop(
  173. maps, rois, widths_ceil, heights_ceil, widths, heights, offset_x, offset_y, num_keypoints
  174. ):
  175. xy_preds = torch.zeros((0, 3, int(num_keypoints)), dtype=torch.float32, device=maps.device)
  176. end_scores = torch.zeros((0, int(num_keypoints)), dtype=torch.float32, device=maps.device)
  177. for i in range(int(rois.size(0))):
  178. xy_preds_i, end_scores_i = _onnx_heatmaps_to_keypoints(
  179. maps, maps[i], widths_ceil[i], heights_ceil[i], widths[i], heights[i], offset_x[i], offset_y[i]
  180. )
  181. xy_preds = torch.cat((xy_preds.to(dtype=torch.float32), xy_preds_i.unsqueeze(0).to(dtype=torch.float32)), 0)
  182. end_scores = torch.cat(
  183. (end_scores.to(dtype=torch.float32), end_scores_i.to(dtype=torch.float32).unsqueeze(0)), 0
  184. )
  185. return xy_preds, end_scores
  186. def heatmaps_to_keypoints(maps, rois):
  187. """Extract predicted keypoint locations from heatmaps.
  188. Args:
  189. maps (Tensor[K, N, H, W]): The predicted heatmaps, where K is the number of RoIs,
  190. N is the number of keypoints, and H, W are the heatmap spatial dimensions.
  191. rois (Tensor[K, 4]): The RoI boxes in ``(x1, y1, x2, y2)`` format.
  192. Returns:
  193. tuple:
  194. - **xy_preds** (Tensor[K, N, 3]): The predicted keypoint locations, where the last
  195. dimension contains ``(x, y, v)`` with x, y being coordinates and v being visibility (always 1).
  196. - **scores** (Tensor[K, N]): The heatmap scores at the predicted keypoint locations.
  197. """
  198. # This function converts a discrete image coordinate in a HEATMAP_SIZE x
  199. # HEATMAP_SIZE image to a continuous keypoint coordinate. We maintain
  200. # consistency with keypoints_to_heatmap_labels by using the conversion from
  201. # Heckbert 1990: c = d + 0.5, where d is a discrete coordinate and c is a
  202. # continuous coordinate.
  203. offset_x = rois[:, 0]
  204. offset_y = rois[:, 1]
  205. widths = rois[:, 2] - rois[:, 0]
  206. heights = rois[:, 3] - rois[:, 1]
  207. widths = widths.clamp(min=1)
  208. heights = heights.clamp(min=1)
  209. widths_ceil = widths.ceil()
  210. heights_ceil = heights.ceil()
  211. num_keypoints = maps.shape[1]
  212. if torchvision._is_tracing():
  213. xy_preds, end_scores = _onnx_heatmaps_to_keypoints_loop(
  214. maps,
  215. rois,
  216. widths_ceil,
  217. heights_ceil,
  218. widths,
  219. heights,
  220. offset_x,
  221. offset_y,
  222. torch.scalar_tensor(num_keypoints, dtype=torch.int64),
  223. )
  224. return xy_preds.permute(0, 2, 1), end_scores
  225. xy_preds = torch.zeros((len(rois), 3, num_keypoints), dtype=torch.float32, device=maps.device)
  226. end_scores = torch.zeros((len(rois), num_keypoints), dtype=torch.float32, device=maps.device)
  227. for i in range(len(rois)):
  228. roi_map_width = int(widths_ceil[i].item())
  229. roi_map_height = int(heights_ceil[i].item())
  230. width_correction = widths[i] / roi_map_width
  231. height_correction = heights[i] / roi_map_height
  232. roi_map = F.interpolate(
  233. maps[i][:, None], size=(roi_map_height, roi_map_width), mode="bicubic", align_corners=False
  234. )[:, 0]
  235. # roi_map_probs = scores_to_probs(roi_map.copy())
  236. w = roi_map.shape[2]
  237. pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
  238. x_int = pos % w
  239. y_int = torch.div(pos - x_int, w, rounding_mode="floor")
  240. # assert (roi_map_probs[k, y_int, x_int] ==
  241. # roi_map_probs[k, :, :].max())
  242. x = (x_int.float() + 0.5) * width_correction
  243. y = (y_int.float() + 0.5) * height_correction
  244. xy_preds[i, 0, :] = x + offset_x[i]
  245. xy_preds[i, 1, :] = y + offset_y[i]
  246. xy_preds[i, 2, :] = 1
  247. end_scores[i, :] = roi_map[torch.arange(num_keypoints, device=roi_map.device), y_int, x_int]
  248. return xy_preds.permute(0, 2, 1), end_scores
  249. def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched_idxs):
  250. # type: (Tensor, list[Tensor], list[Tensor], list[Tensor]) -> Tensor
  251. N, K, H, W = keypoint_logits.shape
  252. if H != W:
  253. raise ValueError(
  254. f"keypoint_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
  255. )
  256. discretization_size = H
  257. heatmaps = []
  258. valid = []
  259. for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_keypoints, keypoint_matched_idxs):
  260. kp = gt_kp_in_image[midx]
  261. heatmaps_per_image, valid_per_image = keypoints_to_heatmap(kp, proposals_per_image, discretization_size)
  262. heatmaps.append(heatmaps_per_image.view(-1))
  263. valid.append(valid_per_image.view(-1))
  264. keypoint_targets = torch.cat(heatmaps, dim=0)
  265. valid = torch.cat(valid, dim=0).to(dtype=torch.uint8)
  266. valid = torch.where(valid)[0]
  267. # torch.mean (in binary_cross_entropy_with_logits) doesn't
  268. # accept empty tensors, so handle it sepaartely
  269. if keypoint_targets.numel() == 0 or len(valid) == 0:
  270. return keypoint_logits.sum() * 0
  271. keypoint_logits = keypoint_logits.view(N * K, H * W)
  272. keypoint_loss = F.cross_entropy(keypoint_logits[valid], keypoint_targets[valid])
  273. return keypoint_loss
  274. def keypointrcnn_inference(x, boxes):
  275. # type: (Tensor, list[Tensor]) -> tuple[list[Tensor], list[Tensor]]
  276. kp_probs = []
  277. kp_scores = []
  278. boxes_per_image = [box.size(0) for box in boxes]
  279. x2 = x.split(boxes_per_image, dim=0)
  280. for xx, bb in zip(x2, boxes):
  281. kp_prob, scores = heatmaps_to_keypoints(xx, bb)
  282. kp_probs.append(kp_prob)
  283. kp_scores.append(scores)
  284. return kp_probs, kp_scores
  285. def _onnx_expand_boxes(boxes, scale):
  286. # type: (Tensor, float) -> Tensor
  287. w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5
  288. h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5
  289. x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5
  290. y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5
  291. w_half = w_half.to(dtype=torch.float32) * scale
  292. h_half = h_half.to(dtype=torch.float32) * scale
  293. boxes_exp0 = x_c - w_half
  294. boxes_exp1 = y_c - h_half
  295. boxes_exp2 = x_c + w_half
  296. boxes_exp3 = y_c + h_half
  297. boxes_exp = torch.stack((boxes_exp0, boxes_exp1, boxes_exp2, boxes_exp3), 1)
  298. return boxes_exp
  299. # the next two functions should be merged inside Masker
  300. # but are kept here for the moment while we need them
  301. # temporarily for paste_mask_in_image
  302. def expand_boxes(boxes, scale):
  303. # type: (Tensor, float) -> Tensor
  304. if torchvision._is_tracing():
  305. return _onnx_expand_boxes(boxes, scale)
  306. w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5
  307. h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5
  308. x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5
  309. y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5
  310. w_half *= scale
  311. h_half *= scale
  312. boxes_exp = torch.zeros_like(boxes)
  313. boxes_exp[:, 0] = x_c - w_half
  314. boxes_exp[:, 2] = x_c + w_half
  315. boxes_exp[:, 1] = y_c - h_half
  316. boxes_exp[:, 3] = y_c + h_half
  317. return boxes_exp
  318. @torch.jit.unused
  319. def expand_masks_tracing_scale(M, padding):
  320. # type: (int, int) -> float
  321. return torch.tensor(M + 2 * padding).to(torch.float32) / torch.tensor(M).to(torch.float32)
  322. def expand_masks(mask, padding):
  323. # type: (Tensor, int) -> tuple[Tensor, float]
  324. M = mask.shape[-1]
  325. if torch._C._get_tracing_state(): # could not import is_tracing(), not sure why
  326. scale = expand_masks_tracing_scale(M, padding)
  327. else:
  328. scale = float(M + 2 * padding) / M
  329. padded_mask = F.pad(mask, (padding,) * 4)
  330. return padded_mask, scale
  331. def paste_mask_in_image(mask, box, im_h, im_w):
  332. # type: (Tensor, Tensor, int, int) -> Tensor
  333. TO_REMOVE = 1
  334. w = int(box[2] - box[0] + TO_REMOVE)
  335. h = int(box[3] - box[1] + TO_REMOVE)
  336. w = max(w, 1)
  337. h = max(h, 1)
  338. # Set shape to [batchxCxHxW]
  339. mask = mask.expand((1, 1, -1, -1))
  340. # Resize mask
  341. mask = F.interpolate(mask, size=(h, w), mode="bilinear", align_corners=False)
  342. mask = mask[0][0]
  343. im_mask = torch.zeros((im_h, im_w), dtype=mask.dtype, device=mask.device)
  344. x_0 = max(box[0], 0)
  345. x_1 = min(box[2] + 1, im_w)
  346. y_0 = max(box[1], 0)
  347. y_1 = min(box[3] + 1, im_h)
  348. im_mask[y_0:y_1, x_0:x_1] = mask[(y_0 - box[1]) : (y_1 - box[1]), (x_0 - box[0]) : (x_1 - box[0])]
  349. return im_mask
  350. def _onnx_paste_mask_in_image(mask, box, im_h, im_w):
  351. one = torch.ones(1, dtype=torch.int64)
  352. zero = torch.zeros(1, dtype=torch.int64)
  353. w = box[2] - box[0] + one
  354. h = box[3] - box[1] + one
  355. w = torch.max(torch.cat((w, one)))
  356. h = torch.max(torch.cat((h, one)))
  357. # Set shape to [batchxCxHxW]
  358. mask = mask.expand((1, 1, mask.size(0), mask.size(1)))
  359. # Resize mask
  360. mask = F.interpolate(mask, size=(int(h), int(w)), mode="bilinear", align_corners=False)
  361. mask = mask[0][0]
  362. x_0 = torch.max(torch.cat((box[0].unsqueeze(0), zero)))
  363. x_1 = torch.min(torch.cat((box[2].unsqueeze(0) + one, im_w.unsqueeze(0))))
  364. y_0 = torch.max(torch.cat((box[1].unsqueeze(0), zero)))
  365. y_1 = torch.min(torch.cat((box[3].unsqueeze(0) + one, im_h.unsqueeze(0))))
  366. unpaded_im_mask = mask[(y_0 - box[1]) : (y_1 - box[1]), (x_0 - box[0]) : (x_1 - box[0])]
  367. # TODO : replace below with a dynamic padding when support is added in ONNX
  368. # pad y
  369. zeros_y0 = torch.zeros(y_0, unpaded_im_mask.size(1))
  370. zeros_y1 = torch.zeros(im_h - y_1, unpaded_im_mask.size(1))
  371. concat_0 = torch.cat((zeros_y0, unpaded_im_mask.to(dtype=torch.float32), zeros_y1), 0)[0:im_h, :]
  372. # pad x
  373. zeros_x0 = torch.zeros(concat_0.size(0), x_0)
  374. zeros_x1 = torch.zeros(concat_0.size(0), im_w - x_1)
  375. im_mask = torch.cat((zeros_x0, concat_0, zeros_x1), 1)[:, :im_w]
  376. return im_mask
  377. @torch.jit._script_if_tracing
  378. def _onnx_paste_masks_in_image_loop(masks, boxes, im_h, im_w):
  379. res_append = torch.zeros(0, im_h, im_w)
  380. for i in range(masks.size(0)):
  381. mask_res = _onnx_paste_mask_in_image(masks[i][0], boxes[i], im_h, im_w)
  382. mask_res = mask_res.unsqueeze(0)
  383. res_append = torch.cat((res_append, mask_res))
  384. return res_append
  385. def paste_masks_in_image(masks, boxes, img_shape, padding=1):
  386. # type: (Tensor, Tensor, tuple[int, int], int) -> Tensor
  387. masks, scale = expand_masks(masks, padding=padding)
  388. boxes = expand_boxes(boxes, scale).to(dtype=torch.int64)
  389. im_h, im_w = img_shape
  390. if torchvision._is_tracing():
  391. return _onnx_paste_masks_in_image_loop(
  392. masks, boxes, torch.scalar_tensor(im_h, dtype=torch.int64), torch.scalar_tensor(im_w, dtype=torch.int64)
  393. )[:, None]
  394. res = [paste_mask_in_image(m[0], b, im_h, im_w) for m, b in zip(masks, boxes)]
  395. if len(res) > 0:
  396. ret = torch.stack(res, dim=0)[:, None]
  397. else:
  398. ret = masks.new_empty((0, 1, im_h, im_w))
  399. return ret
  400. class RoIHeads(nn.Module):
  401. __annotations__ = {
  402. "box_coder": det_utils.BoxCoder,
  403. "proposal_matcher": det_utils.Matcher,
  404. "fg_bg_sampler": det_utils.BalancedPositiveNegativeSampler,
  405. }
  406. def __init__(
  407. self,
  408. box_roi_pool,
  409. box_head,
  410. box_predictor,
  411. # Faster R-CNN training
  412. fg_iou_thresh,
  413. bg_iou_thresh,
  414. batch_size_per_image,
  415. positive_fraction,
  416. bbox_reg_weights,
  417. # Faster R-CNN inference
  418. score_thresh,
  419. nms_thresh,
  420. detections_per_img,
  421. # Mask
  422. mask_roi_pool=None,
  423. mask_head=None,
  424. mask_predictor=None,
  425. keypoint_roi_pool=None,
  426. keypoint_head=None,
  427. keypoint_predictor=None,
  428. ):
  429. super().__init__()
  430. self.box_similarity = box_ops.box_iou
  431. # assign ground-truth boxes for each proposal
  432. self.proposal_matcher = det_utils.Matcher(fg_iou_thresh, bg_iou_thresh, allow_low_quality_matches=False)
  433. self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler(batch_size_per_image, positive_fraction)
  434. if bbox_reg_weights is None:
  435. bbox_reg_weights = (10.0, 10.0, 5.0, 5.0)
  436. self.box_coder = det_utils.BoxCoder(bbox_reg_weights)
  437. self.box_roi_pool = box_roi_pool
  438. self.box_head = box_head
  439. self.box_predictor = box_predictor
  440. self.score_thresh = score_thresh
  441. self.nms_thresh = nms_thresh
  442. self.detections_per_img = detections_per_img
  443. self.mask_roi_pool = mask_roi_pool
  444. self.mask_head = mask_head
  445. self.mask_predictor = mask_predictor
  446. self.keypoint_roi_pool = keypoint_roi_pool
  447. self.keypoint_head = keypoint_head
  448. self.keypoint_predictor = keypoint_predictor
  449. def has_mask(self):
  450. if self.mask_roi_pool is None:
  451. return False
  452. if self.mask_head is None:
  453. return False
  454. if self.mask_predictor is None:
  455. return False
  456. return True
  457. def has_keypoint(self):
  458. if self.keypoint_roi_pool is None:
  459. return False
  460. if self.keypoint_head is None:
  461. return False
  462. if self.keypoint_predictor is None:
  463. return False
  464. return True
  465. def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels):
  466. # type: (list[Tensor], list[Tensor], list[Tensor]) -> tuple[list[Tensor], list[Tensor]]
  467. matched_idxs = []
  468. labels = []
  469. for proposals_in_image, gt_boxes_in_image, gt_labels_in_image in zip(proposals, gt_boxes, gt_labels):
  470. if gt_boxes_in_image.numel() == 0:
  471. # Background image
  472. device = proposals_in_image.device
  473. clamped_matched_idxs_in_image = torch.zeros(
  474. (proposals_in_image.shape[0],), dtype=torch.int64, device=device
  475. )
  476. labels_in_image = torch.zeros((proposals_in_image.shape[0],), dtype=torch.int64, device=device)
  477. else:
  478. # set to self.box_similarity when https://github.com/pytorch/pytorch/issues/27495 lands
  479. match_quality_matrix = box_ops.box_iou(gt_boxes_in_image, proposals_in_image)
  480. matched_idxs_in_image = self.proposal_matcher(match_quality_matrix)
  481. clamped_matched_idxs_in_image = matched_idxs_in_image.clamp(min=0)
  482. labels_in_image = gt_labels_in_image[clamped_matched_idxs_in_image]
  483. labels_in_image = labels_in_image.to(dtype=torch.int64)
  484. # Label background (below the low threshold)
  485. bg_inds = matched_idxs_in_image == self.proposal_matcher.BELOW_LOW_THRESHOLD
  486. labels_in_image[bg_inds] = 0
  487. # Label ignore proposals (between low and high thresholds)
  488. ignore_inds = matched_idxs_in_image == self.proposal_matcher.BETWEEN_THRESHOLDS
  489. labels_in_image[ignore_inds] = -1 # -1 is ignored by sampler
  490. matched_idxs.append(clamped_matched_idxs_in_image)
  491. labels.append(labels_in_image)
  492. return matched_idxs, labels
  493. def subsample(self, labels):
  494. # type: (list[Tensor]) -> list[Tensor]
  495. sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
  496. sampled_inds = []
  497. for img_idx, (pos_inds_img, neg_inds_img) in enumerate(zip(sampled_pos_inds, sampled_neg_inds)):
  498. img_sampled_inds = torch.where(pos_inds_img | neg_inds_img)[0]
  499. sampled_inds.append(img_sampled_inds)
  500. return sampled_inds
  501. def add_gt_proposals(self, proposals, gt_boxes):
  502. # type: (list[Tensor], list[Tensor]) -> list[Tensor]
  503. proposals = [torch.cat((proposal, gt_box)) for proposal, gt_box in zip(proposals, gt_boxes)]
  504. return proposals
  505. def check_targets(self, targets):
  506. # type: (Optional[list[dict[str, Tensor]]]) -> None
  507. if targets is None:
  508. raise ValueError("targets should not be None")
  509. if not all(["boxes" in t for t in targets]):
  510. raise ValueError("Every element of targets should have a boxes key")
  511. if not all(["labels" in t for t in targets]):
  512. raise ValueError("Every element of targets should have a labels key")
  513. if self.has_mask():
  514. if not all(["masks" in t for t in targets]):
  515. raise ValueError("Every element of targets should have a masks key")
  516. def select_training_samples(
  517. self,
  518. proposals, # type: list[Tensor]
  519. targets, # type: Optional[list[dict[str, Tensor]]]
  520. ):
  521. # type: (...) -> tuple[list[Tensor], list[Tensor], list[Tensor], list[Tensor]]
  522. self.check_targets(targets)
  523. if targets is None:
  524. raise ValueError("targets should not be None")
  525. dtype = proposals[0].dtype
  526. device = proposals[0].device
  527. gt_boxes = [t["boxes"].to(dtype) for t in targets]
  528. gt_labels = [t["labels"] for t in targets]
  529. # append ground-truth bboxes to propos
  530. proposals = self.add_gt_proposals(proposals, gt_boxes)
  531. # get matching gt indices for each proposal
  532. matched_idxs, labels = self.assign_targets_to_proposals(proposals, gt_boxes, gt_labels)
  533. # sample a fixed proportion of positive-negative proposals
  534. sampled_inds = self.subsample(labels)
  535. matched_gt_boxes = []
  536. num_images = len(proposals)
  537. for img_id in range(num_images):
  538. img_sampled_inds = sampled_inds[img_id]
  539. proposals[img_id] = proposals[img_id][img_sampled_inds]
  540. labels[img_id] = labels[img_id][img_sampled_inds]
  541. matched_idxs[img_id] = matched_idxs[img_id][img_sampled_inds]
  542. gt_boxes_in_image = gt_boxes[img_id]
  543. if gt_boxes_in_image.numel() == 0:
  544. gt_boxes_in_image = torch.zeros((1, 4), dtype=dtype, device=device)
  545. matched_gt_boxes.append(gt_boxes_in_image[matched_idxs[img_id]])
  546. regression_targets = self.box_coder.encode(matched_gt_boxes, proposals)
  547. return proposals, matched_idxs, labels, regression_targets
  548. def postprocess_detections(
  549. self,
  550. class_logits, # type: Tensor
  551. box_regression, # type: Tensor
  552. proposals, # type: list[Tensor]
  553. image_shapes, # type: list[tuple[int, int]]
  554. ):
  555. # type: (...) -> tuple[list[Tensor], list[Tensor], list[Tensor]]
  556. device = class_logits.device
  557. num_classes = class_logits.shape[-1]
  558. boxes_per_image = [boxes_in_image.shape[0] for boxes_in_image in proposals]
  559. pred_boxes = self.box_coder.decode(box_regression, proposals)
  560. pred_scores = F.softmax(class_logits, -1)
  561. pred_boxes_list = pred_boxes.split(boxes_per_image, 0)
  562. pred_scores_list = pred_scores.split(boxes_per_image, 0)
  563. all_boxes = []
  564. all_scores = []
  565. all_labels = []
  566. for boxes, scores, image_shape in zip(pred_boxes_list, pred_scores_list, image_shapes):
  567. boxes = box_ops.clip_boxes_to_image(boxes, image_shape)
  568. # create labels for each prediction
  569. labels = torch.arange(num_classes, device=device)
  570. labels = labels.view(1, -1).expand_as(scores)
  571. # remove predictions with the background label
  572. boxes = boxes[:, 1:]
  573. scores = scores[:, 1:]
  574. labels = labels[:, 1:]
  575. # batch everything, by making every class prediction be a separate instance
  576. boxes = boxes.reshape(-1, 4)
  577. scores = scores.reshape(-1)
  578. labels = labels.reshape(-1)
  579. # remove low scoring boxes
  580. inds = torch.where(scores > self.score_thresh)[0]
  581. boxes, scores, labels = boxes[inds], scores[inds], labels[inds]
  582. # remove empty boxes
  583. keep = box_ops.remove_small_boxes(boxes, min_size=1e-2)
  584. boxes, scores, labels = boxes[keep], scores[keep], labels[keep]
  585. # non-maximum suppression, independently done per class
  586. keep = box_ops.batched_nms(boxes, scores, labels, self.nms_thresh)
  587. # keep only topk scoring predictions
  588. keep = keep[: self.detections_per_img]
  589. boxes, scores, labels = boxes[keep], scores[keep], labels[keep]
  590. all_boxes.append(boxes)
  591. all_scores.append(scores)
  592. all_labels.append(labels)
  593. return all_boxes, all_scores, all_labels
  594. def forward(
  595. self,
  596. features: dict[str, torch.Tensor],
  597. proposals: list[torch.Tensor],
  598. image_shapes: list[tuple[int, int]],
  599. targets: Optional[list[dict[str, torch.Tensor]]] = None,
  600. ) -> tuple[list[dict[str, torch.Tensor]], dict[str, torch.Tensor]]:
  601. """
  602. Args:
  603. features (List[Tensor])
  604. proposals (List[Tensor[N, 4]])
  605. image_shapes (List[Tuple[H, W]])
  606. targets (List[Dict])
  607. """
  608. if targets is not None:
  609. for t in targets:
  610. # TODO: https://github.com/pytorch/pytorch/issues/26731
  611. floating_point_types = (torch.float, torch.double, torch.half)
  612. if t["boxes"].dtype not in floating_point_types:
  613. raise TypeError(f"target boxes must of float type, instead got {t['boxes'].dtype}")
  614. if not t["labels"].dtype == torch.int64:
  615. raise TypeError(f"target labels must of int64 type, instead got {t['labels'].dtype}")
  616. if self.has_keypoint():
  617. if not t["keypoints"].dtype == torch.float32:
  618. raise TypeError(f"target keypoints must of float type, instead got {t['keypoints'].dtype}")
  619. if self.training:
  620. proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets)
  621. else:
  622. labels = None
  623. regression_targets = None
  624. matched_idxs = None
  625. box_features = self.box_roi_pool(features, proposals, image_shapes)
  626. box_features = self.box_head(box_features)
  627. class_logits, box_regression = self.box_predictor(box_features)
  628. result: list[dict[str, torch.Tensor]] = []
  629. losses = {}
  630. if self.training:
  631. if labels is None:
  632. raise ValueError("labels cannot be None")
  633. if regression_targets is None:
  634. raise ValueError("regression_targets cannot be None")
  635. loss_classifier, loss_box_reg = fastrcnn_loss(class_logits, box_regression, labels, regression_targets)
  636. losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg}
  637. else:
  638. boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes)
  639. num_images = len(boxes)
  640. for i in range(num_images):
  641. result.append(
  642. {
  643. "boxes": boxes[i],
  644. "labels": labels[i],
  645. "scores": scores[i],
  646. }
  647. )
  648. if self.has_mask():
  649. mask_proposals = [p["boxes"] for p in result]
  650. if self.training:
  651. if matched_idxs is None:
  652. raise ValueError("if in training, matched_idxs should not be None")
  653. # during training, only focus on positive boxes
  654. num_images = len(proposals)
  655. mask_proposals = []
  656. pos_matched_idxs = []
  657. for img_id in range(num_images):
  658. pos = torch.where(labels[img_id] > 0)[0]
  659. mask_proposals.append(proposals[img_id][pos])
  660. pos_matched_idxs.append(matched_idxs[img_id][pos])
  661. else:
  662. pos_matched_idxs = None
  663. if self.mask_roi_pool is not None:
  664. mask_features = self.mask_roi_pool(features, mask_proposals, image_shapes)
  665. mask_features = self.mask_head(mask_features)
  666. mask_logits = self.mask_predictor(mask_features)
  667. else:
  668. raise Exception("Expected mask_roi_pool to be not None")
  669. loss_mask = {}
  670. if self.training:
  671. if targets is None or pos_matched_idxs is None or mask_logits is None:
  672. raise ValueError("targets, pos_matched_idxs, mask_logits cannot be None when training")
  673. gt_masks = [t["masks"] for t in targets]
  674. gt_labels = [t["labels"] for t in targets]
  675. rcnn_loss_mask = maskrcnn_loss(mask_logits, mask_proposals, gt_masks, gt_labels, pos_matched_idxs)
  676. loss_mask = {"loss_mask": rcnn_loss_mask}
  677. else:
  678. labels = [r["labels"] for r in result]
  679. masks_probs = maskrcnn_inference(mask_logits, labels)
  680. for mask_prob, r in zip(masks_probs, result):
  681. r["masks"] = mask_prob
  682. losses.update(loss_mask)
  683. # keep none checks in if conditional so torchscript will conditionally
  684. # compile each branch
  685. if (
  686. self.keypoint_roi_pool is not None
  687. and self.keypoint_head is not None
  688. and self.keypoint_predictor is not None
  689. ):
  690. keypoint_proposals = [p["boxes"] for p in result]
  691. if self.training:
  692. # during training, only focus on positive boxes
  693. num_images = len(proposals)
  694. keypoint_proposals = []
  695. pos_matched_idxs = []
  696. if matched_idxs is None:
  697. raise ValueError("if in trainning, matched_idxs should not be None")
  698. for img_id in range(num_images):
  699. pos = torch.where(labels[img_id] > 0)[0]
  700. keypoint_proposals.append(proposals[img_id][pos])
  701. pos_matched_idxs.append(matched_idxs[img_id][pos])
  702. else:
  703. pos_matched_idxs = None
  704. keypoint_features = self.keypoint_roi_pool(features, keypoint_proposals, image_shapes)
  705. keypoint_features = self.keypoint_head(keypoint_features)
  706. keypoint_logits = self.keypoint_predictor(keypoint_features)
  707. loss_keypoint = {}
  708. if self.training:
  709. if targets is None or pos_matched_idxs is None:
  710. raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
  711. gt_keypoints = [t["keypoints"] for t in targets]
  712. rcnn_loss_keypoint = keypointrcnn_loss(
  713. keypoint_logits, keypoint_proposals, gt_keypoints, pos_matched_idxs
  714. )
  715. loss_keypoint = {"loss_keypoint": rcnn_loss_keypoint}
  716. else:
  717. if keypoint_logits is None or keypoint_proposals is None:
  718. raise ValueError(
  719. "both keypoint_logits and keypoint_proposals should not be None when not in training mode"
  720. )
  721. keypoints_probs, kp_scores = keypointrcnn_inference(keypoint_logits, keypoint_proposals)
  722. for keypoint_prob, kps, r in zip(keypoints_probs, kp_scores, result):
  723. r["keypoints"] = keypoint_prob
  724. r["keypoints_scores"] = kps
  725. losses.update(loss_keypoint)
  726. return result, losses