boxes.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533
  1. import torch
  2. import torchvision
  3. from torch import Tensor
  4. from torchvision.extension import _assert_has_ops
  5. from ..utils import _log_api_usage_once
  6. from ._box_convert import (
  7. _box_cxcywh_to_xyxy,
  8. _box_cxcywhr_to_xywhr,
  9. _box_xywh_to_xyxy,
  10. _box_xywhr_to_cxcywhr,
  11. _box_xywhr_to_xyxyxyxy,
  12. _box_xyxy_to_cxcywh,
  13. _box_xyxy_to_xywh,
  14. _box_xyxyxyxy_to_xywhr,
  15. )
  16. from ._utils import _upcast
  17. def nms(boxes: Tensor, scores: Tensor, iou_threshold: float) -> Tensor:
  18. """
  19. Performs non-maximum suppression (NMS) on the boxes according
  20. to their intersection-over-union (IoU).
  21. NMS iteratively removes lower scoring boxes which have an
  22. IoU greater than ``iou_threshold`` with another (higher scoring)
  23. box.
  24. If multiple boxes have the exact same score and satisfy the IoU
  25. criterion with respect to a reference box, the selected box is
  26. not guaranteed to be the same between CPU and GPU. This is similar
  27. to the behavior of argsort in PyTorch when repeated values are present.
  28. Args:
  29. boxes (Tensor[N, 4])): boxes to perform NMS on. They
  30. are expected to be in ``(x1, y1, x2, y2)`` format with ``0 <= x1 < x2`` and
  31. ``0 <= y1 < y2``.
  32. scores (Tensor[N]): scores for each one of the boxes
  33. iou_threshold (float): discards all overlapping boxes with IoU > iou_threshold
  34. Returns:
  35. Tensor: int64 tensor with the indices of the elements that have been kept
  36. by NMS, sorted in decreasing order of scores
  37. """
  38. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  39. _log_api_usage_once(nms)
  40. _assert_has_ops()
  41. return torch.ops.torchvision.nms(boxes, scores, iou_threshold)
  42. def batched_nms(
  43. boxes: Tensor,
  44. scores: Tensor,
  45. idxs: Tensor,
  46. iou_threshold: float,
  47. ) -> Tensor:
  48. """
  49. Performs non-maximum suppression in a batched fashion.
  50. Each index value correspond to a category, and NMS
  51. will not be applied between elements of different categories.
  52. Args:
  53. boxes (Tensor[N, 4]): boxes where NMS will be performed. They
  54. are expected to be in ``(x1, y1, x2, y2)`` format with ``0 <= x1 < x2`` and
  55. ``0 <= y1 < y2``.
  56. scores (Tensor[N]): scores for each one of the boxes
  57. idxs (Tensor[N]): indices of the categories for each one of the boxes.
  58. iou_threshold (float): discards all overlapping boxes with IoU > iou_threshold
  59. Returns:
  60. Tensor: int64 tensor with the indices of the elements that have been kept by NMS, sorted
  61. in decreasing order of scores
  62. """
  63. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  64. _log_api_usage_once(batched_nms)
  65. # Benchmarks that drove the following thresholds are at
  66. # https://github.com/pytorch/vision/issues/1311#issuecomment-781329339
  67. # and https://github.com/pytorch/vision/pull/8925
  68. if boxes.numel() > (4000 if boxes.device.type == "cpu" else 100_000) and not torchvision._is_tracing():
  69. return _batched_nms_vanilla(boxes, scores, idxs, iou_threshold)
  70. else:
  71. return _batched_nms_coordinate_trick(boxes, scores, idxs, iou_threshold)
  72. @torch.jit._script_if_tracing
  73. def _batched_nms_coordinate_trick(
  74. boxes: Tensor,
  75. scores: Tensor,
  76. idxs: Tensor,
  77. iou_threshold: float,
  78. ) -> Tensor:
  79. # strategy: in order to perform NMS independently per class,
  80. # we add an offset to all the boxes. The offset is dependent
  81. # only on the class idx, and is large enough so that boxes
  82. # from different classes do not overlap
  83. if boxes.numel() == 0:
  84. return torch.empty((0,), dtype=torch.int64, device=boxes.device)
  85. max_coordinate = boxes.max()
  86. offsets = idxs.to(boxes) * (max_coordinate + torch.tensor(1).to(boxes))
  87. boxes_for_nms = boxes + offsets[:, None]
  88. keep = nms(boxes_for_nms, scores, iou_threshold)
  89. return keep
  90. @torch.jit._script_if_tracing
  91. def _batched_nms_vanilla(
  92. boxes: Tensor,
  93. scores: Tensor,
  94. idxs: Tensor,
  95. iou_threshold: float,
  96. ) -> Tensor:
  97. # Based on Detectron2 implementation, just manually call nms() on each class independently
  98. keep_mask = torch.zeros_like(scores, dtype=torch.bool)
  99. for class_id in torch.unique(idxs):
  100. curr_indices = torch.where(idxs == class_id)[0]
  101. curr_keep_indices = nms(boxes[curr_indices], scores[curr_indices], iou_threshold)
  102. keep_mask[curr_indices[curr_keep_indices]] = True
  103. keep_indices = torch.where(keep_mask)[0]
  104. return keep_indices[scores[keep_indices].sort(descending=True)[1]]
  105. def remove_small_boxes(boxes: Tensor, min_size: float) -> Tensor:
  106. """
  107. Remove every box from ``boxes`` which contains at least one side length
  108. that is smaller than ``min_size``.
  109. .. note::
  110. For sanitizing a :class:`~torchvision.tv_tensors.BoundingBoxes` object, consider using
  111. the transform :func:`~torchvision.transforms.v2.SanitizeBoundingBoxes` instead.
  112. Args:
  113. boxes (Tensor[..., 4]): boxes in ``(x1, y1, x2, y2)`` format
  114. with ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
  115. min_size (float): minimum size
  116. Returns:
  117. Tensor[K]: indices of the boxes that have both sides
  118. larger than ``min_size``
  119. """
  120. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  121. _log_api_usage_once(remove_small_boxes)
  122. ws, hs = boxes[..., 2] - boxes[..., 0], boxes[..., 3] - boxes[..., 1]
  123. keep = (ws >= min_size) & (hs >= min_size)
  124. keep = torch.where(keep)[0]
  125. return keep
  126. def clip_boxes_to_image(boxes: Tensor, size: tuple[int, int]) -> Tensor:
  127. """
  128. Clip boxes so that they lie inside an image of size ``size``.
  129. .. note::
  130. For clipping a :class:`~torchvision.tv_tensors.BoundingBoxes` object, consider using
  131. the transform :func:`~torchvision.transforms.v2.ClampBoundingBoxes` instead.
  132. Args:
  133. boxes (Tensor[..., 4]): boxes in ``(x1, y1, x2, y2)`` format
  134. with ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
  135. size (Tuple[height, width]): size of the image
  136. Returns:
  137. Tensor[..., 4]: clipped boxes
  138. """
  139. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  140. _log_api_usage_once(clip_boxes_to_image)
  141. dim = boxes.dim()
  142. boxes_x = boxes[..., 0::2]
  143. boxes_y = boxes[..., 1::2]
  144. height, width = size
  145. if torchvision._is_tracing():
  146. boxes_x = torch.max(boxes_x, torch.tensor(0, dtype=boxes.dtype, device=boxes.device))
  147. boxes_x = torch.min(boxes_x, torch.tensor(width, dtype=boxes.dtype, device=boxes.device))
  148. boxes_y = torch.max(boxes_y, torch.tensor(0, dtype=boxes.dtype, device=boxes.device))
  149. boxes_y = torch.min(boxes_y, torch.tensor(height, dtype=boxes.dtype, device=boxes.device))
  150. else:
  151. boxes_x = boxes_x.clamp(min=0, max=width)
  152. boxes_y = boxes_y.clamp(min=0, max=height)
  153. clipped_boxes = torch.stack((boxes_x, boxes_y), dim=dim)
  154. return clipped_boxes.reshape(boxes.shape)
  155. def box_convert(boxes: Tensor, in_fmt: str, out_fmt: str) -> Tensor:
  156. """
  157. Converts :class:`torch.Tensor` boxes from a given ``in_fmt`` to ``out_fmt``.
  158. .. note::
  159. For converting a :class:`torch.Tensor` or a :class:`~torchvision.tv_tensors.BoundingBoxes` object
  160. between different formats,
  161. consider using :func:`~torchvision.transforms.v2.functional.convert_bounding_box_format` instead.
  162. Or see the corresponding transform :func:`~torchvision.transforms.v2.ConvertBoundingBoxFormat`.
  163. Supported ``in_fmt`` and ``out_fmt`` strings are:
  164. ``'xyxy'``: boxes are represented via corners, x1, y1 being top left and x2, y2 being bottom right.
  165. This is the format that torchvision utilities expect.
  166. ``'xywh'``: boxes are represented via corner, width and height, x1, y2 being top left, w, h being width and height.
  167. ``'cxcywh'``: boxes are represented via centre, width and height, cx, cy being center of box, w, h
  168. being width and height.
  169. ``'xywhr'``: boxes are represented via corner, width and height, x1, y2 being top left, w, h being width and height.
  170. r is rotation angle w.r.t to the box center by :math:`|r|` degrees counter clock wise in the image plan
  171. ``'cxcywhr'``: boxes are represented via centre, width and height, cx, cy being center of box, w, h
  172. being width and height.
  173. r is rotation angle w.r.t to the box center by :math:`|r|` degrees counter clock wise in the image plan
  174. ``'xyxyxyxy'``: boxes are represented via corners, x1, y1 being top left, x2, y2 top right,
  175. x3, y3 bottom right, and x4, y4 bottom left.
  176. Args:
  177. boxes (Tensor[..., K]): boxes which will be converted. K is the number of coordinates (4 for unrotated bounding boxes, 5 or 8 for rotated bounding boxes). Supports any number of leading batch dimensions.
  178. in_fmt (str): Input format of given boxes. Supported formats are ['xyxy', 'xywh', 'cxcywh', 'xywhr', 'cxcywhr', 'xyxyxyxy'].
  179. out_fmt (str): Output format of given boxes. Supported formats are ['xyxy', 'xywh', 'cxcywh', 'xywhr', 'cxcywhr', 'xyxyxyxy']
  180. Returns:
  181. Tensor[..., K]: Boxes into converted format.
  182. """
  183. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  184. _log_api_usage_once(box_convert)
  185. allowed_fmts = (
  186. "xyxy",
  187. "xywh",
  188. "cxcywh",
  189. "xywhr",
  190. "cxcywhr",
  191. "xyxyxyxy",
  192. )
  193. if in_fmt not in allowed_fmts or out_fmt not in allowed_fmts:
  194. raise ValueError(f"Unsupported Bounding Box Conversions for given in_fmt {in_fmt} and out_fmt {out_fmt}")
  195. if in_fmt == out_fmt:
  196. return boxes.clone()
  197. e = (in_fmt, out_fmt)
  198. if e == ("xywh", "xyxy"):
  199. boxes = _box_xywh_to_xyxy(boxes)
  200. elif e == ("cxcywh", "xyxy"):
  201. boxes = _box_cxcywh_to_xyxy(boxes)
  202. elif e == ("xyxy", "xywh"):
  203. boxes = _box_xyxy_to_xywh(boxes)
  204. elif e == ("xyxy", "cxcywh"):
  205. boxes = _box_xyxy_to_cxcywh(boxes)
  206. elif e == ("xywh", "cxcywh"):
  207. boxes = _box_xywh_to_xyxy(boxes)
  208. boxes = _box_xyxy_to_cxcywh(boxes)
  209. elif e == ("cxcywh", "xywh"):
  210. boxes = _box_cxcywh_to_xyxy(boxes)
  211. boxes = _box_xyxy_to_xywh(boxes)
  212. elif e == ("cxcywhr", "xywhr"):
  213. boxes = _box_cxcywhr_to_xywhr(boxes)
  214. elif e == ("xywhr", "cxcywhr"):
  215. boxes = _box_xywhr_to_cxcywhr(boxes)
  216. elif e == ("cxcywhr", "xyxyxyxy"):
  217. boxes = _box_cxcywhr_to_xywhr(boxes).to(boxes.dtype)
  218. boxes = _box_xywhr_to_xyxyxyxy(boxes)
  219. elif e == ("xyxyxyxy", "cxcywhr"):
  220. boxes = _box_xyxyxyxy_to_xywhr(boxes).to(boxes.dtype)
  221. boxes = _box_xywhr_to_cxcywhr(boxes)
  222. elif e == ("xywhr", "xyxyxyxy"):
  223. boxes = _box_xywhr_to_xyxyxyxy(boxes)
  224. elif e == ("xyxyxyxy", "xywhr"):
  225. boxes = _box_xyxyxyxy_to_xywhr(boxes)
  226. else:
  227. raise NotImplementedError(f"Unsupported Bounding Box Conversions for given in_fmt {e[0]} and out_fmt {e[1]}")
  228. return boxes
  229. def box_area(boxes: Tensor, fmt: str = "xyxy") -> Tensor:
  230. """
  231. Computes the area of a set of bounding boxes from a given format.
  232. Args:
  233. boxes (Tensor[..., 4]): boxes for which the area will be computed.
  234. fmt (str): Format of the input boxes.
  235. Default is "xyxy" to preserve backward compatibility.
  236. Supported formats are "xyxy", "xywh", and "cxcywh".
  237. Returns:
  238. Tensor[N]: Tensor containing the area for each box.
  239. """
  240. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  241. _log_api_usage_once(box_area)
  242. allowed_fmts = (
  243. "xyxy",
  244. "xywh",
  245. "cxcywh",
  246. )
  247. if fmt not in allowed_fmts:
  248. raise ValueError(f"Unsupported Bounding Box area for given format {fmt}")
  249. boxes = _upcast(boxes)
  250. if fmt == "xyxy":
  251. area = (boxes[..., 2] - boxes[..., 0]) * (boxes[..., 3] - boxes[..., 1])
  252. else:
  253. # For formats with width and height, area = width * height
  254. # Supported: cxcywh, xywh
  255. area = boxes[..., 2] * boxes[..., 3]
  256. return area
  257. # implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py
  258. # with slight modifications
  259. def _box_inter_union(boxes1: Tensor, boxes2: Tensor, fmt: str = "xyxy") -> tuple[Tensor, Tensor]:
  260. area1 = box_area(boxes1, fmt=fmt)
  261. area2 = box_area(boxes2, fmt=fmt)
  262. allowed_fmts = (
  263. "xyxy",
  264. "xywh",
  265. "cxcywh",
  266. )
  267. if fmt not in allowed_fmts:
  268. raise ValueError(f"Unsupported Box IoU Calculation for given fmt {fmt}.")
  269. if fmt == "xyxy":
  270. lt = torch.max(boxes1[..., None, :2], boxes2[..., None, :, :2]) # [...,N,M,2]
  271. rb = torch.min(boxes1[..., None, 2:], boxes2[..., None, :, 2:]) # [...,N,M,2]
  272. elif fmt == "xywh":
  273. lt = torch.max(boxes1[..., None, :2], boxes2[..., None, :, :2]) # [...,N,M,2]
  274. rb = torch.min(
  275. boxes1[..., None, :2] + boxes1[..., None, 2:], boxes2[..., None, :, :2] + boxes2[..., None, :, 2:]
  276. ) # [...,N,M,2]
  277. else: # fmt == "cxcywh":
  278. lt = torch.max(
  279. boxes1[..., None, :2] - boxes1[..., None, 2:] / 2, boxes2[..., None, :, :2] - boxes2[..., None, :, 2:] / 2
  280. ) # [N,M,2]
  281. rb = torch.min(
  282. boxes1[..., None, :2] + boxes1[..., None, 2:] / 2, boxes2[..., None, :, :2] + boxes2[..., None, :, 2:] / 2
  283. ) # [N,M,2]
  284. wh = _upcast(rb - lt).clamp(min=0) # [N,M,2]
  285. inter = wh[..., 0] * wh[..., 1] # [N,M]
  286. union = area1[..., None] + area2[..., None, :] - inter
  287. return inter, union
  288. def box_iou(boxes1: Tensor, boxes2: Tensor, fmt: str = "xyxy") -> Tensor:
  289. """
  290. Return intersection-over-union (Jaccard index) between two sets of boxes from a given format.
  291. Args:
  292. boxes1 (Tensor[..., N, 4]): first set of boxes
  293. boxes2 (Tensor[..., M, 4]): second set of boxes
  294. fmt (str): Format of the input boxes.
  295. Default is "xyxy" to preserve backward compatibility.
  296. Supported formats are "xyxy", "xywh", and "cxcywh".
  297. Returns:
  298. Tensor[..., N, M]: the NxM matrix containing the pairwise IoU values for every element
  299. in boxes1 and boxes2
  300. """
  301. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  302. _log_api_usage_once(box_iou)
  303. allowed_fmts = (
  304. "xyxy",
  305. "xywh",
  306. "cxcywh",
  307. )
  308. if fmt not in allowed_fmts:
  309. raise ValueError(f"Unsupported Box IoU Calculation for given format {fmt}.")
  310. inter, union = _box_inter_union(boxes1, boxes2, fmt=fmt)
  311. iou = inter / union
  312. return iou
  313. # Implementation adapted from https://github.com/facebookresearch/detr/blob/master/util/box_ops.py
  314. def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
  315. """
  316. Return generalized intersection-over-union (Jaccard index) between two sets of boxes.
  317. Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with
  318. ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
  319. Args:
  320. boxes1 (Tensor[..., N, 4]): first set of boxes
  321. boxes2 (Tensor[..., M, 4]): second set of boxes
  322. Returns:
  323. Tensor[..., N, M]: the NxM matrix containing the pairwise generalized IoU values
  324. for every element in boxes1 and boxes2
  325. """
  326. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  327. _log_api_usage_once(generalized_box_iou)
  328. inter, union = _box_inter_union(boxes1, boxes2)
  329. iou = inter / union
  330. lti = torch.min(boxes1[..., None, :2], boxes2[..., None, :, :2])
  331. rbi = torch.max(boxes1[..., None, 2:], boxes2[..., None, :, 2:])
  332. whi = _upcast(rbi - lti).clamp(min=0) # [N,M,2]
  333. areai = whi[..., 0] * whi[..., 1]
  334. return iou - (areai - union) / areai
  335. def complete_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tensor:
  336. """
  337. Return complete intersection-over-union (Jaccard index) between two sets of boxes.
  338. Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with
  339. ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
  340. Args:
  341. boxes1 (Tensor[..., N, 4]): first set of boxes
  342. boxes2 (Tensor[..., M, 4]): second set of boxes
  343. eps (float, optional): small number to prevent division by zero. Default: 1e-7
  344. Returns:
  345. Tensor[..., N, M]: the NxM matrix containing the pairwise complete IoU values
  346. for every element in boxes1 and boxes2
  347. """
  348. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  349. _log_api_usage_once(complete_box_iou)
  350. boxes1 = _upcast(boxes1)
  351. boxes2 = _upcast(boxes2)
  352. diou, iou = _box_diou_iou(boxes1, boxes2, eps)
  353. w_pred = boxes1[..., None, 2] - boxes1[..., None, 0]
  354. h_pred = boxes1[..., None, 3] - boxes1[..., None, 1]
  355. w_gt = boxes2[..., None, :, 2] - boxes2[..., None, :, 0]
  356. h_gt = boxes2[..., None, :, 3] - boxes2[..., None, :, 1]
  357. v = (4 / (torch.pi**2)) * torch.pow(torch.atan(w_pred / h_pred) - torch.atan(w_gt / h_gt), 2)
  358. with torch.no_grad():
  359. alpha = v / (1 - iou + v + eps)
  360. return diou - alpha * v
  361. def distance_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tensor:
  362. """
  363. Return distance intersection-over-union (Jaccard index) between two sets of boxes.
  364. Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with
  365. ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
  366. Args:
  367. boxes1 (Tensor[..., N, 4]): first set of boxes
  368. boxes2 (Tensor[..., M, 4]): second set of boxes
  369. eps (float, optional): small number to prevent division by zero. Default: 1e-7
  370. Returns:
  371. Tensor[..., N, M]: the NxM matrix containing the pairwise distance IoU values
  372. for every element in boxes1 and boxes2
  373. """
  374. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  375. _log_api_usage_once(distance_box_iou)
  376. boxes1 = _upcast(boxes1)
  377. boxes2 = _upcast(boxes2)
  378. diou, _ = _box_diou_iou(boxes1, boxes2, eps=eps)
  379. return diou
  380. def _box_diou_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> tuple[Tensor, Tensor]:
  381. iou = box_iou(boxes1, boxes2)
  382. lti = torch.min(boxes1[..., None, :2], boxes2[..., None, :, :2])
  383. rbi = torch.max(boxes1[..., None, 2:], boxes2[..., None, :, 2:])
  384. whi = _upcast(rbi - lti).clamp(min=0) # [N,M,2]
  385. diagonal_distance_squared = (whi[..., 0] ** 2) + (whi[..., 1] ** 2) + eps
  386. # centers of boxes
  387. x_p = (boxes1[..., 0] + boxes1[..., 2]) / 2
  388. y_p = (boxes1[..., 1] + boxes1[..., 3]) / 2
  389. x_g = (boxes2[..., 0] + boxes2[..., 2]) / 2
  390. y_g = (boxes2[..., 1] + boxes2[..., 3]) / 2
  391. # The distance between boxes' centers squared.
  392. centers_distance_squared = (_upcast(x_p[..., None] - x_g[..., None, :]) ** 2) + (
  393. _upcast(y_p[..., None] - y_g[..., None, :]) ** 2
  394. )
  395. # The distance IoU is the IoU penalized by a normalized
  396. # distance between boxes' centers squared.
  397. return iou - (centers_distance_squared / diagonal_distance_squared), iou
  398. def masks_to_boxes(masks: torch.Tensor) -> torch.Tensor:
  399. """
  400. Compute the bounding boxes around the provided masks.
  401. Returns a [N, 4] tensor containing bounding boxes. The boxes are in ``(x1, y1, x2, y2)`` format with
  402. ``0 <= x1 <= x2`` and ``0 <= y1 <= y2``.
  403. .. note::
  404. Empty masks (all zeros) will return bounding boxes ``[0, 0, 0, 0]``.
  405. .. warning::
  406. In most cases the output will guarantee ``x1 < x2`` and ``y1 < y2``. But
  407. if the input is degenerate, e.g. if a mask is a single row or a single
  408. column, then the output may have x1 = x2 or y1 = y2.
  409. Args:
  410. masks (Tensor[N, H, W]): masks to transform where N is the number of masks
  411. and (H, W) are the spatial dimensions.
  412. Returns:
  413. Tensor[N, 4]: bounding boxes
  414. """
  415. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  416. _log_api_usage_once(masks_to_boxes)
  417. if masks.numel() == 0:
  418. return torch.zeros((0, 4), device=masks.device, dtype=torch.float)
  419. n, h, w = masks.shape
  420. masks_bool = masks.bool()
  421. non_zero_rows = torch.any(masks_bool, dim=2)
  422. non_zero_cols = torch.any(masks_bool, dim=1)
  423. empty_masks = ~torch.any(non_zero_rows, dim=1)
  424. non_zero_rows_f = non_zero_rows.float()
  425. non_zero_cols_f = non_zero_cols.float()
  426. y1 = non_zero_rows_f.argmax(dim=1)
  427. x1 = non_zero_cols_f.argmax(dim=1)
  428. y2 = (h - 1) - non_zero_rows_f.flip(dims=[1]).argmax(dim=1)
  429. x2 = (w - 1) - non_zero_cols_f.flip(dims=[1]).argmax(dim=1)
  430. bounding_boxes = torch.stack([x1, y1, x2, y2], dim=1).float()
  431. bounding_boxes[empty_masks] = 0
  432. return bounding_boxes