boxes.py 47 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from __future__ import annotations
  18. from typing import Optional, Tuple, cast
  19. import torch
  20. from torch import Size
  21. from kornia.core import Tensor, stack, zeros
  22. from kornia.geometry.bbox import validate_bbox
  23. from kornia.geometry.linalg import transform_points
  24. from kornia.utils import eye_like
  25. __all__ = ["Boxes", "Boxes3D"]
  26. def _is_floating_point_dtype(dtype: torch.dtype) -> bool:
  27. return dtype in (torch.float16, torch.float32, torch.float64, torch.bfloat16, torch.half)
  28. def _merge_box_list(boxes: list[torch.Tensor], method: str = "pad") -> tuple[torch.Tensor, list[int]]:
  29. r"""Merge a list of boxes into one tensor."""
  30. if not all(box.shape[-2:] == torch.Size([4, 2]) and box.dim() == 3 for box in boxes):
  31. raise TypeError(f"Input boxes must be a list of (N, 4, 2) shaped. Got: {[box.shape for box in boxes]}.")
  32. if method == "pad":
  33. max_N = max(box.shape[0] for box in boxes)
  34. stats = [max_N - box.shape[0] for box in boxes]
  35. output = torch.nn.utils.rnn.pad_sequence(boxes, batch_first=True)
  36. else:
  37. raise NotImplementedError(f"`{method}` is not implemented.")
  38. return output, stats
  39. def _transform_boxes(boxes: torch.Tensor, M: torch.Tensor) -> torch.Tensor:
  40. """Transform 3D and 2D in kornia format by applying the transformation matrix M.
  41. Boxes and the transformation matrix could be batched or not.
  42. Args:
  43. boxes: 2D quadrilaterals or 3D hexahedrons in kornia format.
  44. M: the transformation matrix of shape :math:`(3, 3)` or :math:`(B, 3, 3)` for 2D and :math:`(4, 4)` or
  45. :math:`(B, 4, 4)` for 3D hexahedron.
  46. """
  47. M = M if M.is_floating_point() else M.float()
  48. # Work with batch as kornia.transform_points only supports a batch of points.
  49. boxes_per_batch, n_points_per_box, coordinates_dimension = boxes.shape[-3:]
  50. if boxes_per_batch == 0:
  51. return boxes
  52. points = boxes.view(-1, n_points_per_box * boxes_per_batch, coordinates_dimension)
  53. M = M if M.ndim == 3 else M.unsqueeze(0)
  54. if points.shape[0] != M.shape[0]:
  55. raise ValueError(
  56. f"Batch size mismatch. Got {points.shape[0]} for boxes and {M.shape[0]} for the transformation matrix."
  57. )
  58. transformed_boxes: torch.Tensor = transform_points(M, points)
  59. transformed_boxes = transformed_boxes.view_as(boxes)
  60. return transformed_boxes
  61. def _boxes_to_polygons(
  62. xmin: torch.Tensor, ymin: torch.Tensor, width: torch.Tensor, height: torch.Tensor
  63. ) -> torch.Tensor:
  64. if not xmin.ndim == ymin.ndim == width.ndim == height.ndim == 2:
  65. raise ValueError("We expect to create a batch of 2D boxes (quadrilaterals) in vertices format (B, N, 4, 2)")
  66. # Create (B,N,4,2) with all points in top left position of boxes
  67. polygons = zeros((xmin.shape[0], xmin.shape[1], 4, 2), device=xmin.device, dtype=xmin.dtype)
  68. polygons[..., 0] = xmin.unsqueeze(-1)
  69. polygons[..., 1] = ymin.unsqueeze(-1)
  70. # Shift top-right, bottom-right, bottom-left points to the right coordinates
  71. polygons[..., 1, 0] += width - 1 # Top right
  72. polygons[..., 2, 0] += width - 1 # Bottom right
  73. polygons[..., 2, 1] += height - 1 # Bottom right
  74. polygons[..., 3, 1] += height - 1 # Bottom left
  75. return polygons
  76. def _boxes_to_quadrilaterals(boxes: torch.Tensor, mode: str = "xyxy", validate_boxes: bool = True) -> torch.Tensor:
  77. """Convert from boxes to quadrilaterals."""
  78. mode = mode.lower()
  79. if mode.startswith("vertices"):
  80. batched = boxes.ndim == 4
  81. if not (3 <= boxes.ndim <= 4 and boxes.shape[-2:] == torch.Size([4, 2])):
  82. raise ValueError(f"Boxes shape must be (N, 4, 2) or (B, N, 4, 2) when {mode} mode. Got {boxes.shape}.")
  83. elif mode.startswith("xy"):
  84. batched = boxes.ndim == 3
  85. if not (2 <= boxes.ndim <= 3 and boxes.shape[-1] == 4):
  86. raise ValueError(f"Boxes shape must be (N, 4) or (B, N, 4) when {mode} mode. Got {boxes.shape}.")
  87. else:
  88. raise ValueError(f"Unknown mode {mode}")
  89. boxes = boxes if boxes.is_floating_point() else boxes.float()
  90. boxes = boxes if batched else boxes.unsqueeze(0)
  91. if mode.startswith("vertices"):
  92. if mode == "vertices":
  93. quadrilaterals = boxes.clone()
  94. # Here, vertices are quadrilaterals with width and height defined as `width = xmax - xmin` and
  95. # `height = ymax - ymin`. We need to convert to `width = xmax - xmin + 1` and `height = ymax - ymin + 1` to
  96. # match with internal Boxes Kornia representation.
  97. quadrilaterals[..., 1:3, 0] = quadrilaterals[..., 1:3, 0] - 1
  98. quadrilaterals[..., 2:, 1] = quadrilaterals[..., 2:, 1] - 1
  99. elif mode == "vertices_plus":
  100. # Avoid passing reference
  101. quadrilaterals = boxes.clone()
  102. else:
  103. raise ValueError(f"Unknown mode {mode}")
  104. not validate_boxes or validate_bbox(quadrilaterals)
  105. elif mode.startswith("xy"):
  106. if mode == "xyxy":
  107. height, width = boxes[..., 3] - boxes[..., 1], boxes[..., 2] - boxes[..., 0]
  108. elif mode == "xyxy_plus":
  109. height, width = boxes[..., 3] - boxes[..., 1] + 1, boxes[..., 2] - boxes[..., 0] + 1
  110. elif mode == "xywh":
  111. height, width = boxes[..., 3], boxes[..., 2]
  112. else:
  113. raise ValueError(f"Unknown mode {mode}")
  114. if validate_boxes:
  115. if (width <= 0).any():
  116. raise ValueError("Some boxes have negative widths or 0.")
  117. if (height <= 0).any():
  118. raise ValueError("Some boxes have negative heights or 0.")
  119. xmin, ymin = boxes[..., 0], boxes[..., 1]
  120. quadrilaterals = _boxes_to_polygons(xmin, ymin, width, height)
  121. else:
  122. raise ValueError(f"Unknown mode {mode}")
  123. quadrilaterals = quadrilaterals if batched else quadrilaterals.squeeze(0)
  124. return quadrilaterals
  125. def _boxes3d_to_polygons3d(
  126. xmin: torch.Tensor,
  127. ymin: torch.Tensor,
  128. zmin: torch.Tensor,
  129. width: torch.Tensor,
  130. height: torch.Tensor,
  131. depth: torch.Tensor,
  132. ) -> torch.Tensor:
  133. if not xmin.ndim == ymin.ndim == zmin.ndim == width.ndim == height.ndim == depth.ndim == 2:
  134. raise ValueError("We expect to create a batch of 3D boxes (hexahedrons) in vertices format (B, N, 8, 3)")
  135. # Front
  136. # Create (B,N,4,3) with all points in front top left position of boxes
  137. front_vertices = zeros((xmin.shape[0], xmin.shape[1], 4, 3), device=xmin.device, dtype=xmin.dtype)
  138. front_vertices[..., 0] = xmin.unsqueeze(-1)
  139. front_vertices[..., 1] = ymin.unsqueeze(-1)
  140. front_vertices[..., 2] = zmin.unsqueeze(-1)
  141. # Shift front-top-right, front-bottom-right, front-bottom-left points to the right coordinates
  142. front_vertices[..., 1, 0] += width - 1 # Top right
  143. front_vertices[..., 2, 0] += width - 1 # Bottom right
  144. front_vertices[..., 2, 1] += height - 1 # Bottom right
  145. front_vertices[..., 3, 1] += height - 1 # Bottom left
  146. # Back
  147. back_vertices = front_vertices.clone()
  148. back_vertices[..., 2] += depth.unsqueeze(-1) - 1
  149. polygons3d = torch.cat([front_vertices, back_vertices], dim=-2)
  150. return polygons3d
  151. class Boxes:
  152. r"""2D boxes containing N or BxN boxes.
  153. Args:
  154. boxes: 2D boxes, shape of :math:`(N, 4, 2)`, :math:`(B, N, 4, 2)` or a list of :math:`(N, 4, 2)`.
  155. See below for more details.
  156. raise_if_not_floating_point: flag to control floating point casting behaviour when `boxes` is not a
  157. floating point tensor. True to raise an error when `boxes` isn't a floating point tensor, False
  158. to cast to float.
  159. mode: the box format of the input boxes.
  160. Note:
  161. **2D boxes format** is defined as a floating data type tensor of shape ``Nx4x2`` or ``BxNx4x2``
  162. where each box is a `quadrilateral <https://en.wikipedia.org/wiki/Quadrilateral>`_ defined by it's
  163. 4 vertices coordinates (A, B, C, D). Coordinates must be in ``x, y`` order. The height and width of
  164. a box is defined as ``width = xmax - xmin + 1`` and ``height = ymax - ymin + 1``. Examples of
  165. `quadrilaterals <https://en.wikipedia.org/wiki/Quadrilateral>`_ are rectangles, rhombus and trapezoids.
  166. """
  167. def __init__(
  168. self,
  169. boxes: torch.Tensor | list[torch.Tensor],
  170. raise_if_not_floating_point: bool = True,
  171. mode: str = "vertices_plus",
  172. ) -> None:
  173. self._N: Optional[list[int]] = None
  174. if isinstance(boxes, list):
  175. boxes, self._N = _merge_box_list(boxes)
  176. if not isinstance(boxes, torch.Tensor):
  177. raise TypeError(f"Input boxes is not a Tensor. Got: {type(boxes)}.")
  178. if not boxes.is_floating_point():
  179. if raise_if_not_floating_point:
  180. raise ValueError(f"Coordinates must be in floating point. Got {boxes.dtype}")
  181. boxes = boxes.float()
  182. if len(boxes.shape) == 0:
  183. boxes = boxes.reshape((-1, 4))
  184. if not (3 <= boxes.ndim <= 4 and boxes.shape[-2:] == (4, 2)):
  185. raise ValueError(f"Boxes shape must be (N, 4, 2) or (B, N, 4, 2). Got {boxes.shape}.")
  186. self._is_batched = False if boxes.ndim == 3 else True
  187. self._data = boxes
  188. self._mode = mode
  189. def __getitem__(self, key: slice | int | Tensor) -> Boxes:
  190. new_box = type(self)(self._data[key], False)
  191. new_box._mode = self._mode
  192. return new_box
  193. def __setitem__(self, key: slice | int | Tensor, value: Boxes) -> Boxes:
  194. self._data[key] = value._data
  195. return self
  196. @property
  197. def shape(self) -> tuple[int, ...] | Size:
  198. return self.data.shape
  199. def get_boxes_shape(self) -> tuple[torch.Tensor, torch.Tensor]:
  200. r"""Compute boxes heights and widths.
  201. Returns:
  202. - Boxes heights, shape of :math:`(N,)` or :math:`(B,N)`.
  203. - Boxes widths, shape of :math:`(N,)` or :math:`(B,N)`.
  204. Example:
  205. >>> boxes_xyxy = torch.tensor([[[1,1,2,2],[1,1,3,2]]])
  206. >>> boxes = Boxes.from_tensor(boxes_xyxy)
  207. >>> boxes.get_boxes_shape()
  208. (tensor([[1., 1.]]), tensor([[1., 2.]]))
  209. """
  210. boxes_xywh = cast(torch.Tensor, self.to_tensor("xywh", as_padded_sequence=True))
  211. widths, heights = boxes_xywh[..., 2], boxes_xywh[..., 3]
  212. return heights, widths
  213. def merge(self, boxes: Boxes, inplace: bool = False) -> Boxes:
  214. """Merge boxes.
  215. Say, current instance holds :math:`(B, N, 4, 2)` and the incoming boxes holds :math:`(B, M, 4, 2)`,
  216. the merge results in :math:`(B, N + M, 4, 2)`.
  217. Args:
  218. boxes: 2D boxes.
  219. inplace: do transform in-place and return self.
  220. """
  221. data = torch.cat([self._data, boxes.data], dim=1)
  222. if inplace:
  223. self._data = data
  224. return self
  225. obj = self.clone()
  226. obj._data = data
  227. return obj
  228. def index_put(
  229. self, indices: tuple[Tensor, ...] | list[Tensor], values: Tensor | Boxes, inplace: bool = False
  230. ) -> Boxes:
  231. if inplace:
  232. _data = self._data
  233. else:
  234. _data = self._data.clone()
  235. if isinstance(values, Boxes):
  236. _data.index_put_(indices, values.data)
  237. else:
  238. _data.index_put_(indices, values)
  239. if inplace:
  240. return self
  241. obj = self.clone()
  242. obj._data = _data
  243. return obj
  244. def pad(self, padding_size: Tensor) -> Boxes:
  245. """Pad a bounding box.
  246. Args:
  247. padding_size: (B, 4)
  248. """
  249. if not (len(padding_size.shape) == 2 and padding_size.size(1) == 4):
  250. raise RuntimeError(f"Expected padding_size as (B, 4). Got {padding_size.shape}.")
  251. self._data[..., 0] += padding_size[..., None, :1].to(device=self._data.device) # left padding
  252. self._data[..., 1] += padding_size[..., None, 2:3].to(device=self._data.device) # top padding
  253. return self
  254. def unpad(self, padding_size: Tensor) -> Boxes:
  255. """Pad a bounding box.
  256. Args:
  257. padding_size: (B, 4)
  258. """
  259. if not (len(padding_size.shape) == 2 and padding_size.size(1) == 4):
  260. raise RuntimeError(f"Expected padding_size as (B, 4). Got {padding_size.shape}.")
  261. self._data[..., 0] -= padding_size[..., None, :1].to(device=self._data.device) # left padding
  262. self._data[..., 1] -= padding_size[..., None, 2:3].to(device=self._data.device) # top padding
  263. return self
  264. def clamp(
  265. self,
  266. topleft: Optional[Tensor | tuple[int, int]] = None,
  267. botright: Optional[Tensor | tuple[int, int]] = None,
  268. inplace: bool = False,
  269. ) -> Boxes:
  270. if not (isinstance(topleft, Tensor) and isinstance(botright, Tensor)):
  271. raise NotImplementedError
  272. if inplace:
  273. _data = self._data
  274. else:
  275. _data = self._data.clone()
  276. topleft_x = topleft[:, None, :1].repeat(1, _data.size(1), 4)
  277. _data[..., 0][_data[..., 0] < topleft_x] = topleft_x[_data[..., 0] < topleft_x]
  278. topleft_y = topleft[:, None, 1:].repeat(1, _data.size(1), 4)
  279. _data[..., 1][_data[..., 1] < topleft_y] = topleft_y[_data[..., 1] < topleft_y]
  280. botright_x = botright[:, None, :1].repeat(1, _data.size(1), 4)
  281. _data[..., 0][_data[..., 0] > botright_x] = botright_x[_data[..., 0] > botright_x]
  282. botright_y = botright[:, None, 1:].repeat(1, _data.size(1), 4)
  283. _data[..., 1][_data[..., 1] > botright_y] = botright_y[_data[..., 1] > botright_y]
  284. if inplace:
  285. return self
  286. obj = self.clone()
  287. obj._data = _data
  288. return obj
  289. def trim(self, correspondence_preserve: bool = False, inplace: bool = False) -> Boxes:
  290. """Trim out zero padded boxes.
  291. Given box arrangements of shape :math:`(4, 4, Box)`:
  292. == === == === == === == === ==
  293. -- Box -- Box -- Box -- Box --
  294. -- 0 -- 0 -- Box -- Box --
  295. -- 0 -- Box -- 0 -- 0 --
  296. -- 0 -- 0 -- 0 -- 0 --
  297. == === == === == === == === ==
  298. Nothing will change if correspondence_preserve is True. Only pure zero layers will be removed, resulting in
  299. shape :math:`(4, 3, Box)`:
  300. == === == === == === == === ==
  301. -- Box -- Box -- Box -- Box --
  302. -- 0 -- 0 -- Box -- Box --
  303. -- 0 -- Box -- 0 -- 0 --
  304. == === == === == === == === ==
  305. Otherwise, you will get :math:`(4, 2, Box)`:
  306. == === == === == === == === ==
  307. -- Box -- Box -- Box -- Box --
  308. -- 0 -- Box -- Box -- Box --
  309. == === == === == === == === ==
  310. """
  311. raise NotImplementedError
  312. def filter_boxes_by_area(
  313. self, min_area: Optional[float] = None, max_area: Optional[float] = None, inplace: bool = False
  314. ) -> Boxes:
  315. area = self.compute_area()
  316. if inplace:
  317. _data = self._data
  318. else:
  319. _data = self._data.clone()
  320. if min_area is not None:
  321. _data[area < min_area] = 0.0
  322. if max_area is not None:
  323. _data[area > max_area] = 0.0
  324. if inplace:
  325. return self
  326. obj = self.clone()
  327. obj._data = _data
  328. return obj
  329. def compute_area(self) -> torch.Tensor:
  330. """Return :math:`(B, N)`."""
  331. coords = self._data.view((-1, 4, 2)) if self._data.ndim == 4 else self._data
  332. # calculate centroid of the box
  333. centroid = coords.mean(dim=1, keepdim=True)
  334. # calculate the angle from centroid to each corner
  335. angles = torch.atan2(coords[..., 1] - centroid[..., 1], coords[..., 0] - centroid[..., 0])
  336. # sort the corners by angle to get an order for shoelace formula
  337. _, clockwise_indices = torch.sort(angles, dim=1, descending=True)
  338. # gather the corners in the new order
  339. ordered_corners = torch.gather(coords, 1, clockwise_indices.unsqueeze(-1).expand(-1, -1, 2))
  340. x, y = ordered_corners[..., 0], ordered_corners[..., 1]
  341. # Gaussian/Shoelace formula https://en.wikipedia.org/wiki/Shoelace_formula
  342. area = 0.5 * torch.abs(torch.sum((x * torch.roll(y, 1, 1)) - (y * torch.roll(x, 1, 1)), dim=1))
  343. return area.view(self._data.shape[:2]) if self._data.ndim == 4 else area
  344. @classmethod
  345. def from_tensor(
  346. cls, boxes: torch.Tensor | list[torch.Tensor], mode: str = "xyxy", validate_boxes: bool = True
  347. ) -> Boxes:
  348. r"""Create :class:`Boxes` from boxes stored in another format.
  349. Args:
  350. boxes: 2D boxes, shape of :math:`(N, 4)`, :math:`(B, N, 4)`, :math:`(N, 4, 2)` or :math:`(B, N, 4, 2)`.
  351. mode: The format in which the boxes are provided.
  352. * 'xyxy': boxes are assumed to be in the format ``xmin, ymin, xmax, ymax`` where ``width = xmax - xmin``
  353. and ``height = ymax - ymin``. With shape :math:`(N, 4)`, :math:`(B, N, 4)`.
  354. * 'xyxy_plus': similar to 'xyxy' mode but where box width and length are defined as
  355. ``width = xmax - xmin + 1`` and ``height = ymax - ymin + 1``.
  356. With shape :math:`(N, 4)`, :math:`(B, N, 4)`.
  357. * 'xywh': boxes are assumed to be in the format ``xmin, ymin, width, height`` where
  358. ``width = xmax - xmin`` and ``height = ymax - ymin``. With shape :math:`(N, 4)`, :math:`(B, N, 4)`.
  359. * 'vertices': boxes are defined by their vertices points in the following ``clockwise`` order:
  360. *top-left, top-right, bottom-right, bottom-left*. Vertices coordinates are in (x,y) order. Finally,
  361. box width and height are defined as ``width = xmax - xmin`` and ``height = ymax - ymin``.
  362. With shape :math:`(N, 4, 2)` or :math:`(B, N, 4, 2)`.
  363. * 'vertices_plus': similar to 'vertices' mode but where box width and length are defined as
  364. ``width = xmax - xmin + 1`` and ``height = ymax - ymin + 1``. ymin + 1``.
  365. With shape :math:`(N, 4, 2)` or :math:`(B, N, 4, 2)`.
  366. validate_boxes: check if boxes are valid rectangles or not. Valid rectangles are those with width
  367. and height >= 1 (>= 2 when mode ends with '_plus' suffix).
  368. Returns:
  369. :class:`Boxes` class containing the original `boxes` in the format specified by ``mode``.
  370. Examples:
  371. >>> boxes_xyxy = torch.as_tensor([[0, 3, 1, 4], [5, 1, 8, 4]])
  372. >>> boxes = Boxes.from_tensor(boxes_xyxy, mode='xyxy')
  373. >>> boxes.data # (2, 4, 2)
  374. tensor([[[0., 3.],
  375. [0., 3.],
  376. [0., 3.],
  377. [0., 3.]],
  378. <BLANKLINE>
  379. [[5., 1.],
  380. [7., 1.],
  381. [7., 3.],
  382. [5., 3.]]])
  383. """
  384. quadrilaterals: torch.Tensor | list[torch.Tensor]
  385. if isinstance(boxes, torch.Tensor):
  386. quadrilaterals = _boxes_to_quadrilaterals(boxes, mode=mode, validate_boxes=validate_boxes)
  387. else:
  388. quadrilaterals = [_boxes_to_quadrilaterals(box, mode, validate_boxes) for box in boxes]
  389. return cls(quadrilaterals, False, mode)
  390. def to_tensor(
  391. self, mode: Optional[str] = None, as_padded_sequence: bool = False
  392. ) -> torch.Tensor | list[torch.Tensor]:
  393. r"""Cast :class:`Boxes` to a tensor.
  394. ``mode`` controls which 2D boxes format should be use to represent boxes in the tensor.
  395. Args:
  396. mode: the output box format. It could be:
  397. * 'xyxy': boxes are defined as ``xmin, ymin, xmax, ymax`` where ``width = xmax - xmin`` and
  398. ``height = ymax - ymin``.
  399. * 'xyxy_plus': similar to 'xyxy' mode but where box width and length are defined as
  400. ``width = xmax - xmin + 1`` and ``height = ymax - ymin + 1``.
  401. * 'xywh': boxes are defined as ``xmin, ymin, width, height`` where ``width = xmax - xmin``
  402. and ``height = ymax - ymin``.
  403. * 'vertices': boxes are defined by their vertices points in the following ``clockwise`` order:
  404. *top-left, top-right, bottom-right, bottom-left*. Vertices coordinates are in (x,y) order. Finally,
  405. box width and height are defined as ``width = xmax - xmin`` and ``height = ymax - ymin``.
  406. * 'vertices_plus': similar to 'vertices' mode but where box width and length are defined as
  407. ``width = xmax - xmin + 1`` and ``height = ymax - ymin + 1``. ymin + 1``.
  408. as_padded_sequence: whether to keep the pads for a list of boxes. This parameter is only valid
  409. if the boxes are from a box list whilst `from_tensor`.
  410. Returns:
  411. Boxes tensor in the ``mode`` format. The shape depends with the ``mode`` value:
  412. * 'vertices' or 'verticies_plus': :math:`(N, 4, 2)` or :math:`(B, N, 4, 2)`.
  413. * Any other value: :math:`(N, 4)` or :math:`(B, N, 4)`.
  414. Examples:
  415. >>> boxes_xyxy = torch.as_tensor([[0, 3, 1, 4], [5, 1, 8, 4]])
  416. >>> boxes = Boxes.from_tensor(boxes_xyxy)
  417. >>> assert (boxes_xyxy == boxes.to_tensor(mode='xyxy')).all()
  418. """
  419. batched_boxes = self._data if self._is_batched else self._data.unsqueeze(0)
  420. boxes: torch.Tensor | list[torch.Tensor]
  421. # Create boxes in xyxy_plus format.
  422. boxes = torch.stack([batched_boxes.amin(dim=-2), batched_boxes.amax(dim=-2)], dim=-2).view(
  423. batched_boxes.shape[0], batched_boxes.shape[1], 4
  424. )
  425. if mode is None:
  426. mode = self.mode
  427. mode = mode.lower()
  428. if mode in ("xyxy", "xyxy_plus"):
  429. pass
  430. elif mode in ("xywh", "vertices", "vertices_plus"):
  431. height, width = boxes[..., 3] - boxes[..., 1] + 1, boxes[..., 2] - boxes[..., 0] + 1
  432. boxes[..., 2] = width
  433. boxes[..., 3] = height
  434. else:
  435. raise ValueError(f"Unknown mode {mode}")
  436. if mode in ("xyxy", "vertices"):
  437. offset = torch.as_tensor([0, 0, 1, 1], device=boxes.device, dtype=boxes.dtype)
  438. boxes = boxes + offset
  439. if mode.startswith("vertices"):
  440. boxes = _boxes_to_polygons(boxes[..., 0], boxes[..., 1], boxes[..., 2], boxes[..., 3])
  441. if self._N is not None and not as_padded_sequence:
  442. boxes = [torch.nn.functional.pad(o, (len(o.shape) - 1) * [0, 0] + [0, -n]) for o, n in zip(boxes, self._N)]
  443. else:
  444. boxes = boxes if self._is_batched else boxes.squeeze(0)
  445. return boxes
  446. def to_mask(self, height: int, width: int) -> torch.Tensor:
  447. """Convert 2D boxes to masks. Covered area is 1 and the remaining is 0.
  448. Args:
  449. height: height of the masked image/images.
  450. width: width of the masked image/images.
  451. Returns:
  452. the output mask tensor, shape of :math:`(N, width, height)` or :math:`(B,N, width, height)` and dtype of
  453. :func:`Boxes.dtype` (it can be any floating point dtype).
  454. Note:
  455. It is currently non-differentiable.
  456. Examples:
  457. >>> boxes = Boxes(torch.tensor([[ # Equivalent to boxes = Boxes.from_tensor([[1,1,4,3]])
  458. ... [1., 1.],
  459. ... [4., 1.],
  460. ... [4., 3.],
  461. ... [1., 3.],
  462. ... ]])) # 1x4x2
  463. >>> boxes.to_mask(5, 5)
  464. tensor([[[0., 0., 0., 0., 0.],
  465. [0., 1., 1., 1., 1.],
  466. [0., 1., 1., 1., 1.],
  467. [0., 1., 1., 1., 1.],
  468. [0., 0., 0., 0., 0.]]])
  469. """
  470. if self._data.requires_grad:
  471. raise RuntimeError(
  472. "Boxes.to_tensor isn't differentiable. Please, create boxes from tensors with `requires_grad=False`."
  473. )
  474. is_batched = self._is_batched
  475. dtype = self.dtype
  476. device = self.device
  477. # -----------------
  478. # CPU Hotpath (loop)
  479. # -----------------
  480. if device.type != "cuda":
  481. if self._is_batched: # (B, N, 4, 2)
  482. mask = torch.zeros(
  483. (self._data.shape[0], self._data.shape[1], height, width), dtype=self.dtype, device=self.device
  484. )
  485. else: # (N, 4, 2)
  486. mask = torch.zeros((self._data.shape[0], height, width), dtype=self.dtype, device=self.device)
  487. # Boxes coordinates can be outside the image size after transforms. Clamp values to the image size
  488. clipped_boxes_xyxy = cast(torch.Tensor, self.to_tensor("xyxy", as_padded_sequence=True))
  489. clipped_boxes_xyxy[..., ::2].clamp_(0, width)
  490. clipped_boxes_xyxy[..., 1::2].clamp_(0, height)
  491. # Reshape mask to (BxN, H, W) and boxes to (BxN, 4) to iterate over all of them.
  492. # Cast boxes coordinates to be integer to use them as indexes. Use round to handle decimal values.
  493. for mask_channel, box_xyxy in zip(
  494. mask.view(-1, height, width), clipped_boxes_xyxy.view(-1, 4).round().int()
  495. ):
  496. # Mask channel dimensions: (height, width)
  497. mask_channel[box_xyxy[1] : box_xyxy[3], box_xyxy[0] : box_xyxy[2]] = 1
  498. return mask
  499. # -----------------
  500. # GPU Hotpath (vectorized)
  501. # -----------------
  502. out_shape: Tuple[int, ...]
  503. if is_batched:
  504. out_shape = (self.shape[0], self.shape[1], height, width)
  505. else:
  506. out_shape = (self.shape[0], height, width)
  507. clipped_boxes_xyxy = cast(Tensor, self.to_tensor("xyxy", as_padded_sequence=True))
  508. clipped_boxes_xyxy[..., ::2].clamp_(0, width)
  509. clipped_boxes_xyxy[..., 1::2].clamp_(0, height)
  510. xyxy = clipped_boxes_xyxy.view(-1, 4).round().long()
  511. x1, y1, x2, y2 = xyxy[:, 0], xyxy[:, 1], xyxy[:, 2], xyxy[:, 3]
  512. x1 = x1.clamp(0, width)
  513. x2 = x2.clamp(0, width)
  514. y1 = y1.clamp(0, height)
  515. y2 = y2.clamp(0, height)
  516. ys = torch.arange(height, device=device)
  517. xs = torch.arange(width, device=device)
  518. y_mask = (ys[None, :] >= y1[:, None]) & (ys[None, :] < y2[:, None] + 1)
  519. x_mask = (xs[None, :] >= x1[:, None]) & (xs[None, :] < x2[:, None] + 1)
  520. masks = (y_mask.unsqueeze(2) & x_mask.unsqueeze(1)).to(dtype)
  521. return masks.view(*out_shape)
  522. def transform_boxes(self, M: torch.Tensor, inplace: bool = False) -> Boxes:
  523. r"""Apply a transformation matrix to the 2D boxes.
  524. Args:
  525. M: The transformation matrix to be applied, shape of :math:`(3, 3)` or :math:`(B, 3, 3)`.
  526. inplace: do transform in-place and return self.
  527. Returns:
  528. The transformed boxes.
  529. """
  530. if not 2 <= M.ndim <= 3 or M.shape[-2:] != (3, 3):
  531. raise ValueError(f"The transformation matrix shape must be (3, 3) or (B, 3, 3). Got {M.shape}.")
  532. transformed_boxes = _transform_boxes(self._data, M)
  533. if inplace:
  534. self._data = transformed_boxes
  535. return self
  536. obj = self.clone()
  537. obj._data = transformed_boxes
  538. return obj
  539. def transform_boxes_(self, M: torch.Tensor) -> Boxes:
  540. """Inplace version of :func:`Boxes.transform_boxes`."""
  541. return self.transform_boxes(M, inplace=True)
  542. def translate(self, size: Tensor, method: str = "warp", inplace: bool = False) -> Boxes:
  543. """Translate boxes by the provided size.
  544. Args:
  545. size: translate size for x, y direction, shape of :math:`(B, 2)`.
  546. method: "warp" or "fast".
  547. inplace: do transform in-place and return self.
  548. Returns:
  549. The transformed boxes.
  550. """
  551. if method == "fast":
  552. raise NotImplementedError
  553. elif method == "warp":
  554. pass
  555. else:
  556. raise NotImplementedError
  557. M: Tensor = eye_like(3, size)
  558. M[:, :2, 2] = size
  559. return self.transform_boxes(M, inplace=inplace)
  560. @property
  561. def data(self) -> torch.Tensor:
  562. return self._data
  563. @property
  564. def mode(self) -> str:
  565. return self._mode
  566. @property
  567. def device(self) -> torch.device:
  568. """Returns boxes device."""
  569. return self._data.device
  570. @property
  571. def dtype(self) -> torch.dtype:
  572. """Returns boxes dtype."""
  573. return self._data.dtype
  574. def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> Boxes:
  575. """Like :func:`torch.nn.Module.to()` method."""
  576. # In torchscript, dtype is a int and not a class. https://github.com/pytorch/pytorch/issues/51941
  577. if dtype is not None and not _is_floating_point_dtype(dtype):
  578. raise ValueError("Boxes must be in floating point")
  579. self._data = self._data.to(device=device, dtype=dtype)
  580. return self
  581. def clone(self) -> Boxes:
  582. obj = type(self)(self._data.clone(), False)
  583. obj._mode = self._mode
  584. obj._N = self._N
  585. obj._is_batched = self._is_batched
  586. return obj
  587. def type(self, dtype: torch.dtype) -> Boxes:
  588. self._data = self._data.type(dtype)
  589. return self
  590. class VideoBoxes(Boxes):
  591. temporal_channel_size: int
  592. @classmethod
  593. def from_tensor( # type: ignore[override]
  594. cls, boxes: torch.Tensor | list[torch.Tensor], validate_boxes: bool = True
  595. ) -> VideoBoxes:
  596. if isinstance(boxes, (list,)) or (boxes.dim() != 5 or boxes.shape[-2:] != torch.Size([4, 2])):
  597. raise ValueError("Input box type is not yet supported. Please input an `BxTxNx4x2` tensor directly.")
  598. temporal_channel_size = boxes.size(1)
  599. quadrilaterals = _boxes_to_quadrilaterals(
  600. boxes.view(boxes.size(0) * boxes.size(1), -1, boxes.size(3), boxes.size(4)),
  601. mode="vertices_plus",
  602. validate_boxes=validate_boxes,
  603. )
  604. out = cls(quadrilaterals, False, "vertices_plus")
  605. out.temporal_channel_size = temporal_channel_size
  606. return out
  607. def to_tensor(self, mode: Optional[str] = None) -> torch.Tensor | list[torch.Tensor]: # type: ignore[override]
  608. out = super().to_tensor(mode, as_padded_sequence=False)
  609. if isinstance(out, Tensor):
  610. return out.view(-1, self.temporal_channel_size, *out.shape[1:])
  611. # If returns a list of boxes.
  612. return [_out.view(-1, self.temporal_channel_size, *_out.shape[1:]) for _out in out]
  613. def clone(self) -> VideoBoxes:
  614. obj = type(self)(self._data.clone(), False)
  615. obj._mode = self._mode
  616. obj._N = self._N
  617. obj._is_batched = self._is_batched
  618. obj.temporal_channel_size = self.temporal_channel_size
  619. return obj
  620. class Boxes3D:
  621. r"""3D boxes containing N or BxN boxes.
  622. Args:
  623. boxes: 3D boxes, shape of :math:`(N,8,3)` or :math:`(B,N,8,3)`. See below for more details.
  624. raise_if_not_floating_point: flag to control floating point casting behaviour when `boxes` is not a floating
  625. point tensor. True to raise an error when `boxes` isn't a floating point tensor, False to cast to float.
  626. Note:
  627. **3D boxes format** is defined as a floating data type tensor of shape ``Nx8x3`` or ``BxNx8x3`` where each box
  628. is a `hexahedron <https://en.wikipedia.org/wiki/Hexahedron>`_ defined by it's 8 vertices coordinates.
  629. Coordinates must be in ``x, y, z`` order. The height, width and depth of a box is defined as
  630. ``width = xmax - xmin + 1``, ``height = ymax - ymin + 1`` and ``depth = zmax - zmin + 1``. Examples of
  631. `hexahedrons <https://en.wikipedia.org/wiki/Hexahedron>`_ are cubes and rhombohedrons.
  632. """
  633. def __init__(
  634. self, boxes: torch.Tensor, raise_if_not_floating_point: bool = True, mode: str = "xyzxyz_plus"
  635. ) -> None:
  636. if not isinstance(boxes, torch.Tensor):
  637. raise TypeError(f"Input boxes is not a Tensor. Got: {type(boxes)}.")
  638. if not boxes.is_floating_point():
  639. if raise_if_not_floating_point:
  640. raise ValueError(f"Coordinates must be in floating point. Got {boxes.dtype}.")
  641. boxes = boxes.float()
  642. if len(boxes.shape) == 0:
  643. boxes = boxes.reshape((-1, 6))
  644. if not (3 <= boxes.ndim <= 4 and boxes.shape[-2:] == (8, 3)):
  645. raise ValueError(f"3D bbox shape must be (N, 8, 3) or (B, N, 8, 3). Got {boxes.shape}.")
  646. self._is_batched = False if boxes.ndim == 3 else True
  647. self._data = boxes
  648. self._mode = mode
  649. def __getitem__(self, key: slice | int | Tensor) -> Boxes3D:
  650. new_box = Boxes3D(self._data[key], False, mode="xyzxyz_plus")
  651. new_box._mode = self._mode
  652. return new_box
  653. def __setitem__(self, key: slice | int | Tensor, value: Boxes3D) -> Boxes3D:
  654. self._data[key] = value._data
  655. return self
  656. @property
  657. def shape(self) -> tuple[int, ...] | Size:
  658. return self.data.shape
  659. def get_boxes_shape(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  660. r"""Compute boxes heights and widths.
  661. Returns:
  662. - Boxes depths, shape of :math:`(N,)` or :math:`(B,N)`.
  663. - Boxes heights, shape of :math:`(N,)` or :math:`(B,N)`.
  664. - Boxes widths, shape of :math:`(N,)` or :math:`(B,N)`.
  665. Example:
  666. >>> boxes_xyzxyz = torch.tensor([[ 0, 1, 2, 10, 21, 32], [3, 4, 5, 43, 54, 65]])
  667. >>> boxes3d = Boxes3D.from_tensor(boxes_xyzxyz)
  668. >>> boxes3d.get_boxes_shape()
  669. (tensor([30., 60.]), tensor([20., 50.]), tensor([10., 40.]))
  670. """
  671. boxes_xyzwhd = self.to_tensor(mode="xyzwhd")
  672. widths, heights, depths = boxes_xyzwhd[..., 3], boxes_xyzwhd[..., 4], boxes_xyzwhd[..., 5]
  673. return depths, heights, widths
  674. @classmethod
  675. def from_tensor(cls, boxes: torch.Tensor, mode: str = "xyzxyz", validate_boxes: bool = True) -> Boxes3D:
  676. r"""Create :class:`Boxes3D` from 3D boxes stored in another format.
  677. Args:
  678. boxes: 3D boxes, shape of :math:`(N,6)` or :math:`(B,N,6)`.
  679. mode: The format in which the 3D boxes are provided.
  680. * 'xyzxyz': boxes are assumed to be in the format ``xmin, ymin, zmin, xmax, ymax, zmax`` where
  681. ``width = xmax - xmin``, ``height = ymax - ymin`` and ``depth = zmax - zmin``.
  682. * 'xyzxyz_plus': similar to 'xyzxyz' mode but where box width, length and depth are defined as
  683. ``width = xmax - xmin + 1``, ``height = ymax - ymin + 1`` and ``depth = zmax - zmin + 1``.
  684. * 'xyzwhd': boxes are assumed to be in the format ``xmin, ymin, zmin, width, height, depth`` where
  685. ``width = xmax - xmin``, ``height = ymax - ymin`` and ``depth = zmax - zmin``.
  686. validate_boxes: check if boxes are valid rectangles or not. Valid rectangles are those with width, height
  687. and depth >= 1 (>= 2 when mode ends with '_plus' suffix).
  688. Returns:
  689. :class:`Boxes3D` class containing the original `boxes` in the format specified by ``mode``.
  690. Examples:
  691. >>> boxes_xyzxyz = torch.as_tensor([[0, 3, 6, 1, 4, 8], [5, 1, 3, 8, 4, 9]])
  692. >>> boxes = Boxes3D.from_tensor(boxes_xyzxyz, mode='xyzxyz')
  693. >>> boxes.data # (2, 8, 3)
  694. tensor([[[0., 3., 6.],
  695. [0., 3., 6.],
  696. [0., 3., 6.],
  697. [0., 3., 6.],
  698. [0., 3., 7.],
  699. [0., 3., 7.],
  700. [0., 3., 7.],
  701. [0., 3., 7.]],
  702. <BLANKLINE>
  703. [[5., 1., 3.],
  704. [7., 1., 3.],
  705. [7., 3., 3.],
  706. [5., 3., 3.],
  707. [5., 1., 8.],
  708. [7., 1., 8.],
  709. [7., 3., 8.],
  710. [5., 3., 8.]]])
  711. """
  712. if not (2 <= boxes.ndim <= 3 and boxes.shape[-1] == 6):
  713. raise ValueError(f"BBox shape must be (N, 6) or (B, N, 6). Got {boxes.shape}.")
  714. batched = boxes.ndim == 3
  715. boxes = boxes if batched else boxes.unsqueeze(0)
  716. boxes = boxes if boxes.is_floating_point() else boxes.float()
  717. xmin, ymin, zmin = boxes[..., 0], boxes[..., 1], boxes[..., 2]
  718. mode = mode.lower()
  719. if mode == "xyzxyz":
  720. width = boxes[..., 3] - boxes[..., 0]
  721. height = boxes[..., 4] - boxes[..., 1]
  722. depth = boxes[..., 5] - boxes[..., 2]
  723. elif mode == "xyzxyz_plus":
  724. width = boxes[..., 3] - boxes[..., 0] + 1
  725. height = boxes[..., 4] - boxes[..., 1] + 1
  726. depth = boxes[..., 5] - boxes[..., 2] + 1
  727. elif mode == "xyzwhd":
  728. depth, height, width = boxes[..., 4], boxes[..., 3], boxes[..., 5]
  729. else:
  730. raise ValueError(f"Unknown mode {mode}")
  731. if validate_boxes:
  732. if (width <= 0).any():
  733. raise ValueError("Some boxes have negative widths or 0.")
  734. if (height <= 0).any():
  735. raise ValueError("Some boxes have negative heights or 0.")
  736. if (depth <= 0).any():
  737. raise ValueError("Some boxes have negative depths or 0.")
  738. hexahedrons = _boxes3d_to_polygons3d(xmin, ymin, zmin, width, height, depth)
  739. hexahedrons = hexahedrons if batched else hexahedrons.squeeze(0)
  740. return cls(hexahedrons, raise_if_not_floating_point=False, mode=mode)
  741. def to_tensor(self, mode: str = "xyzxyz") -> torch.Tensor:
  742. r"""Cast :class:`Boxes3D` to a tensor.
  743. ``mode`` controls which 3D boxes format should be use to represent boxes in the tensor.
  744. Args:
  745. mode: The format in which the boxes are provided.
  746. * 'xyzxyz': boxes are assumed to be in the format ``xmin, ymin, zmin, xmax, ymax, zmax`` where
  747. ``width = xmax - xmin``, ``height = ymax - ymin`` and ``depth = zmax - zmin``.
  748. * 'xyzxyz_plus': similar to 'xyzxyz' mode but where box width, length and depth are defined as
  749. ``width = xmax - xmin + 1``, ``height = ymax - ymin + 1`` and ``depth = zmax - zmin + 1``.
  750. * 'xyzwhd': boxes are assumed to be in the format ``xmin, ymin, zmin, width, height, depth`` where
  751. ``width = xmax - xmin``, ``height = ymax - ymin`` and ``depth = zmax - zmin``.
  752. * 'vertices': boxes are defined by their vertices points in the following ``clockwise`` order:
  753. *front-top-left, front-top-right, front-bottom-right, front-bottom-left, back-top-left,
  754. back-top-right, back-bottom-right, back-bottom-left*. Vertices coordinates are in (x,y, z) order.
  755. Finally, box width, height and depth are defined as ``width = xmax - xmin``, ``height = ymax - ymin``
  756. and ``depth = zmax - zmin``.
  757. * 'vertices_plus': similar to 'vertices' mode but where box width, length and depth are defined as
  758. ``width = xmax - xmin + 1`` and ``height = ymax - ymin + 1``.
  759. Returns:
  760. 3D Boxes tensor in the ``mode`` format. The shape depends with the ``mode`` value:
  761. * 'vertices' or 'verticies_plus': :math:`(N, 8, 3)` or :math:`(B, N, 8, 3)`.
  762. * Any other value: :math:`(N, 6)` or :math:`(B, N, 6)`.
  763. Note:
  764. It is currently non-differentiable due to a bug. See github issue
  765. `#1304 <https://github.com/kornia/kornia/issues/1396>`_.
  766. Examples:
  767. >>> boxes_xyzxyz = torch.as_tensor([[0, 3, 6, 1, 4, 8], [5, 1, 3, 8, 4, 9]])
  768. >>> boxes = Boxes3D.from_tensor(boxes_xyzxyz, mode='xyzxyz')
  769. >>> assert (boxes.to_tensor(mode='xyzxyz') == boxes_xyzxyz).all()
  770. """
  771. if self._data.requires_grad:
  772. raise RuntimeError(
  773. "Boxes3D.to_tensor doesn't support computing gradients since they aren't accurate. "
  774. "Please, create boxes from tensors with `requires_grad=False`. "
  775. "This is a known bug. Help is needed to fix it. For more information, "
  776. "see https://github.com/kornia/kornia/issues/1396."
  777. )
  778. batched_boxes = self._data if self._is_batched else self._data.unsqueeze(0)
  779. # Create boxes in xyzxyz_plus format.
  780. boxes = stack([batched_boxes.amin(dim=-2), batched_boxes.amax(dim=-2)], dim=-2).view(
  781. batched_boxes.shape[0], batched_boxes.shape[1], 6
  782. )
  783. mode = mode.lower()
  784. if mode in ("xyzxyz", "xyzxyz_plus"):
  785. pass
  786. elif mode in ("xyzwhd", "vertices", "vertices_plus"):
  787. width = boxes[..., 3] - boxes[..., 0] + 1
  788. height = boxes[..., 4] - boxes[..., 1] + 1
  789. depth = boxes[..., 5] - boxes[..., 2] + 1
  790. boxes[..., 3] = width
  791. boxes[..., 4] = height
  792. boxes[..., 5] = depth
  793. else:
  794. raise ValueError(f"Unknown mode {mode}")
  795. if mode in ("xyzxyz", "vertices"):
  796. offset = torch.as_tensor([0, 0, 0, 1, 1, 1], device=boxes.device, dtype=boxes.dtype)
  797. boxes = boxes + offset
  798. if mode.startswith("vertices"):
  799. xmin, ymin, zmin = boxes[..., 0], boxes[..., 1], boxes[..., 2]
  800. width, height, depth = boxes[..., 3], boxes[..., 4], boxes[..., 5]
  801. boxes = _boxes3d_to_polygons3d(xmin, ymin, zmin, width, height, depth)
  802. boxes = boxes if self._is_batched else boxes.squeeze(0)
  803. return boxes
  804. def to_mask(self, depth: int, height: int, width: int) -> torch.Tensor:
  805. """Convert ·D boxes to masks. Covered area is 1 and the remaining is 0.
  806. Args:
  807. depth: depth of the masked image/images.
  808. height: height of the masked image/images.
  809. width: width of the masked image/images.
  810. Returns:
  811. the output mask tensor, shape of :math:`(N, depth, width, height)` or :math:`(B,N, depth, width, height)`
  812. and dtype of :func:`Boxes3D.dtype` (it can be any floating point dtype).
  813. Note:
  814. It is currently non-differentiable.
  815. Examples:
  816. >>> boxes = Boxes3D(torch.tensor([[ # Equivalent to boxes = Boxes.3Dfrom_tensor([[1,1,1,3,3,2]])
  817. ... [1., 1., 1.],
  818. ... [3., 1., 1.],
  819. ... [3., 3., 1.],
  820. ... [1., 3., 1.],
  821. ... [1., 1., 2.],
  822. ... [3., 1., 2.],
  823. ... [3., 3., 2.],
  824. ... [1., 3., 2.],
  825. ... ]])) # 1x8x3
  826. >>> boxes.to_mask(4, 5, 5)
  827. tensor([[[[0., 0., 0., 0., 0.],
  828. [0., 0., 0., 0., 0.],
  829. [0., 0., 0., 0., 0.],
  830. [0., 0., 0., 0., 0.],
  831. [0., 0., 0., 0., 0.]],
  832. <BLANKLINE>
  833. [[0., 0., 0., 0., 0.],
  834. [0., 1., 1., 1., 0.],
  835. [0., 1., 1., 1., 0.],
  836. [0., 1., 1., 1., 0.],
  837. [0., 0., 0., 0., 0.]],
  838. <BLANKLINE>
  839. [[0., 0., 0., 0., 0.],
  840. [0., 1., 1., 1., 0.],
  841. [0., 1., 1., 1., 0.],
  842. [0., 1., 1., 1., 0.],
  843. [0., 0., 0., 0., 0.]],
  844. <BLANKLINE>
  845. [[0., 0., 0., 0., 0.],
  846. [0., 0., 0., 0., 0.],
  847. [0., 0., 0., 0., 0.],
  848. [0., 0., 0., 0., 0.],
  849. [0., 0., 0., 0., 0.]]]])
  850. """
  851. if self._data.requires_grad:
  852. raise RuntimeError(
  853. "Boxes.to_tensor isn't differentiable. Please, create boxes from tensors with `requires_grad=False`."
  854. )
  855. if self._is_batched: # (B, N, 8, 3)
  856. mask = zeros(
  857. (self._data.shape[0], self._data.shape[1], depth, height, width),
  858. dtype=self._data.dtype,
  859. device=self._data.device,
  860. )
  861. else: # (N, 8, 3)
  862. mask = zeros((self._data.shape[0], depth, height, width), dtype=self._data.dtype, device=self._data.device)
  863. # Boxes coordinates can be outside the image size after transforms. Clamp values to the image size
  864. clipped_boxes_xyzxyz = self.to_tensor("xyzxyz")
  865. clipped_boxes_xyzxyz[..., ::3].clamp_(0, width)
  866. clipped_boxes_xyzxyz[..., 1::3].clamp_(0, height)
  867. clipped_boxes_xyzxyz[..., 2::3].clamp_(0, depth)
  868. # Reshape mask to (BxN, D, H, W) and boxes to (BxN, 6) to iterate over all of them.
  869. # Cast boxes coordinates to be integer to use them as indexes. Use round to handle decimal values.
  870. for mask_channel, box_xyzxyz in zip(
  871. mask.view(-1, depth, height, width), clipped_boxes_xyzxyz.view(-1, 6).round().int()
  872. ):
  873. # Mask channel dimensions: (depth, height, width)
  874. mask_channel[
  875. box_xyzxyz[2] : box_xyzxyz[5], box_xyzxyz[1] : box_xyzxyz[4], box_xyzxyz[0] : box_xyzxyz[3]
  876. ] = 1
  877. return mask
  878. def transform_boxes(self, M: torch.Tensor, inplace: bool = False) -> Boxes3D:
  879. r"""Apply a transformation matrix to the 3D boxes.
  880. Args:
  881. M: The transformation matrix to be applied, shape of :math:`(4, 4)` or :math:`(B, 4, 4)`.
  882. inplace: do transform in-place and return self.
  883. Returns:
  884. The transformed boxes.
  885. """
  886. if not 2 <= M.ndim <= 3 or M.shape[-2:] != (4, 4):
  887. raise ValueError(f"The transformation matrix shape must be (4, 4) or (B, 4, 4). Got {M.shape}.")
  888. transformed_boxes = _transform_boxes(self._data, M)
  889. if inplace:
  890. self._data = transformed_boxes
  891. return self
  892. return Boxes3D(transformed_boxes, False, "xyzxyz_plus")
  893. def transform_boxes_(self, M: torch.Tensor) -> Boxes3D:
  894. """Inplace version of :func:`Boxes3D.transform_boxes`."""
  895. return self.transform_boxes(M, inplace=True)
  896. @property
  897. def data(self) -> torch.Tensor:
  898. return self._data
  899. @property
  900. def mode(self) -> str:
  901. return self._mode
  902. @property
  903. def device(self) -> torch.device:
  904. """Returns boxes device."""
  905. return self._data.device
  906. @property
  907. def dtype(self) -> torch.dtype:
  908. """Returns boxes dtype."""
  909. return self._data.dtype
  910. def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> Boxes3D:
  911. """Like :func:`torch.nn.Module.to()` method."""
  912. # In torchscript, dtype is a int and not a class. https://github.com/pytorch/pytorch/issues/51941
  913. if dtype is not None and not _is_floating_point_dtype(dtype):
  914. raise ValueError("Boxes must be in floating point")
  915. self._data = self._data.to(device=device, dtype=dtype)
  916. return self