| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139 |
- # LICENSE HEADER MANAGED BY add-license-header
- #
- # Copyright 2018 Kornia Team
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- from __future__ import annotations
- from typing import Optional, Tuple, cast
- import torch
- from torch import Size
- from kornia.core import Tensor, stack, zeros
- from kornia.geometry.bbox import validate_bbox
- from kornia.geometry.linalg import transform_points
- from kornia.utils import eye_like
- __all__ = ["Boxes", "Boxes3D"]
- def _is_floating_point_dtype(dtype: torch.dtype) -> bool:
- return dtype in (torch.float16, torch.float32, torch.float64, torch.bfloat16, torch.half)
- def _merge_box_list(boxes: list[torch.Tensor], method: str = "pad") -> tuple[torch.Tensor, list[int]]:
- r"""Merge a list of boxes into one tensor."""
- if not all(box.shape[-2:] == torch.Size([4, 2]) and box.dim() == 3 for box in boxes):
- raise TypeError(f"Input boxes must be a list of (N, 4, 2) shaped. Got: {[box.shape for box in boxes]}.")
- if method == "pad":
- max_N = max(box.shape[0] for box in boxes)
- stats = [max_N - box.shape[0] for box in boxes]
- output = torch.nn.utils.rnn.pad_sequence(boxes, batch_first=True)
- else:
- raise NotImplementedError(f"`{method}` is not implemented.")
- return output, stats
- def _transform_boxes(boxes: torch.Tensor, M: torch.Tensor) -> torch.Tensor:
- """Transform 3D and 2D in kornia format by applying the transformation matrix M.
- Boxes and the transformation matrix could be batched or not.
- Args:
- boxes: 2D quadrilaterals or 3D hexahedrons in kornia format.
- M: the transformation matrix of shape :math:`(3, 3)` or :math:`(B, 3, 3)` for 2D and :math:`(4, 4)` or
- :math:`(B, 4, 4)` for 3D hexahedron.
- """
- M = M if M.is_floating_point() else M.float()
- # Work with batch as kornia.transform_points only supports a batch of points.
- boxes_per_batch, n_points_per_box, coordinates_dimension = boxes.shape[-3:]
- if boxes_per_batch == 0:
- return boxes
- points = boxes.view(-1, n_points_per_box * boxes_per_batch, coordinates_dimension)
- M = M if M.ndim == 3 else M.unsqueeze(0)
- if points.shape[0] != M.shape[0]:
- raise ValueError(
- f"Batch size mismatch. Got {points.shape[0]} for boxes and {M.shape[0]} for the transformation matrix."
- )
- transformed_boxes: torch.Tensor = transform_points(M, points)
- transformed_boxes = transformed_boxes.view_as(boxes)
- return transformed_boxes
- def _boxes_to_polygons(
- xmin: torch.Tensor, ymin: torch.Tensor, width: torch.Tensor, height: torch.Tensor
- ) -> torch.Tensor:
- if not xmin.ndim == ymin.ndim == width.ndim == height.ndim == 2:
- raise ValueError("We expect to create a batch of 2D boxes (quadrilaterals) in vertices format (B, N, 4, 2)")
- # Create (B,N,4,2) with all points in top left position of boxes
- polygons = zeros((xmin.shape[0], xmin.shape[1], 4, 2), device=xmin.device, dtype=xmin.dtype)
- polygons[..., 0] = xmin.unsqueeze(-1)
- polygons[..., 1] = ymin.unsqueeze(-1)
- # Shift top-right, bottom-right, bottom-left points to the right coordinates
- polygons[..., 1, 0] += width - 1 # Top right
- polygons[..., 2, 0] += width - 1 # Bottom right
- polygons[..., 2, 1] += height - 1 # Bottom right
- polygons[..., 3, 1] += height - 1 # Bottom left
- return polygons
- def _boxes_to_quadrilaterals(boxes: torch.Tensor, mode: str = "xyxy", validate_boxes: bool = True) -> torch.Tensor:
- """Convert from boxes to quadrilaterals."""
- mode = mode.lower()
- if mode.startswith("vertices"):
- batched = boxes.ndim == 4
- if not (3 <= boxes.ndim <= 4 and boxes.shape[-2:] == torch.Size([4, 2])):
- raise ValueError(f"Boxes shape must be (N, 4, 2) or (B, N, 4, 2) when {mode} mode. Got {boxes.shape}.")
- elif mode.startswith("xy"):
- batched = boxes.ndim == 3
- if not (2 <= boxes.ndim <= 3 and boxes.shape[-1] == 4):
- raise ValueError(f"Boxes shape must be (N, 4) or (B, N, 4) when {mode} mode. Got {boxes.shape}.")
- else:
- raise ValueError(f"Unknown mode {mode}")
- boxes = boxes if boxes.is_floating_point() else boxes.float()
- boxes = boxes if batched else boxes.unsqueeze(0)
- if mode.startswith("vertices"):
- if mode == "vertices":
- quadrilaterals = boxes.clone()
- # Here, vertices are quadrilaterals with width and height defined as `width = xmax - xmin` and
- # `height = ymax - ymin`. We need to convert to `width = xmax - xmin + 1` and `height = ymax - ymin + 1` to
- # match with internal Boxes Kornia representation.
- quadrilaterals[..., 1:3, 0] = quadrilaterals[..., 1:3, 0] - 1
- quadrilaterals[..., 2:, 1] = quadrilaterals[..., 2:, 1] - 1
- elif mode == "vertices_plus":
- # Avoid passing reference
- quadrilaterals = boxes.clone()
- else:
- raise ValueError(f"Unknown mode {mode}")
- not validate_boxes or validate_bbox(quadrilaterals)
- elif mode.startswith("xy"):
- if mode == "xyxy":
- height, width = boxes[..., 3] - boxes[..., 1], boxes[..., 2] - boxes[..., 0]
- elif mode == "xyxy_plus":
- height, width = boxes[..., 3] - boxes[..., 1] + 1, boxes[..., 2] - boxes[..., 0] + 1
- elif mode == "xywh":
- height, width = boxes[..., 3], boxes[..., 2]
- else:
- raise ValueError(f"Unknown mode {mode}")
- if validate_boxes:
- if (width <= 0).any():
- raise ValueError("Some boxes have negative widths or 0.")
- if (height <= 0).any():
- raise ValueError("Some boxes have negative heights or 0.")
- xmin, ymin = boxes[..., 0], boxes[..., 1]
- quadrilaterals = _boxes_to_polygons(xmin, ymin, width, height)
- else:
- raise ValueError(f"Unknown mode {mode}")
- quadrilaterals = quadrilaterals if batched else quadrilaterals.squeeze(0)
- return quadrilaterals
- def _boxes3d_to_polygons3d(
- xmin: torch.Tensor,
- ymin: torch.Tensor,
- zmin: torch.Tensor,
- width: torch.Tensor,
- height: torch.Tensor,
- depth: torch.Tensor,
- ) -> torch.Tensor:
- if not xmin.ndim == ymin.ndim == zmin.ndim == width.ndim == height.ndim == depth.ndim == 2:
- raise ValueError("We expect to create a batch of 3D boxes (hexahedrons) in vertices format (B, N, 8, 3)")
- # Front
- # Create (B,N,4,3) with all points in front top left position of boxes
- front_vertices = zeros((xmin.shape[0], xmin.shape[1], 4, 3), device=xmin.device, dtype=xmin.dtype)
- front_vertices[..., 0] = xmin.unsqueeze(-1)
- front_vertices[..., 1] = ymin.unsqueeze(-1)
- front_vertices[..., 2] = zmin.unsqueeze(-1)
- # Shift front-top-right, front-bottom-right, front-bottom-left points to the right coordinates
- front_vertices[..., 1, 0] += width - 1 # Top right
- front_vertices[..., 2, 0] += width - 1 # Bottom right
- front_vertices[..., 2, 1] += height - 1 # Bottom right
- front_vertices[..., 3, 1] += height - 1 # Bottom left
- # Back
- back_vertices = front_vertices.clone()
- back_vertices[..., 2] += depth.unsqueeze(-1) - 1
- polygons3d = torch.cat([front_vertices, back_vertices], dim=-2)
- return polygons3d
- class Boxes:
- r"""2D boxes containing N or BxN boxes.
- Args:
- boxes: 2D boxes, shape of :math:`(N, 4, 2)`, :math:`(B, N, 4, 2)` or a list of :math:`(N, 4, 2)`.
- See below for more details.
- raise_if_not_floating_point: flag to control floating point casting behaviour when `boxes` is not a
- floating point tensor. True to raise an error when `boxes` isn't a floating point tensor, False
- to cast to float.
- mode: the box format of the input boxes.
- Note:
- **2D boxes format** is defined as a floating data type tensor of shape ``Nx4x2`` or ``BxNx4x2``
- where each box is a `quadrilateral <https://en.wikipedia.org/wiki/Quadrilateral>`_ defined by it's
- 4 vertices coordinates (A, B, C, D). Coordinates must be in ``x, y`` order. The height and width of
- a box is defined as ``width = xmax - xmin + 1`` and ``height = ymax - ymin + 1``. Examples of
- `quadrilaterals <https://en.wikipedia.org/wiki/Quadrilateral>`_ are rectangles, rhombus and trapezoids.
- """
- def __init__(
- self,
- boxes: torch.Tensor | list[torch.Tensor],
- raise_if_not_floating_point: bool = True,
- mode: str = "vertices_plus",
- ) -> None:
- self._N: Optional[list[int]] = None
- if isinstance(boxes, list):
- boxes, self._N = _merge_box_list(boxes)
- if not isinstance(boxes, torch.Tensor):
- raise TypeError(f"Input boxes is not a Tensor. Got: {type(boxes)}.")
- if not boxes.is_floating_point():
- if raise_if_not_floating_point:
- raise ValueError(f"Coordinates must be in floating point. Got {boxes.dtype}")
- boxes = boxes.float()
- if len(boxes.shape) == 0:
- boxes = boxes.reshape((-1, 4))
- if not (3 <= boxes.ndim <= 4 and boxes.shape[-2:] == (4, 2)):
- raise ValueError(f"Boxes shape must be (N, 4, 2) or (B, N, 4, 2). Got {boxes.shape}.")
- self._is_batched = False if boxes.ndim == 3 else True
- self._data = boxes
- self._mode = mode
- def __getitem__(self, key: slice | int | Tensor) -> Boxes:
- new_box = type(self)(self._data[key], False)
- new_box._mode = self._mode
- return new_box
- def __setitem__(self, key: slice | int | Tensor, value: Boxes) -> Boxes:
- self._data[key] = value._data
- return self
- @property
- def shape(self) -> tuple[int, ...] | Size:
- return self.data.shape
- def get_boxes_shape(self) -> tuple[torch.Tensor, torch.Tensor]:
- r"""Compute boxes heights and widths.
- Returns:
- - Boxes heights, shape of :math:`(N,)` or :math:`(B,N)`.
- - Boxes widths, shape of :math:`(N,)` or :math:`(B,N)`.
- Example:
- >>> boxes_xyxy = torch.tensor([[[1,1,2,2],[1,1,3,2]]])
- >>> boxes = Boxes.from_tensor(boxes_xyxy)
- >>> boxes.get_boxes_shape()
- (tensor([[1., 1.]]), tensor([[1., 2.]]))
- """
- boxes_xywh = cast(torch.Tensor, self.to_tensor("xywh", as_padded_sequence=True))
- widths, heights = boxes_xywh[..., 2], boxes_xywh[..., 3]
- return heights, widths
- def merge(self, boxes: Boxes, inplace: bool = False) -> Boxes:
- """Merge boxes.
- Say, current instance holds :math:`(B, N, 4, 2)` and the incoming boxes holds :math:`(B, M, 4, 2)`,
- the merge results in :math:`(B, N + M, 4, 2)`.
- Args:
- boxes: 2D boxes.
- inplace: do transform in-place and return self.
- """
- data = torch.cat([self._data, boxes.data], dim=1)
- if inplace:
- self._data = data
- return self
- obj = self.clone()
- obj._data = data
- return obj
- def index_put(
- self, indices: tuple[Tensor, ...] | list[Tensor], values: Tensor | Boxes, inplace: bool = False
- ) -> Boxes:
- if inplace:
- _data = self._data
- else:
- _data = self._data.clone()
- if isinstance(values, Boxes):
- _data.index_put_(indices, values.data)
- else:
- _data.index_put_(indices, values)
- if inplace:
- return self
- obj = self.clone()
- obj._data = _data
- return obj
- def pad(self, padding_size: Tensor) -> Boxes:
- """Pad a bounding box.
- Args:
- padding_size: (B, 4)
- """
- if not (len(padding_size.shape) == 2 and padding_size.size(1) == 4):
- raise RuntimeError(f"Expected padding_size as (B, 4). Got {padding_size.shape}.")
- self._data[..., 0] += padding_size[..., None, :1].to(device=self._data.device) # left padding
- self._data[..., 1] += padding_size[..., None, 2:3].to(device=self._data.device) # top padding
- return self
- def unpad(self, padding_size: Tensor) -> Boxes:
- """Pad a bounding box.
- Args:
- padding_size: (B, 4)
- """
- if not (len(padding_size.shape) == 2 and padding_size.size(1) == 4):
- raise RuntimeError(f"Expected padding_size as (B, 4). Got {padding_size.shape}.")
- self._data[..., 0] -= padding_size[..., None, :1].to(device=self._data.device) # left padding
- self._data[..., 1] -= padding_size[..., None, 2:3].to(device=self._data.device) # top padding
- return self
- def clamp(
- self,
- topleft: Optional[Tensor | tuple[int, int]] = None,
- botright: Optional[Tensor | tuple[int, int]] = None,
- inplace: bool = False,
- ) -> Boxes:
- if not (isinstance(topleft, Tensor) and isinstance(botright, Tensor)):
- raise NotImplementedError
- if inplace:
- _data = self._data
- else:
- _data = self._data.clone()
- topleft_x = topleft[:, None, :1].repeat(1, _data.size(1), 4)
- _data[..., 0][_data[..., 0] < topleft_x] = topleft_x[_data[..., 0] < topleft_x]
- topleft_y = topleft[:, None, 1:].repeat(1, _data.size(1), 4)
- _data[..., 1][_data[..., 1] < topleft_y] = topleft_y[_data[..., 1] < topleft_y]
- botright_x = botright[:, None, :1].repeat(1, _data.size(1), 4)
- _data[..., 0][_data[..., 0] > botright_x] = botright_x[_data[..., 0] > botright_x]
- botright_y = botright[:, None, 1:].repeat(1, _data.size(1), 4)
- _data[..., 1][_data[..., 1] > botright_y] = botright_y[_data[..., 1] > botright_y]
- if inplace:
- return self
- obj = self.clone()
- obj._data = _data
- return obj
- def trim(self, correspondence_preserve: bool = False, inplace: bool = False) -> Boxes:
- """Trim out zero padded boxes.
- Given box arrangements of shape :math:`(4, 4, Box)`:
- == === == === == === == === ==
- -- Box -- Box -- Box -- Box --
- -- 0 -- 0 -- Box -- Box --
- -- 0 -- Box -- 0 -- 0 --
- -- 0 -- 0 -- 0 -- 0 --
- == === == === == === == === ==
- Nothing will change if correspondence_preserve is True. Only pure zero layers will be removed, resulting in
- shape :math:`(4, 3, Box)`:
- == === == === == === == === ==
- -- Box -- Box -- Box -- Box --
- -- 0 -- 0 -- Box -- Box --
- -- 0 -- Box -- 0 -- 0 --
- == === == === == === == === ==
- Otherwise, you will get :math:`(4, 2, Box)`:
- == === == === == === == === ==
- -- Box -- Box -- Box -- Box --
- -- 0 -- Box -- Box -- Box --
- == === == === == === == === ==
- """
- raise NotImplementedError
- def filter_boxes_by_area(
- self, min_area: Optional[float] = None, max_area: Optional[float] = None, inplace: bool = False
- ) -> Boxes:
- area = self.compute_area()
- if inplace:
- _data = self._data
- else:
- _data = self._data.clone()
- if min_area is not None:
- _data[area < min_area] = 0.0
- if max_area is not None:
- _data[area > max_area] = 0.0
- if inplace:
- return self
- obj = self.clone()
- obj._data = _data
- return obj
- def compute_area(self) -> torch.Tensor:
- """Return :math:`(B, N)`."""
- coords = self._data.view((-1, 4, 2)) if self._data.ndim == 4 else self._data
- # calculate centroid of the box
- centroid = coords.mean(dim=1, keepdim=True)
- # calculate the angle from centroid to each corner
- angles = torch.atan2(coords[..., 1] - centroid[..., 1], coords[..., 0] - centroid[..., 0])
- # sort the corners by angle to get an order for shoelace formula
- _, clockwise_indices = torch.sort(angles, dim=1, descending=True)
- # gather the corners in the new order
- ordered_corners = torch.gather(coords, 1, clockwise_indices.unsqueeze(-1).expand(-1, -1, 2))
- x, y = ordered_corners[..., 0], ordered_corners[..., 1]
- # Gaussian/Shoelace formula https://en.wikipedia.org/wiki/Shoelace_formula
- area = 0.5 * torch.abs(torch.sum((x * torch.roll(y, 1, 1)) - (y * torch.roll(x, 1, 1)), dim=1))
- return area.view(self._data.shape[:2]) if self._data.ndim == 4 else area
- @classmethod
- def from_tensor(
- cls, boxes: torch.Tensor | list[torch.Tensor], mode: str = "xyxy", validate_boxes: bool = True
- ) -> Boxes:
- r"""Create :class:`Boxes` from boxes stored in another format.
- Args:
- boxes: 2D boxes, shape of :math:`(N, 4)`, :math:`(B, N, 4)`, :math:`(N, 4, 2)` or :math:`(B, N, 4, 2)`.
- mode: The format in which the boxes are provided.
- * 'xyxy': boxes are assumed to be in the format ``xmin, ymin, xmax, ymax`` where ``width = xmax - xmin``
- and ``height = ymax - ymin``. With shape :math:`(N, 4)`, :math:`(B, N, 4)`.
- * 'xyxy_plus': similar to 'xyxy' mode but where box width and length are defined as
- ``width = xmax - xmin + 1`` and ``height = ymax - ymin + 1``.
- With shape :math:`(N, 4)`, :math:`(B, N, 4)`.
- * 'xywh': boxes are assumed to be in the format ``xmin, ymin, width, height`` where
- ``width = xmax - xmin`` and ``height = ymax - ymin``. With shape :math:`(N, 4)`, :math:`(B, N, 4)`.
- * 'vertices': boxes are defined by their vertices points in the following ``clockwise`` order:
- *top-left, top-right, bottom-right, bottom-left*. Vertices coordinates are in (x,y) order. Finally,
- box width and height are defined as ``width = xmax - xmin`` and ``height = ymax - ymin``.
- With shape :math:`(N, 4, 2)` or :math:`(B, N, 4, 2)`.
- * 'vertices_plus': similar to 'vertices' mode but where box width and length are defined as
- ``width = xmax - xmin + 1`` and ``height = ymax - ymin + 1``. ymin + 1``.
- With shape :math:`(N, 4, 2)` or :math:`(B, N, 4, 2)`.
- validate_boxes: check if boxes are valid rectangles or not. Valid rectangles are those with width
- and height >= 1 (>= 2 when mode ends with '_plus' suffix).
- Returns:
- :class:`Boxes` class containing the original `boxes` in the format specified by ``mode``.
- Examples:
- >>> boxes_xyxy = torch.as_tensor([[0, 3, 1, 4], [5, 1, 8, 4]])
- >>> boxes = Boxes.from_tensor(boxes_xyxy, mode='xyxy')
- >>> boxes.data # (2, 4, 2)
- tensor([[[0., 3.],
- [0., 3.],
- [0., 3.],
- [0., 3.]],
- <BLANKLINE>
- [[5., 1.],
- [7., 1.],
- [7., 3.],
- [5., 3.]]])
- """
- quadrilaterals: torch.Tensor | list[torch.Tensor]
- if isinstance(boxes, torch.Tensor):
- quadrilaterals = _boxes_to_quadrilaterals(boxes, mode=mode, validate_boxes=validate_boxes)
- else:
- quadrilaterals = [_boxes_to_quadrilaterals(box, mode, validate_boxes) for box in boxes]
- return cls(quadrilaterals, False, mode)
- def to_tensor(
- self, mode: Optional[str] = None, as_padded_sequence: bool = False
- ) -> torch.Tensor | list[torch.Tensor]:
- r"""Cast :class:`Boxes` to a tensor.
- ``mode`` controls which 2D boxes format should be use to represent boxes in the tensor.
- Args:
- mode: the output box format. It could be:
- * 'xyxy': boxes are defined as ``xmin, ymin, xmax, ymax`` where ``width = xmax - xmin`` and
- ``height = ymax - ymin``.
- * 'xyxy_plus': similar to 'xyxy' mode but where box width and length are defined as
- ``width = xmax - xmin + 1`` and ``height = ymax - ymin + 1``.
- * 'xywh': boxes are defined as ``xmin, ymin, width, height`` where ``width = xmax - xmin``
- and ``height = ymax - ymin``.
- * 'vertices': boxes are defined by their vertices points in the following ``clockwise`` order:
- *top-left, top-right, bottom-right, bottom-left*. Vertices coordinates are in (x,y) order. Finally,
- box width and height are defined as ``width = xmax - xmin`` and ``height = ymax - ymin``.
- * 'vertices_plus': similar to 'vertices' mode but where box width and length are defined as
- ``width = xmax - xmin + 1`` and ``height = ymax - ymin + 1``. ymin + 1``.
- as_padded_sequence: whether to keep the pads for a list of boxes. This parameter is only valid
- if the boxes are from a box list whilst `from_tensor`.
- Returns:
- Boxes tensor in the ``mode`` format. The shape depends with the ``mode`` value:
- * 'vertices' or 'verticies_plus': :math:`(N, 4, 2)` or :math:`(B, N, 4, 2)`.
- * Any other value: :math:`(N, 4)` or :math:`(B, N, 4)`.
- Examples:
- >>> boxes_xyxy = torch.as_tensor([[0, 3, 1, 4], [5, 1, 8, 4]])
- >>> boxes = Boxes.from_tensor(boxes_xyxy)
- >>> assert (boxes_xyxy == boxes.to_tensor(mode='xyxy')).all()
- """
- batched_boxes = self._data if self._is_batched else self._data.unsqueeze(0)
- boxes: torch.Tensor | list[torch.Tensor]
- # Create boxes in xyxy_plus format.
- boxes = torch.stack([batched_boxes.amin(dim=-2), batched_boxes.amax(dim=-2)], dim=-2).view(
- batched_boxes.shape[0], batched_boxes.shape[1], 4
- )
- if mode is None:
- mode = self.mode
- mode = mode.lower()
- if mode in ("xyxy", "xyxy_plus"):
- pass
- elif mode in ("xywh", "vertices", "vertices_plus"):
- height, width = boxes[..., 3] - boxes[..., 1] + 1, boxes[..., 2] - boxes[..., 0] + 1
- boxes[..., 2] = width
- boxes[..., 3] = height
- else:
- raise ValueError(f"Unknown mode {mode}")
- if mode in ("xyxy", "vertices"):
- offset = torch.as_tensor([0, 0, 1, 1], device=boxes.device, dtype=boxes.dtype)
- boxes = boxes + offset
- if mode.startswith("vertices"):
- boxes = _boxes_to_polygons(boxes[..., 0], boxes[..., 1], boxes[..., 2], boxes[..., 3])
- if self._N is not None and not as_padded_sequence:
- boxes = [torch.nn.functional.pad(o, (len(o.shape) - 1) * [0, 0] + [0, -n]) for o, n in zip(boxes, self._N)]
- else:
- boxes = boxes if self._is_batched else boxes.squeeze(0)
- return boxes
- def to_mask(self, height: int, width: int) -> torch.Tensor:
- """Convert 2D boxes to masks. Covered area is 1 and the remaining is 0.
- Args:
- height: height of the masked image/images.
- width: width of the masked image/images.
- Returns:
- the output mask tensor, shape of :math:`(N, width, height)` or :math:`(B,N, width, height)` and dtype of
- :func:`Boxes.dtype` (it can be any floating point dtype).
- Note:
- It is currently non-differentiable.
- Examples:
- >>> boxes = Boxes(torch.tensor([[ # Equivalent to boxes = Boxes.from_tensor([[1,1,4,3]])
- ... [1., 1.],
- ... [4., 1.],
- ... [4., 3.],
- ... [1., 3.],
- ... ]])) # 1x4x2
- >>> boxes.to_mask(5, 5)
- tensor([[[0., 0., 0., 0., 0.],
- [0., 1., 1., 1., 1.],
- [0., 1., 1., 1., 1.],
- [0., 1., 1., 1., 1.],
- [0., 0., 0., 0., 0.]]])
- """
- if self._data.requires_grad:
- raise RuntimeError(
- "Boxes.to_tensor isn't differentiable. Please, create boxes from tensors with `requires_grad=False`."
- )
- is_batched = self._is_batched
- dtype = self.dtype
- device = self.device
- # -----------------
- # CPU Hotpath (loop)
- # -----------------
- if device.type != "cuda":
- if self._is_batched: # (B, N, 4, 2)
- mask = torch.zeros(
- (self._data.shape[0], self._data.shape[1], height, width), dtype=self.dtype, device=self.device
- )
- else: # (N, 4, 2)
- mask = torch.zeros((self._data.shape[0], height, width), dtype=self.dtype, device=self.device)
- # Boxes coordinates can be outside the image size after transforms. Clamp values to the image size
- clipped_boxes_xyxy = cast(torch.Tensor, self.to_tensor("xyxy", as_padded_sequence=True))
- clipped_boxes_xyxy[..., ::2].clamp_(0, width)
- clipped_boxes_xyxy[..., 1::2].clamp_(0, height)
- # Reshape mask to (BxN, H, W) and boxes to (BxN, 4) to iterate over all of them.
- # Cast boxes coordinates to be integer to use them as indexes. Use round to handle decimal values.
- for mask_channel, box_xyxy in zip(
- mask.view(-1, height, width), clipped_boxes_xyxy.view(-1, 4).round().int()
- ):
- # Mask channel dimensions: (height, width)
- mask_channel[box_xyxy[1] : box_xyxy[3], box_xyxy[0] : box_xyxy[2]] = 1
- return mask
- # -----------------
- # GPU Hotpath (vectorized)
- # -----------------
- out_shape: Tuple[int, ...]
- if is_batched:
- out_shape = (self.shape[0], self.shape[1], height, width)
- else:
- out_shape = (self.shape[0], height, width)
- clipped_boxes_xyxy = cast(Tensor, self.to_tensor("xyxy", as_padded_sequence=True))
- clipped_boxes_xyxy[..., ::2].clamp_(0, width)
- clipped_boxes_xyxy[..., 1::2].clamp_(0, height)
- xyxy = clipped_boxes_xyxy.view(-1, 4).round().long()
- x1, y1, x2, y2 = xyxy[:, 0], xyxy[:, 1], xyxy[:, 2], xyxy[:, 3]
- x1 = x1.clamp(0, width)
- x2 = x2.clamp(0, width)
- y1 = y1.clamp(0, height)
- y2 = y2.clamp(0, height)
- ys = torch.arange(height, device=device)
- xs = torch.arange(width, device=device)
- y_mask = (ys[None, :] >= y1[:, None]) & (ys[None, :] < y2[:, None] + 1)
- x_mask = (xs[None, :] >= x1[:, None]) & (xs[None, :] < x2[:, None] + 1)
- masks = (y_mask.unsqueeze(2) & x_mask.unsqueeze(1)).to(dtype)
- return masks.view(*out_shape)
- def transform_boxes(self, M: torch.Tensor, inplace: bool = False) -> Boxes:
- r"""Apply a transformation matrix to the 2D boxes.
- Args:
- M: The transformation matrix to be applied, shape of :math:`(3, 3)` or :math:`(B, 3, 3)`.
- inplace: do transform in-place and return self.
- Returns:
- The transformed boxes.
- """
- if not 2 <= M.ndim <= 3 or M.shape[-2:] != (3, 3):
- raise ValueError(f"The transformation matrix shape must be (3, 3) or (B, 3, 3). Got {M.shape}.")
- transformed_boxes = _transform_boxes(self._data, M)
- if inplace:
- self._data = transformed_boxes
- return self
- obj = self.clone()
- obj._data = transformed_boxes
- return obj
- def transform_boxes_(self, M: torch.Tensor) -> Boxes:
- """Inplace version of :func:`Boxes.transform_boxes`."""
- return self.transform_boxes(M, inplace=True)
- def translate(self, size: Tensor, method: str = "warp", inplace: bool = False) -> Boxes:
- """Translate boxes by the provided size.
- Args:
- size: translate size for x, y direction, shape of :math:`(B, 2)`.
- method: "warp" or "fast".
- inplace: do transform in-place and return self.
- Returns:
- The transformed boxes.
- """
- if method == "fast":
- raise NotImplementedError
- elif method == "warp":
- pass
- else:
- raise NotImplementedError
- M: Tensor = eye_like(3, size)
- M[:, :2, 2] = size
- return self.transform_boxes(M, inplace=inplace)
- @property
- def data(self) -> torch.Tensor:
- return self._data
- @property
- def mode(self) -> str:
- return self._mode
- @property
- def device(self) -> torch.device:
- """Returns boxes device."""
- return self._data.device
- @property
- def dtype(self) -> torch.dtype:
- """Returns boxes dtype."""
- return self._data.dtype
- def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> Boxes:
- """Like :func:`torch.nn.Module.to()` method."""
- # In torchscript, dtype is a int and not a class. https://github.com/pytorch/pytorch/issues/51941
- if dtype is not None and not _is_floating_point_dtype(dtype):
- raise ValueError("Boxes must be in floating point")
- self._data = self._data.to(device=device, dtype=dtype)
- return self
- def clone(self) -> Boxes:
- obj = type(self)(self._data.clone(), False)
- obj._mode = self._mode
- obj._N = self._N
- obj._is_batched = self._is_batched
- return obj
- def type(self, dtype: torch.dtype) -> Boxes:
- self._data = self._data.type(dtype)
- return self
- class VideoBoxes(Boxes):
- temporal_channel_size: int
- @classmethod
- def from_tensor( # type: ignore[override]
- cls, boxes: torch.Tensor | list[torch.Tensor], validate_boxes: bool = True
- ) -> VideoBoxes:
- if isinstance(boxes, (list,)) or (boxes.dim() != 5 or boxes.shape[-2:] != torch.Size([4, 2])):
- raise ValueError("Input box type is not yet supported. Please input an `BxTxNx4x2` tensor directly.")
- temporal_channel_size = boxes.size(1)
- quadrilaterals = _boxes_to_quadrilaterals(
- boxes.view(boxes.size(0) * boxes.size(1), -1, boxes.size(3), boxes.size(4)),
- mode="vertices_plus",
- validate_boxes=validate_boxes,
- )
- out = cls(quadrilaterals, False, "vertices_plus")
- out.temporal_channel_size = temporal_channel_size
- return out
- def to_tensor(self, mode: Optional[str] = None) -> torch.Tensor | list[torch.Tensor]: # type: ignore[override]
- out = super().to_tensor(mode, as_padded_sequence=False)
- if isinstance(out, Tensor):
- return out.view(-1, self.temporal_channel_size, *out.shape[1:])
- # If returns a list of boxes.
- return [_out.view(-1, self.temporal_channel_size, *_out.shape[1:]) for _out in out]
- def clone(self) -> VideoBoxes:
- obj = type(self)(self._data.clone(), False)
- obj._mode = self._mode
- obj._N = self._N
- obj._is_batched = self._is_batched
- obj.temporal_channel_size = self.temporal_channel_size
- return obj
- class Boxes3D:
- r"""3D boxes containing N or BxN boxes.
- Args:
- boxes: 3D boxes, shape of :math:`(N,8,3)` or :math:`(B,N,8,3)`. See below for more details.
- raise_if_not_floating_point: flag to control floating point casting behaviour when `boxes` is not a floating
- point tensor. True to raise an error when `boxes` isn't a floating point tensor, False to cast to float.
- Note:
- **3D boxes format** is defined as a floating data type tensor of shape ``Nx8x3`` or ``BxNx8x3`` where each box
- is a `hexahedron <https://en.wikipedia.org/wiki/Hexahedron>`_ defined by it's 8 vertices coordinates.
- Coordinates must be in ``x, y, z`` order. The height, width and depth of a box is defined as
- ``width = xmax - xmin + 1``, ``height = ymax - ymin + 1`` and ``depth = zmax - zmin + 1``. Examples of
- `hexahedrons <https://en.wikipedia.org/wiki/Hexahedron>`_ are cubes and rhombohedrons.
- """
- def __init__(
- self, boxes: torch.Tensor, raise_if_not_floating_point: bool = True, mode: str = "xyzxyz_plus"
- ) -> None:
- if not isinstance(boxes, torch.Tensor):
- raise TypeError(f"Input boxes is not a Tensor. Got: {type(boxes)}.")
- if not boxes.is_floating_point():
- if raise_if_not_floating_point:
- raise ValueError(f"Coordinates must be in floating point. Got {boxes.dtype}.")
- boxes = boxes.float()
- if len(boxes.shape) == 0:
- boxes = boxes.reshape((-1, 6))
- if not (3 <= boxes.ndim <= 4 and boxes.shape[-2:] == (8, 3)):
- raise ValueError(f"3D bbox shape must be (N, 8, 3) or (B, N, 8, 3). Got {boxes.shape}.")
- self._is_batched = False if boxes.ndim == 3 else True
- self._data = boxes
- self._mode = mode
- def __getitem__(self, key: slice | int | Tensor) -> Boxes3D:
- new_box = Boxes3D(self._data[key], False, mode="xyzxyz_plus")
- new_box._mode = self._mode
- return new_box
- def __setitem__(self, key: slice | int | Tensor, value: Boxes3D) -> Boxes3D:
- self._data[key] = value._data
- return self
- @property
- def shape(self) -> tuple[int, ...] | Size:
- return self.data.shape
- def get_boxes_shape(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- r"""Compute boxes heights and widths.
- Returns:
- - Boxes depths, shape of :math:`(N,)` or :math:`(B,N)`.
- - Boxes heights, shape of :math:`(N,)` or :math:`(B,N)`.
- - Boxes widths, shape of :math:`(N,)` or :math:`(B,N)`.
- Example:
- >>> boxes_xyzxyz = torch.tensor([[ 0, 1, 2, 10, 21, 32], [3, 4, 5, 43, 54, 65]])
- >>> boxes3d = Boxes3D.from_tensor(boxes_xyzxyz)
- >>> boxes3d.get_boxes_shape()
- (tensor([30., 60.]), tensor([20., 50.]), tensor([10., 40.]))
- """
- boxes_xyzwhd = self.to_tensor(mode="xyzwhd")
- widths, heights, depths = boxes_xyzwhd[..., 3], boxes_xyzwhd[..., 4], boxes_xyzwhd[..., 5]
- return depths, heights, widths
- @classmethod
- def from_tensor(cls, boxes: torch.Tensor, mode: str = "xyzxyz", validate_boxes: bool = True) -> Boxes3D:
- r"""Create :class:`Boxes3D` from 3D boxes stored in another format.
- Args:
- boxes: 3D boxes, shape of :math:`(N,6)` or :math:`(B,N,6)`.
- mode: The format in which the 3D boxes are provided.
- * 'xyzxyz': boxes are assumed to be in the format ``xmin, ymin, zmin, xmax, ymax, zmax`` where
- ``width = xmax - xmin``, ``height = ymax - ymin`` and ``depth = zmax - zmin``.
- * 'xyzxyz_plus': similar to 'xyzxyz' mode but where box width, length and depth are defined as
- ``width = xmax - xmin + 1``, ``height = ymax - ymin + 1`` and ``depth = zmax - zmin + 1``.
- * 'xyzwhd': boxes are assumed to be in the format ``xmin, ymin, zmin, width, height, depth`` where
- ``width = xmax - xmin``, ``height = ymax - ymin`` and ``depth = zmax - zmin``.
- validate_boxes: check if boxes are valid rectangles or not. Valid rectangles are those with width, height
- and depth >= 1 (>= 2 when mode ends with '_plus' suffix).
- Returns:
- :class:`Boxes3D` class containing the original `boxes` in the format specified by ``mode``.
- Examples:
- >>> boxes_xyzxyz = torch.as_tensor([[0, 3, 6, 1, 4, 8], [5, 1, 3, 8, 4, 9]])
- >>> boxes = Boxes3D.from_tensor(boxes_xyzxyz, mode='xyzxyz')
- >>> boxes.data # (2, 8, 3)
- tensor([[[0., 3., 6.],
- [0., 3., 6.],
- [0., 3., 6.],
- [0., 3., 6.],
- [0., 3., 7.],
- [0., 3., 7.],
- [0., 3., 7.],
- [0., 3., 7.]],
- <BLANKLINE>
- [[5., 1., 3.],
- [7., 1., 3.],
- [7., 3., 3.],
- [5., 3., 3.],
- [5., 1., 8.],
- [7., 1., 8.],
- [7., 3., 8.],
- [5., 3., 8.]]])
- """
- if not (2 <= boxes.ndim <= 3 and boxes.shape[-1] == 6):
- raise ValueError(f"BBox shape must be (N, 6) or (B, N, 6). Got {boxes.shape}.")
- batched = boxes.ndim == 3
- boxes = boxes if batched else boxes.unsqueeze(0)
- boxes = boxes if boxes.is_floating_point() else boxes.float()
- xmin, ymin, zmin = boxes[..., 0], boxes[..., 1], boxes[..., 2]
- mode = mode.lower()
- if mode == "xyzxyz":
- width = boxes[..., 3] - boxes[..., 0]
- height = boxes[..., 4] - boxes[..., 1]
- depth = boxes[..., 5] - boxes[..., 2]
- elif mode == "xyzxyz_plus":
- width = boxes[..., 3] - boxes[..., 0] + 1
- height = boxes[..., 4] - boxes[..., 1] + 1
- depth = boxes[..., 5] - boxes[..., 2] + 1
- elif mode == "xyzwhd":
- depth, height, width = boxes[..., 4], boxes[..., 3], boxes[..., 5]
- else:
- raise ValueError(f"Unknown mode {mode}")
- if validate_boxes:
- if (width <= 0).any():
- raise ValueError("Some boxes have negative widths or 0.")
- if (height <= 0).any():
- raise ValueError("Some boxes have negative heights or 0.")
- if (depth <= 0).any():
- raise ValueError("Some boxes have negative depths or 0.")
- hexahedrons = _boxes3d_to_polygons3d(xmin, ymin, zmin, width, height, depth)
- hexahedrons = hexahedrons if batched else hexahedrons.squeeze(0)
- return cls(hexahedrons, raise_if_not_floating_point=False, mode=mode)
- def to_tensor(self, mode: str = "xyzxyz") -> torch.Tensor:
- r"""Cast :class:`Boxes3D` to a tensor.
- ``mode`` controls which 3D boxes format should be use to represent boxes in the tensor.
- Args:
- mode: The format in which the boxes are provided.
- * 'xyzxyz': boxes are assumed to be in the format ``xmin, ymin, zmin, xmax, ymax, zmax`` where
- ``width = xmax - xmin``, ``height = ymax - ymin`` and ``depth = zmax - zmin``.
- * 'xyzxyz_plus': similar to 'xyzxyz' mode but where box width, length and depth are defined as
- ``width = xmax - xmin + 1``, ``height = ymax - ymin + 1`` and ``depth = zmax - zmin + 1``.
- * 'xyzwhd': boxes are assumed to be in the format ``xmin, ymin, zmin, width, height, depth`` where
- ``width = xmax - xmin``, ``height = ymax - ymin`` and ``depth = zmax - zmin``.
- * 'vertices': boxes are defined by their vertices points in the following ``clockwise`` order:
- *front-top-left, front-top-right, front-bottom-right, front-bottom-left, back-top-left,
- back-top-right, back-bottom-right, back-bottom-left*. Vertices coordinates are in (x,y, z) order.
- Finally, box width, height and depth are defined as ``width = xmax - xmin``, ``height = ymax - ymin``
- and ``depth = zmax - zmin``.
- * 'vertices_plus': similar to 'vertices' mode but where box width, length and depth are defined as
- ``width = xmax - xmin + 1`` and ``height = ymax - ymin + 1``.
- Returns:
- 3D Boxes tensor in the ``mode`` format. The shape depends with the ``mode`` value:
- * 'vertices' or 'verticies_plus': :math:`(N, 8, 3)` or :math:`(B, N, 8, 3)`.
- * Any other value: :math:`(N, 6)` or :math:`(B, N, 6)`.
- Note:
- It is currently non-differentiable due to a bug. See github issue
- `#1304 <https://github.com/kornia/kornia/issues/1396>`_.
- Examples:
- >>> boxes_xyzxyz = torch.as_tensor([[0, 3, 6, 1, 4, 8], [5, 1, 3, 8, 4, 9]])
- >>> boxes = Boxes3D.from_tensor(boxes_xyzxyz, mode='xyzxyz')
- >>> assert (boxes.to_tensor(mode='xyzxyz') == boxes_xyzxyz).all()
- """
- if self._data.requires_grad:
- raise RuntimeError(
- "Boxes3D.to_tensor doesn't support computing gradients since they aren't accurate. "
- "Please, create boxes from tensors with `requires_grad=False`. "
- "This is a known bug. Help is needed to fix it. For more information, "
- "see https://github.com/kornia/kornia/issues/1396."
- )
- batched_boxes = self._data if self._is_batched else self._data.unsqueeze(0)
- # Create boxes in xyzxyz_plus format.
- boxes = stack([batched_boxes.amin(dim=-2), batched_boxes.amax(dim=-2)], dim=-2).view(
- batched_boxes.shape[0], batched_boxes.shape[1], 6
- )
- mode = mode.lower()
- if mode in ("xyzxyz", "xyzxyz_plus"):
- pass
- elif mode in ("xyzwhd", "vertices", "vertices_plus"):
- width = boxes[..., 3] - boxes[..., 0] + 1
- height = boxes[..., 4] - boxes[..., 1] + 1
- depth = boxes[..., 5] - boxes[..., 2] + 1
- boxes[..., 3] = width
- boxes[..., 4] = height
- boxes[..., 5] = depth
- else:
- raise ValueError(f"Unknown mode {mode}")
- if mode in ("xyzxyz", "vertices"):
- offset = torch.as_tensor([0, 0, 0, 1, 1, 1], device=boxes.device, dtype=boxes.dtype)
- boxes = boxes + offset
- if mode.startswith("vertices"):
- xmin, ymin, zmin = boxes[..., 0], boxes[..., 1], boxes[..., 2]
- width, height, depth = boxes[..., 3], boxes[..., 4], boxes[..., 5]
- boxes = _boxes3d_to_polygons3d(xmin, ymin, zmin, width, height, depth)
- boxes = boxes if self._is_batched else boxes.squeeze(0)
- return boxes
- def to_mask(self, depth: int, height: int, width: int) -> torch.Tensor:
- """Convert ·D boxes to masks. Covered area is 1 and the remaining is 0.
- Args:
- depth: depth of the masked image/images.
- height: height of the masked image/images.
- width: width of the masked image/images.
- Returns:
- the output mask tensor, shape of :math:`(N, depth, width, height)` or :math:`(B,N, depth, width, height)`
- and dtype of :func:`Boxes3D.dtype` (it can be any floating point dtype).
- Note:
- It is currently non-differentiable.
- Examples:
- >>> boxes = Boxes3D(torch.tensor([[ # Equivalent to boxes = Boxes.3Dfrom_tensor([[1,1,1,3,3,2]])
- ... [1., 1., 1.],
- ... [3., 1., 1.],
- ... [3., 3., 1.],
- ... [1., 3., 1.],
- ... [1., 1., 2.],
- ... [3., 1., 2.],
- ... [3., 3., 2.],
- ... [1., 3., 2.],
- ... ]])) # 1x8x3
- >>> boxes.to_mask(4, 5, 5)
- tensor([[[[0., 0., 0., 0., 0.],
- [0., 0., 0., 0., 0.],
- [0., 0., 0., 0., 0.],
- [0., 0., 0., 0., 0.],
- [0., 0., 0., 0., 0.]],
- <BLANKLINE>
- [[0., 0., 0., 0., 0.],
- [0., 1., 1., 1., 0.],
- [0., 1., 1., 1., 0.],
- [0., 1., 1., 1., 0.],
- [0., 0., 0., 0., 0.]],
- <BLANKLINE>
- [[0., 0., 0., 0., 0.],
- [0., 1., 1., 1., 0.],
- [0., 1., 1., 1., 0.],
- [0., 1., 1., 1., 0.],
- [0., 0., 0., 0., 0.]],
- <BLANKLINE>
- [[0., 0., 0., 0., 0.],
- [0., 0., 0., 0., 0.],
- [0., 0., 0., 0., 0.],
- [0., 0., 0., 0., 0.],
- [0., 0., 0., 0., 0.]]]])
- """
- if self._data.requires_grad:
- raise RuntimeError(
- "Boxes.to_tensor isn't differentiable. Please, create boxes from tensors with `requires_grad=False`."
- )
- if self._is_batched: # (B, N, 8, 3)
- mask = zeros(
- (self._data.shape[0], self._data.shape[1], depth, height, width),
- dtype=self._data.dtype,
- device=self._data.device,
- )
- else: # (N, 8, 3)
- mask = zeros((self._data.shape[0], depth, height, width), dtype=self._data.dtype, device=self._data.device)
- # Boxes coordinates can be outside the image size after transforms. Clamp values to the image size
- clipped_boxes_xyzxyz = self.to_tensor("xyzxyz")
- clipped_boxes_xyzxyz[..., ::3].clamp_(0, width)
- clipped_boxes_xyzxyz[..., 1::3].clamp_(0, height)
- clipped_boxes_xyzxyz[..., 2::3].clamp_(0, depth)
- # Reshape mask to (BxN, D, H, W) and boxes to (BxN, 6) to iterate over all of them.
- # Cast boxes coordinates to be integer to use them as indexes. Use round to handle decimal values.
- for mask_channel, box_xyzxyz in zip(
- mask.view(-1, depth, height, width), clipped_boxes_xyzxyz.view(-1, 6).round().int()
- ):
- # Mask channel dimensions: (depth, height, width)
- mask_channel[
- box_xyzxyz[2] : box_xyzxyz[5], box_xyzxyz[1] : box_xyzxyz[4], box_xyzxyz[0] : box_xyzxyz[3]
- ] = 1
- return mask
- def transform_boxes(self, M: torch.Tensor, inplace: bool = False) -> Boxes3D:
- r"""Apply a transformation matrix to the 3D boxes.
- Args:
- M: The transformation matrix to be applied, shape of :math:`(4, 4)` or :math:`(B, 4, 4)`.
- inplace: do transform in-place and return self.
- Returns:
- The transformed boxes.
- """
- if not 2 <= M.ndim <= 3 or M.shape[-2:] != (4, 4):
- raise ValueError(f"The transformation matrix shape must be (4, 4) or (B, 4, 4). Got {M.shape}.")
- transformed_boxes = _transform_boxes(self._data, M)
- if inplace:
- self._data = transformed_boxes
- return self
- return Boxes3D(transformed_boxes, False, "xyzxyz_plus")
- def transform_boxes_(self, M: torch.Tensor) -> Boxes3D:
- """Inplace version of :func:`Boxes3D.transform_boxes`."""
- return self.transform_boxes(M, inplace=True)
- @property
- def data(self) -> torch.Tensor:
- return self._data
- @property
- def mode(self) -> str:
- return self._mode
- @property
- def device(self) -> torch.device:
- """Returns boxes device."""
- return self._data.device
- @property
- def dtype(self) -> torch.dtype:
- """Returns boxes dtype."""
- return self._data.dtype
- def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> Boxes3D:
- """Like :func:`torch.nn.Module.to()` method."""
- # In torchscript, dtype is a int and not a class. https://github.com/pytorch/pytorch/issues/51941
- if dtype is not None and not _is_floating_point_dtype(dtype):
- raise ValueError("Boxes must be in floating point")
- self._data = self._data.to(device=device, dtype=dtype)
- return self
|