bbox.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from __future__ import annotations
  18. import warnings
  19. from typing import Optional
  20. import torch
  21. from kornia.core import arange, stack, where
  22. from .linalg import transform_points
  23. __all__ = [
  24. "bbox_generator",
  25. "bbox_generator3d",
  26. "bbox_to_mask",
  27. "bbox_to_mask3d",
  28. "infer_bbox_shape",
  29. "infer_bbox_shape3d",
  30. "nms",
  31. "transform_bbox",
  32. "validate_bbox",
  33. "validate_bbox3d",
  34. ]
  35. def validate_bbox(boxes: torch.Tensor) -> bool:
  36. """Validate if a 2D bounding box usable or not. This function checks if the boxes are rectangular or not.
  37. Args:
  38. boxes: a tensor containing the coordinates of the bounding boxes to be extracted. The tensor must have the shape
  39. of Bx4x2, where each box is defined in the following ``clockwise`` order: top-left, top-right, bottom-right,
  40. bottom-left. The coordinates must be in the x, y order.
  41. """
  42. if not (len(boxes.shape) in [3, 4] and boxes.shape[-2:] == torch.Size([4, 2])):
  43. return False
  44. if len(boxes.shape) == 4:
  45. boxes = boxes.view(-1, 4, 2)
  46. x_tl, y_tl = boxes[..., 0, 0], boxes[..., 0, 1]
  47. x_tr, y_tr = boxes[..., 1, 0], boxes[..., 1, 1]
  48. x_br, y_br = boxes[..., 2, 0], boxes[..., 2, 1]
  49. x_bl, y_bl = boxes[..., 3, 0], boxes[..., 3, 1]
  50. width_t, width_b = x_tr - x_tl + 1, x_br - x_bl + 1
  51. height_t, height_b = y_tr - y_tl + 1, y_br - y_bl + 1
  52. # Replace torch.allclose with exportable operations
  53. width_diff = torch.abs(width_t - width_b)
  54. height_diff = torch.abs(height_t - height_b)
  55. # Check if differences are within tolerance (1e-4)
  56. if torch.any(width_diff > 1e-4):
  57. return False
  58. if torch.any(height_diff > 1e-4):
  59. return False
  60. return True
  61. def validate_bbox3d(boxes: torch.Tensor) -> bool:
  62. """Validate if a 3D bounding box usable or not. This function checks if the boxes are cube or not.
  63. Args:
  64. boxes: a tensor containing the coordinates of the bounding boxes to be extracted. The tensor must have the shape
  65. of Bx8x3, where each box is defined in the following ``clockwise`` order: front-top-left, front-top-right,
  66. front-bottom-right, front-bottom-left, back-top-left, back-top-right, back-bottom-right, back-bottom-left.
  67. The coordinates must be in the x, y, z order.
  68. """
  69. if not (len(boxes.shape) in [3, 4] and boxes.shape[-2:] == torch.Size([8, 3])):
  70. raise AssertionError(f"Box shape must be (B, 8, 3) or (B, N, 8, 3). Got {boxes.shape}.")
  71. if len(boxes.shape) == 4:
  72. boxes = boxes.view(-1, 8, 3)
  73. left = torch.index_select(boxes, 1, torch.tensor([1, 2, 5, 6], device=boxes.device, dtype=torch.long))[:, :, 0]
  74. right = torch.index_select(boxes, 1, torch.tensor([0, 3, 4, 7], device=boxes.device, dtype=torch.long))[:, :, 0]
  75. widths = left - right + 1
  76. if not torch.allclose(widths.permute(1, 0), widths[:, 0]):
  77. raise AssertionError(f"Boxes must have be cube, while get different widths {widths}.")
  78. bot = torch.index_select(boxes, 1, torch.tensor([2, 3, 6, 7], device=boxes.device, dtype=torch.long))[:, :, 1]
  79. upper = torch.index_select(boxes, 1, torch.tensor([0, 1, 4, 5], device=boxes.device, dtype=torch.long))[:, :, 1]
  80. heights = bot - upper + 1
  81. if not torch.allclose(heights.permute(1, 0), heights[:, 0]):
  82. raise AssertionError(f"Boxes must have be cube, while get different heights {heights}.")
  83. depths = boxes[:, 4:, 2] - boxes[:, :4, 2] + 1
  84. if not torch.allclose(depths.permute(1, 0), depths[:, 0]):
  85. raise AssertionError(f"Boxes must have be cube, while get different depths {depths}.")
  86. return True
  87. def infer_bbox_shape(boxes: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
  88. r"""Auto-infer the output sizes for the given 2D bounding boxes.
  89. Args:
  90. boxes: a tensor containing the coordinates of the bounding boxes to be extracted. The tensor must have the shape
  91. of Bx4x2, where each box is defined in the following ``clockwise`` order: top-left, top-right, bottom-right,
  92. bottom-left. The coordinates must be in the x, y order.
  93. Returns:
  94. - Bounding box heights, shape of :math:`(B,)`.
  95. - Boundingbox widths, shape of :math:`(B,)`.
  96. Example:
  97. >>> boxes = torch.tensor([[
  98. ... [1., 1.],
  99. ... [2., 1.],
  100. ... [2., 2.],
  101. ... [1., 2.],
  102. ... ], [
  103. ... [1., 1.],
  104. ... [3., 1.],
  105. ... [3., 2.],
  106. ... [1., 2.],
  107. ... ]]) # 2x4x2
  108. >>> infer_bbox_shape(boxes)
  109. (tensor([2., 2.]), tensor([2., 3.]))
  110. """
  111. validate_bbox(boxes)
  112. width: torch.Tensor = boxes[:, 1, 0] - boxes[:, 0, 0] + 1
  113. height: torch.Tensor = boxes[:, 2, 1] - boxes[:, 0, 1] + 1
  114. return height, width
  115. def infer_bbox_shape3d(boxes: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  116. r"""Auto-infer the output sizes for the given 3D bounding boxes.
  117. Args:
  118. boxes: a tensor containing the coordinates of the bounding boxes to be extracted. The tensor must have the shape
  119. of Bx8x3, where each box is defined in the following ``clockwise`` order: front-top-left, front-top-right,
  120. front-bottom-right, front-bottom-left, back-top-left, back-top-right, back-bottom-right, back-bottom-left.
  121. The coordinates must be in the x, y, z order.
  122. Returns:
  123. - Bounding box depths, shape of :math:`(B,)`.
  124. - Bounding box heights, shape of :math:`(B,)`.
  125. - Bounding box widths, shape of :math:`(B,)`.
  126. Example:
  127. >>> boxes = torch.tensor([[[ 0, 1, 2],
  128. ... [10, 1, 2],
  129. ... [10, 21, 2],
  130. ... [ 0, 21, 2],
  131. ... [ 0, 1, 32],
  132. ... [10, 1, 32],
  133. ... [10, 21, 32],
  134. ... [ 0, 21, 32]],
  135. ... [[ 3, 4, 5],
  136. ... [43, 4, 5],
  137. ... [43, 54, 5],
  138. ... [ 3, 54, 5],
  139. ... [ 3, 4, 65],
  140. ... [43, 4, 65],
  141. ... [43, 54, 65],
  142. ... [ 3, 54, 65]]]) # 2x8x3
  143. >>> infer_bbox_shape3d(boxes)
  144. (tensor([31, 61]), tensor([21, 51]), tensor([11, 41]))
  145. """
  146. validate_bbox3d(boxes)
  147. left = torch.index_select(boxes, 1, torch.tensor([1, 2, 5, 6], device=boxes.device, dtype=torch.long))[:, :, 0]
  148. right = torch.index_select(boxes, 1, torch.tensor([0, 3, 4, 7], device=boxes.device, dtype=torch.long))[:, :, 0]
  149. widths = (left - right + 1)[:, 0]
  150. bot = torch.index_select(boxes, 1, torch.tensor([2, 3, 6, 7], device=boxes.device, dtype=torch.long))[:, :, 1]
  151. upper = torch.index_select(boxes, 1, torch.tensor([0, 1, 4, 5], device=boxes.device, dtype=torch.long))[:, :, 1]
  152. heights = (bot - upper + 1)[:, 0]
  153. depths = (boxes[:, 4:, 2] - boxes[:, :4, 2] + 1)[:, 0]
  154. return depths, heights, widths
  155. def bbox_to_mask(boxes: torch.Tensor, width: int, height: int) -> torch.Tensor:
  156. """Convert 2D bounding boxes to masks. Covered area is 1. and the remaining is 0.
  157. Args:
  158. boxes: a tensor containing the coordinates of the bounding boxes to be extracted. The tensor must have the shape
  159. of Bx4x2, where each box is defined in the following ``clockwise`` order: top-left, top-right, bottom-right
  160. and bottom-left. The coordinates must be in the x, y order.
  161. width: width of the masked image.
  162. height: height of the masked image.
  163. Returns:
  164. the output mask tensor.
  165. Note:
  166. It is currently non-differentiable.
  167. Examples:
  168. >>> boxes = torch.tensor([[
  169. ... [1., 1.],
  170. ... [3., 1.],
  171. ... [3., 2.],
  172. ... [1., 2.],
  173. ... ]]) # 1x4x2
  174. >>> bbox_to_mask(boxes, 5, 5)
  175. tensor([[[0., 0., 0., 0., 0.],
  176. [0., 1., 1., 1., 0.],
  177. [0., 1., 1., 1., 0.],
  178. [0., 0., 0., 0., 0.],
  179. [0., 0., 0., 0., 0.]]])
  180. """
  181. validate_bbox(boxes)
  182. # zero padding the surroundings
  183. yy = torch.arange(height, device=boxes.device, dtype=boxes.dtype).view(height, 1)
  184. xx = torch.arange(width, device=boxes.device, dtype=boxes.dtype).view(1, width)
  185. x_min = boxes[:, 0, 0].view(-1, 1, 1)
  186. y_min = boxes[:, 0, 1].view(-1, 1, 1)
  187. x_max = boxes[:, 2, 0].view(-1, 1, 1)
  188. y_max = boxes[:, 2, 1].view(-1, 1, 1)
  189. mask = (xx >= x_min) & (xx <= x_max) & (yy >= y_min) & (yy <= y_max)
  190. return mask.to(boxes.dtype)
  191. def bbox_to_mask3d(boxes: torch.Tensor, size: tuple[int, int, int]) -> torch.Tensor:
  192. """Convert 3D bounding boxes to masks. Covered area is 1. and the remaining is 0.
  193. Args:
  194. boxes: a tensor containing the coordinates of the bounding boxes to be extracted. The tensor must have the shape
  195. of Bx8x3, where each box is defined in the following ``clockwise`` order: front-top-left, front-top-right,
  196. front-bottom-right, front-bottom-left, back-top-left, back-top-right, back-bottom-right, back-bottom-left.
  197. The coordinates must be in the x, y, z order.
  198. size: depth, height and width of the masked image.
  199. Returns:
  200. the output mask tensor.
  201. Examples:
  202. >>> boxes = torch.tensor([[
  203. ... [1., 1., 1.],
  204. ... [2., 1., 1.],
  205. ... [2., 2., 1.],
  206. ... [1., 2., 1.],
  207. ... [1., 1., 2.],
  208. ... [2., 1., 2.],
  209. ... [2., 2., 2.],
  210. ... [1., 2., 2.],
  211. ... ]]) # 1x8x3
  212. >>> bbox_to_mask3d(boxes, (4, 5, 5))
  213. tensor([[[[[0., 0., 0., 0., 0.],
  214. [0., 0., 0., 0., 0.],
  215. [0., 0., 0., 0., 0.],
  216. [0., 0., 0., 0., 0.],
  217. [0., 0., 0., 0., 0.]],
  218. <BLANKLINE>
  219. [[0., 0., 0., 0., 0.],
  220. [0., 1., 1., 0., 0.],
  221. [0., 1., 1., 0., 0.],
  222. [0., 0., 0., 0., 0.],
  223. [0., 0., 0., 0., 0.]],
  224. <BLANKLINE>
  225. [[0., 0., 0., 0., 0.],
  226. [0., 1., 1., 0., 0.],
  227. [0., 1., 1., 0., 0.],
  228. [0., 0., 0., 0., 0.],
  229. [0., 0., 0., 0., 0.]],
  230. <BLANKLINE>
  231. [[0., 0., 0., 0., 0.],
  232. [0., 0., 0., 0., 0.],
  233. [0., 0., 0., 0., 0.],
  234. [0., 0., 0., 0., 0.],
  235. [0., 0., 0., 0., 0.]]]]])
  236. """
  237. validate_bbox3d(boxes)
  238. D0, D1, D2 = size # get depth, height, width
  239. z_min = boxes[:, 0, 2].long()
  240. z_max = boxes[:, 4, 2].long()
  241. y_min = boxes[:, 1, 1].long()
  242. y_max = boxes[:, 2, 1].long()
  243. x_min = boxes[:, 0, 0].long()
  244. x_max = boxes[:, 1, 0].long()
  245. z = arange(D0, device=boxes.device, dtype=torch.long)
  246. y = arange(D1, device=boxes.device, dtype=torch.long)
  247. x = arange(D2, device=boxes.device, dtype=torch.long)
  248. # Compute mask as union of planes in one step
  249. m = (
  250. ((z[None, :] >= z_min[:, None]) & (z[None, :] <= z_max[:, None]))[:, None, :, None, None]
  251. | ((y[None, :] >= y_min[:, None]) & (y[None, :] <= y_max[:, None]))[:, None, None, :, None]
  252. | ((x[None, :] >= x_min[:, None]) & (x[None, :] <= x_max[:, None]))[:, None, None, None, :]
  253. ).float() # Shape: (N, 1, D0, D1, D2)
  254. # Compute conditions
  255. cond1 = m.all(dim=3, keepdim=True).all(dim=2, keepdim=True)
  256. cond2 = m.all(dim=4, keepdim=True).all(dim=2, keepdim=True)
  257. cond3 = m.all(dim=3, keepdim=True).all(dim=4, keepdim=True)
  258. m_out = cond1 * cond2 * cond3 # Broadcasting to (N, 1, D0, D1, D2)
  259. return m_out.float()
  260. def bbox_generator(
  261. x_start: torch.Tensor, y_start: torch.Tensor, width: torch.Tensor, height: torch.Tensor
  262. ) -> torch.Tensor:
  263. """Generate 2D bounding boxes according to the provided start coords, width and height.
  264. Args:
  265. x_start: a tensor containing the x coordinates of the bounding boxes to be extracted. Shape must be a scalar
  266. tensor or :math:`(B,)`.
  267. y_start: a tensor containing the y coordinates of the bounding boxes to be extracted. Shape must be a scalar
  268. tensor or :math:`(B,)`.
  269. width: widths of the masked image. Shape must be a scalar tensor or :math:`(B,)`.
  270. height: heights of the masked image. Shape must be a scalar tensor or :math:`(B,)`.
  271. Returns:
  272. the bounding box tensor.
  273. Examples:
  274. >>> x_start = torch.tensor([0, 1])
  275. >>> y_start = torch.tensor([1, 0])
  276. >>> width = torch.tensor([5, 3])
  277. >>> height = torch.tensor([7, 4])
  278. >>> bbox_generator(x_start, y_start, width, height)
  279. tensor([[[0, 1],
  280. [4, 1],
  281. [4, 7],
  282. [0, 7]],
  283. <BLANKLINE>
  284. [[1, 0],
  285. [3, 0],
  286. [3, 3],
  287. [1, 3]]])
  288. """
  289. if not (x_start.shape == y_start.shape and x_start.dim() in [0, 1]):
  290. raise AssertionError(f"`x_start` and `y_start` must be a scalar or (B,). Got {x_start}, {y_start}.")
  291. if not (width.shape == height.shape and width.dim() in [0, 1]):
  292. raise AssertionError(f"`width` and `height` must be a scalar or (B,). Got {width}, {height}.")
  293. if not x_start.dtype == y_start.dtype == width.dtype == height.dtype:
  294. raise AssertionError(
  295. "All tensors must be in the same dtype. Got "
  296. f"`x_start`({x_start.dtype}), `y_start`({x_start.dtype}), `width`({width.dtype}), `height`({height.dtype})."
  297. )
  298. if not x_start.device == y_start.device == width.device == height.device:
  299. raise AssertionError(
  300. "All tensors must be in the same device. Got "
  301. f"`x_start`({x_start.device}), `y_start`({x_start.device}), "
  302. f"`width`({width.device}), `height`({height.device})."
  303. )
  304. bbox = torch.tensor([[[0, 0], [0, 0], [0, 0], [0, 0]]], device=x_start.device, dtype=x_start.dtype).repeat(
  305. 1 if x_start.dim() == 0 else len(x_start), 1, 1
  306. )
  307. bbox[:, :, 0] += x_start.view(-1, 1)
  308. bbox[:, :, 1] += y_start.view(-1, 1)
  309. bbox[:, 1, 0] += width - 1
  310. bbox[:, 2, 0] += width - 1
  311. bbox[:, 2, 1] += height - 1
  312. bbox[:, 3, 1] += height - 1
  313. return bbox
  314. def bbox_generator3d(
  315. x_start: torch.Tensor,
  316. y_start: torch.Tensor,
  317. z_start: torch.Tensor,
  318. width: torch.Tensor,
  319. height: torch.Tensor,
  320. depth: torch.Tensor,
  321. ) -> torch.Tensor:
  322. """Generate 3D bounding boxes according to the provided start coords, width, height and depth.
  323. Args:
  324. x_start: a tensor containing the x coordinates of the bounding boxes to be extracted. Shape must be a scalar
  325. tensor or :math:`(B,)`.
  326. y_start: a tensor containing the y coordinates of the bounding boxes to be extracted. Shape must be a scalar
  327. tensor or :math:`(B,)`.
  328. z_start: a tensor containing the z coordinates of the bounding boxes to be extracted. Shape must be a scalar
  329. tensor or :math:`(B,)`.
  330. width: widths of the masked image. Shape must be a scalar tensor or :math:`(B,)`.
  331. height: heights of the masked image. Shape must be a scalar tensor or :math:`(B,)`.
  332. depth: depths of the masked image. Shape must be a scalar tensor or :math:`(B,)`.
  333. Returns:
  334. the 3d bounding box tensor :math:`(B, 8, 3)`.
  335. Examples:
  336. >>> x_start = torch.tensor([0, 3])
  337. >>> y_start = torch.tensor([1, 4])
  338. >>> z_start = torch.tensor([2, 5])
  339. >>> width = torch.tensor([10, 40])
  340. >>> height = torch.tensor([20, 50])
  341. >>> depth = torch.tensor([30, 60])
  342. >>> bbox_generator3d(x_start, y_start, z_start, width, height, depth)
  343. tensor([[[ 0, 1, 2],
  344. [10, 1, 2],
  345. [10, 21, 2],
  346. [ 0, 21, 2],
  347. [ 0, 1, 32],
  348. [10, 1, 32],
  349. [10, 21, 32],
  350. [ 0, 21, 32]],
  351. <BLANKLINE>
  352. [[ 3, 4, 5],
  353. [43, 4, 5],
  354. [43, 54, 5],
  355. [ 3, 54, 5],
  356. [ 3, 4, 65],
  357. [43, 4, 65],
  358. [43, 54, 65],
  359. [ 3, 54, 65]]])
  360. """
  361. if not (x_start.shape == y_start.shape == z_start.shape and x_start.dim() in [0, 1]):
  362. raise AssertionError(
  363. f"`x_start`, `y_start` and `z_start` must be a scalar or (B,). Got {x_start}, {y_start}, {z_start}."
  364. )
  365. if not (width.shape == height.shape == depth.shape and width.dim() in [0, 1]):
  366. raise AssertionError(f"`width`, `height` and `depth` must be a scalar or (B,). Got {width}, {height}, {depth}.")
  367. if not x_start.dtype == y_start.dtype == z_start.dtype == width.dtype == height.dtype == depth.dtype:
  368. raise AssertionError(
  369. "All tensors must be in the same dtype. "
  370. f"Got `x_start`({x_start.dtype}), `y_start`({x_start.dtype}), `z_start`({x_start.dtype}), "
  371. f"`width`({width.dtype}), `height`({height.dtype}) and `depth`({depth.dtype})."
  372. )
  373. if not x_start.device == y_start.device == z_start.device == width.device == height.device == depth.device:
  374. raise AssertionError(
  375. "All tensors must be in the same device. "
  376. f"Got `x_start`({x_start.device}), `y_start`({x_start.device}), `z_start`({x_start.device}), "
  377. f"`width`({width.device}), `height`({height.device}) and `depth`({depth.device})."
  378. )
  379. # front
  380. bbox = torch.tensor(
  381. [[[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]]], device=x_start.device, dtype=x_start.dtype
  382. ).repeat(len(x_start), 1, 1)
  383. bbox[:, :, 0] += x_start.view(-1, 1)
  384. bbox[:, :, 1] += y_start.view(-1, 1)
  385. bbox[:, :, 2] += z_start.view(-1, 1)
  386. bbox[:, 1, 0] += width
  387. bbox[:, 2, 0] += width
  388. bbox[:, 2, 1] += height
  389. bbox[:, 3, 1] += height
  390. # back
  391. bbox_back = bbox.clone()
  392. bbox_back[:, :, -1] += depth.unsqueeze(dim=1).repeat(1, 4)
  393. bbox = torch.cat([bbox, bbox_back], dim=1)
  394. return bbox
  395. def transform_bbox(
  396. trans_mat: torch.Tensor, boxes: torch.Tensor, mode: str = "xyxy", restore_coordinates: Optional[bool] = None
  397. ) -> torch.Tensor:
  398. r"""Apply a transformation matrix to a box or batch of boxes.
  399. Args:
  400. trans_mat: The transformation matrix to be applied with a shape of :math:`(3, 3)`
  401. or batched as :math:`(B, 3, 3)`.
  402. boxes: The boxes to be transformed with a common shape of :math:`(N, 4)` or batched as :math:`(B, N, 4)`, the
  403. polygon shape of :math:`(B, N, 4, 2)` is also supported.
  404. mode: The format in which the boxes are provided. If set to 'xyxy' the boxes are assumed to be in the format
  405. ``xmin, ymin, xmax, ymax``. If set to 'xywh' the boxes are assumed to be in the format
  406. ``xmin, ymin, width, height``
  407. restore_coordinates: In case the boxes are flipped, adding a post processing step to restore the
  408. coordinates to a valid bounding box.
  409. Returns:
  410. The set of transformed points in the specified mode
  411. """
  412. if not isinstance(mode, str):
  413. raise TypeError(f"Mode must be a string. Got {type(mode)}")
  414. if mode not in ("xyxy", "xywh"):
  415. raise ValueError(f"Mode must be one of 'xyxy', 'xywh'. Got {mode}")
  416. # (B, N, 4, 2) shaped polygon boxes do not need to be restored.
  417. if restore_coordinates is None and not (boxes.shape[-2:] == torch.Size([4, 2])):
  418. warnings.warn(
  419. "Previous behaviour produces incorrect box coordinates if a flip transformation performed on boxes."
  420. "The previous wrong behaviour has been corrected and will be removed in the future versions."
  421. "If you wish to keep the previous behaviour, please set `restore_coordinates=False`."
  422. "Otherwise, set `restore_coordinates=True` as an acknowledgement.",
  423. stacklevel=1,
  424. )
  425. # convert boxes to format xyxy
  426. if mode == "xywh":
  427. boxes[..., 2] = boxes[..., 0] + boxes[..., 2] # x + w
  428. boxes[..., 3] = boxes[..., 1] + boxes[..., 3] # y + h
  429. transformed_boxes: torch.Tensor = transform_points(trans_mat, boxes.view(boxes.shape[0], -1, 2))
  430. transformed_boxes = transformed_boxes.view_as(boxes)
  431. if (restore_coordinates is None or restore_coordinates) and not (boxes.shape[-2:] == torch.Size([4, 2])):
  432. restored_boxes = transformed_boxes.clone()
  433. # In case the boxes are flipped, we ensure it is ordered like left-top -> right-bot points
  434. restored_boxes[..., 0] = torch.min(transformed_boxes[..., [0, 2]], dim=-1)[0]
  435. restored_boxes[..., 1] = torch.min(transformed_boxes[..., [1, 3]], dim=-1)[0]
  436. restored_boxes[..., 2] = torch.max(transformed_boxes[..., [0, 2]], dim=-1)[0]
  437. restored_boxes[..., 3] = torch.max(transformed_boxes[..., [1, 3]], dim=-1)[0]
  438. transformed_boxes = restored_boxes
  439. if mode == "xywh":
  440. transformed_boxes[..., 2] = transformed_boxes[..., 2] - transformed_boxes[..., 0]
  441. transformed_boxes[..., 3] = transformed_boxes[..., 3] - transformed_boxes[..., 1]
  442. return transformed_boxes
  443. def nms(boxes: torch.Tensor, scores: torch.Tensor, iou_threshold: float) -> torch.Tensor:
  444. """Perform non-maxima suppression (NMS) on tensor of bounding boxes according to the intersection-over-union (IoU).
  445. Args:
  446. boxes: tensor containing the encoded bounding boxes with the shape :math:`(N, (x_1, y_1, x_2, y_2))`.
  447. scores: tensor containing the scores associated to each bounding box with shape :math:`(N,)`.
  448. iou_threshold: the throshold to discard the overlapping boxes.
  449. Return:
  450. A tensor mask with the indices to keep from the input set of boxes and scores.
  451. Example:
  452. >>> boxes = torch.tensor([
  453. ... [10., 10., 20., 20.],
  454. ... [15., 5., 15., 25.],
  455. ... [100., 100., 200., 200.],
  456. ... [100., 100., 200., 200.]])
  457. >>> scores = torch.tensor([0.9, 0.8, 0.7, 0.9])
  458. >>> nms(boxes, scores, iou_threshold=0.8)
  459. tensor([0, 3, 1])
  460. """
  461. if len(boxes.shape) != 2 and boxes.shape[-1] != 4:
  462. raise ValueError(f"boxes expected as Nx4. Got: {boxes.shape}.")
  463. if len(scores.shape) != 1:
  464. raise ValueError(f"scores expected as N. Got: {scores.shape}.")
  465. if boxes.shape[0] != scores.shape[0]:
  466. raise ValueError(f"boxes and scores mus have same shape. Got: {boxes.shape, scores.shape}.")
  467. x1, y1, x2, y2 = boxes.unbind(-1)
  468. areas = (x2 - x1) * (y2 - y1)
  469. _, order = scores.sort(descending=True)
  470. keep = []
  471. while order.shape[0] > 0:
  472. i = order[0]
  473. keep.append(i)
  474. xx1 = torch.max(x1[i], x1[order[1:]])
  475. yy1 = torch.max(y1[i], y1[order[1:]])
  476. xx2 = torch.min(x2[i], x2[order[1:]])
  477. yy2 = torch.min(y2[i], y2[order[1:]])
  478. w = torch.clamp(xx2 - xx1, min=0.0)
  479. h = torch.clamp(yy2 - yy1, min=0.0)
  480. inter = w * h
  481. ovr = inter / (areas[i] + areas[order[1:]] - inter)
  482. inds = where(ovr <= iou_threshold)[0]
  483. order = order[inds + 1]
  484. if len(keep) > 0:
  485. return stack(keep)
  486. return torch.tensor(keep)