| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533 |
- import torch
- import torchvision
- from torch import Tensor
- from torchvision.extension import _assert_has_ops
- from ..utils import _log_api_usage_once
- from ._box_convert import (
- _box_cxcywh_to_xyxy,
- _box_cxcywhr_to_xywhr,
- _box_xywh_to_xyxy,
- _box_xywhr_to_cxcywhr,
- _box_xywhr_to_xyxyxyxy,
- _box_xyxy_to_cxcywh,
- _box_xyxy_to_xywh,
- _box_xyxyxyxy_to_xywhr,
- )
- from ._utils import _upcast
- def nms(boxes: Tensor, scores: Tensor, iou_threshold: float) -> Tensor:
- """
- Performs non-maximum suppression (NMS) on the boxes according
- to their intersection-over-union (IoU).
- NMS iteratively removes lower scoring boxes which have an
- IoU greater than ``iou_threshold`` with another (higher scoring)
- box.
- If multiple boxes have the exact same score and satisfy the IoU
- criterion with respect to a reference box, the selected box is
- not guaranteed to be the same between CPU and GPU. This is similar
- to the behavior of argsort in PyTorch when repeated values are present.
- Args:
- boxes (Tensor[N, 4])): boxes to perform NMS on. They
- are expected to be in ``(x1, y1, x2, y2)`` format with ``0 <= x1 < x2`` and
- ``0 <= y1 < y2``.
- scores (Tensor[N]): scores for each one of the boxes
- iou_threshold (float): discards all overlapping boxes with IoU > iou_threshold
- Returns:
- Tensor: int64 tensor with the indices of the elements that have been kept
- by NMS, sorted in decreasing order of scores
- """
- if not torch.jit.is_scripting() and not torch.jit.is_tracing():
- _log_api_usage_once(nms)
- _assert_has_ops()
- return torch.ops.torchvision.nms(boxes, scores, iou_threshold)
- def batched_nms(
- boxes: Tensor,
- scores: Tensor,
- idxs: Tensor,
- iou_threshold: float,
- ) -> Tensor:
- """
- Performs non-maximum suppression in a batched fashion.
- Each index value correspond to a category, and NMS
- will not be applied between elements of different categories.
- Args:
- boxes (Tensor[N, 4]): boxes where NMS will be performed. They
- are expected to be in ``(x1, y1, x2, y2)`` format with ``0 <= x1 < x2`` and
- ``0 <= y1 < y2``.
- scores (Tensor[N]): scores for each one of the boxes
- idxs (Tensor[N]): indices of the categories for each one of the boxes.
- iou_threshold (float): discards all overlapping boxes with IoU > iou_threshold
- Returns:
- Tensor: int64 tensor with the indices of the elements that have been kept by NMS, sorted
- in decreasing order of scores
- """
- if not torch.jit.is_scripting() and not torch.jit.is_tracing():
- _log_api_usage_once(batched_nms)
- # Benchmarks that drove the following thresholds are at
- # https://github.com/pytorch/vision/issues/1311#issuecomment-781329339
- # and https://github.com/pytorch/vision/pull/8925
- if boxes.numel() > (4000 if boxes.device.type == "cpu" else 100_000) and not torchvision._is_tracing():
- return _batched_nms_vanilla(boxes, scores, idxs, iou_threshold)
- else:
- return _batched_nms_coordinate_trick(boxes, scores, idxs, iou_threshold)
- @torch.jit._script_if_tracing
- def _batched_nms_coordinate_trick(
- boxes: Tensor,
- scores: Tensor,
- idxs: Tensor,
- iou_threshold: float,
- ) -> Tensor:
- # strategy: in order to perform NMS independently per class,
- # we add an offset to all the boxes. The offset is dependent
- # only on the class idx, and is large enough so that boxes
- # from different classes do not overlap
- if boxes.numel() == 0:
- return torch.empty((0,), dtype=torch.int64, device=boxes.device)
- max_coordinate = boxes.max()
- offsets = idxs.to(boxes) * (max_coordinate + torch.tensor(1).to(boxes))
- boxes_for_nms = boxes + offsets[:, None]
- keep = nms(boxes_for_nms, scores, iou_threshold)
- return keep
- @torch.jit._script_if_tracing
- def _batched_nms_vanilla(
- boxes: Tensor,
- scores: Tensor,
- idxs: Tensor,
- iou_threshold: float,
- ) -> Tensor:
- # Based on Detectron2 implementation, just manually call nms() on each class independently
- keep_mask = torch.zeros_like(scores, dtype=torch.bool)
- for class_id in torch.unique(idxs):
- curr_indices = torch.where(idxs == class_id)[0]
- curr_keep_indices = nms(boxes[curr_indices], scores[curr_indices], iou_threshold)
- keep_mask[curr_indices[curr_keep_indices]] = True
- keep_indices = torch.where(keep_mask)[0]
- return keep_indices[scores[keep_indices].sort(descending=True)[1]]
- def remove_small_boxes(boxes: Tensor, min_size: float) -> Tensor:
- """
- Remove every box from ``boxes`` which contains at least one side length
- that is smaller than ``min_size``.
- .. note::
- For sanitizing a :class:`~torchvision.tv_tensors.BoundingBoxes` object, consider using
- the transform :func:`~torchvision.transforms.v2.SanitizeBoundingBoxes` instead.
- Args:
- boxes (Tensor[..., 4]): boxes in ``(x1, y1, x2, y2)`` format
- with ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
- min_size (float): minimum size
- Returns:
- Tensor[K]: indices of the boxes that have both sides
- larger than ``min_size``
- """
- if not torch.jit.is_scripting() and not torch.jit.is_tracing():
- _log_api_usage_once(remove_small_boxes)
- ws, hs = boxes[..., 2] - boxes[..., 0], boxes[..., 3] - boxes[..., 1]
- keep = (ws >= min_size) & (hs >= min_size)
- keep = torch.where(keep)[0]
- return keep
- def clip_boxes_to_image(boxes: Tensor, size: tuple[int, int]) -> Tensor:
- """
- Clip boxes so that they lie inside an image of size ``size``.
- .. note::
- For clipping a :class:`~torchvision.tv_tensors.BoundingBoxes` object, consider using
- the transform :func:`~torchvision.transforms.v2.ClampBoundingBoxes` instead.
- Args:
- boxes (Tensor[..., 4]): boxes in ``(x1, y1, x2, y2)`` format
- with ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
- size (Tuple[height, width]): size of the image
- Returns:
- Tensor[..., 4]: clipped boxes
- """
- if not torch.jit.is_scripting() and not torch.jit.is_tracing():
- _log_api_usage_once(clip_boxes_to_image)
- dim = boxes.dim()
- boxes_x = boxes[..., 0::2]
- boxes_y = boxes[..., 1::2]
- height, width = size
- if torchvision._is_tracing():
- boxes_x = torch.max(boxes_x, torch.tensor(0, dtype=boxes.dtype, device=boxes.device))
- boxes_x = torch.min(boxes_x, torch.tensor(width, dtype=boxes.dtype, device=boxes.device))
- boxes_y = torch.max(boxes_y, torch.tensor(0, dtype=boxes.dtype, device=boxes.device))
- boxes_y = torch.min(boxes_y, torch.tensor(height, dtype=boxes.dtype, device=boxes.device))
- else:
- boxes_x = boxes_x.clamp(min=0, max=width)
- boxes_y = boxes_y.clamp(min=0, max=height)
- clipped_boxes = torch.stack((boxes_x, boxes_y), dim=dim)
- return clipped_boxes.reshape(boxes.shape)
- def box_convert(boxes: Tensor, in_fmt: str, out_fmt: str) -> Tensor:
- """
- Converts :class:`torch.Tensor` boxes from a given ``in_fmt`` to ``out_fmt``.
- .. note::
- For converting a :class:`torch.Tensor` or a :class:`~torchvision.tv_tensors.BoundingBoxes` object
- between different formats,
- consider using :func:`~torchvision.transforms.v2.functional.convert_bounding_box_format` instead.
- Or see the corresponding transform :func:`~torchvision.transforms.v2.ConvertBoundingBoxFormat`.
- Supported ``in_fmt`` and ``out_fmt`` strings are:
- ``'xyxy'``: boxes are represented via corners, x1, y1 being top left and x2, y2 being bottom right.
- This is the format that torchvision utilities expect.
- ``'xywh'``: boxes are represented via corner, width and height, x1, y2 being top left, w, h being width and height.
- ``'cxcywh'``: boxes are represented via centre, width and height, cx, cy being center of box, w, h
- being width and height.
- ``'xywhr'``: boxes are represented via corner, width and height, x1, y2 being top left, w, h being width and height.
- r is rotation angle w.r.t to the box center by :math:`|r|` degrees counter clock wise in the image plan
- ``'cxcywhr'``: boxes are represented via centre, width and height, cx, cy being center of box, w, h
- being width and height.
- r is rotation angle w.r.t to the box center by :math:`|r|` degrees counter clock wise in the image plan
- ``'xyxyxyxy'``: boxes are represented via corners, x1, y1 being top left, x2, y2 top right,
- x3, y3 bottom right, and x4, y4 bottom left.
- Args:
- boxes (Tensor[..., K]): boxes which will be converted. K is the number of coordinates (4 for unrotated bounding boxes, 5 or 8 for rotated bounding boxes). Supports any number of leading batch dimensions.
- in_fmt (str): Input format of given boxes. Supported formats are ['xyxy', 'xywh', 'cxcywh', 'xywhr', 'cxcywhr', 'xyxyxyxy'].
- out_fmt (str): Output format of given boxes. Supported formats are ['xyxy', 'xywh', 'cxcywh', 'xywhr', 'cxcywhr', 'xyxyxyxy']
- Returns:
- Tensor[..., K]: Boxes into converted format.
- """
- if not torch.jit.is_scripting() and not torch.jit.is_tracing():
- _log_api_usage_once(box_convert)
- allowed_fmts = (
- "xyxy",
- "xywh",
- "cxcywh",
- "xywhr",
- "cxcywhr",
- "xyxyxyxy",
- )
- if in_fmt not in allowed_fmts or out_fmt not in allowed_fmts:
- raise ValueError(f"Unsupported Bounding Box Conversions for given in_fmt {in_fmt} and out_fmt {out_fmt}")
- if in_fmt == out_fmt:
- return boxes.clone()
- e = (in_fmt, out_fmt)
- if e == ("xywh", "xyxy"):
- boxes = _box_xywh_to_xyxy(boxes)
- elif e == ("cxcywh", "xyxy"):
- boxes = _box_cxcywh_to_xyxy(boxes)
- elif e == ("xyxy", "xywh"):
- boxes = _box_xyxy_to_xywh(boxes)
- elif e == ("xyxy", "cxcywh"):
- boxes = _box_xyxy_to_cxcywh(boxes)
- elif e == ("xywh", "cxcywh"):
- boxes = _box_xywh_to_xyxy(boxes)
- boxes = _box_xyxy_to_cxcywh(boxes)
- elif e == ("cxcywh", "xywh"):
- boxes = _box_cxcywh_to_xyxy(boxes)
- boxes = _box_xyxy_to_xywh(boxes)
- elif e == ("cxcywhr", "xywhr"):
- boxes = _box_cxcywhr_to_xywhr(boxes)
- elif e == ("xywhr", "cxcywhr"):
- boxes = _box_xywhr_to_cxcywhr(boxes)
- elif e == ("cxcywhr", "xyxyxyxy"):
- boxes = _box_cxcywhr_to_xywhr(boxes).to(boxes.dtype)
- boxes = _box_xywhr_to_xyxyxyxy(boxes)
- elif e == ("xyxyxyxy", "cxcywhr"):
- boxes = _box_xyxyxyxy_to_xywhr(boxes).to(boxes.dtype)
- boxes = _box_xywhr_to_cxcywhr(boxes)
- elif e == ("xywhr", "xyxyxyxy"):
- boxes = _box_xywhr_to_xyxyxyxy(boxes)
- elif e == ("xyxyxyxy", "xywhr"):
- boxes = _box_xyxyxyxy_to_xywhr(boxes)
- else:
- raise NotImplementedError(f"Unsupported Bounding Box Conversions for given in_fmt {e[0]} and out_fmt {e[1]}")
- return boxes
- def box_area(boxes: Tensor, fmt: str = "xyxy") -> Tensor:
- """
- Computes the area of a set of bounding boxes from a given format.
- Args:
- boxes (Tensor[..., 4]): boxes for which the area will be computed.
- fmt (str): Format of the input boxes.
- Default is "xyxy" to preserve backward compatibility.
- Supported formats are "xyxy", "xywh", and "cxcywh".
- Returns:
- Tensor[N]: Tensor containing the area for each box.
- """
- if not torch.jit.is_scripting() and not torch.jit.is_tracing():
- _log_api_usage_once(box_area)
- allowed_fmts = (
- "xyxy",
- "xywh",
- "cxcywh",
- )
- if fmt not in allowed_fmts:
- raise ValueError(f"Unsupported Bounding Box area for given format {fmt}")
- boxes = _upcast(boxes)
- if fmt == "xyxy":
- area = (boxes[..., 2] - boxes[..., 0]) * (boxes[..., 3] - boxes[..., 1])
- else:
- # For formats with width and height, area = width * height
- # Supported: cxcywh, xywh
- area = boxes[..., 2] * boxes[..., 3]
- return area
- # implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py
- # with slight modifications
- def _box_inter_union(boxes1: Tensor, boxes2: Tensor, fmt: str = "xyxy") -> tuple[Tensor, Tensor]:
- area1 = box_area(boxes1, fmt=fmt)
- area2 = box_area(boxes2, fmt=fmt)
- allowed_fmts = (
- "xyxy",
- "xywh",
- "cxcywh",
- )
- if fmt not in allowed_fmts:
- raise ValueError(f"Unsupported Box IoU Calculation for given fmt {fmt}.")
- if fmt == "xyxy":
- lt = torch.max(boxes1[..., None, :2], boxes2[..., None, :, :2]) # [...,N,M,2]
- rb = torch.min(boxes1[..., None, 2:], boxes2[..., None, :, 2:]) # [...,N,M,2]
- elif fmt == "xywh":
- lt = torch.max(boxes1[..., None, :2], boxes2[..., None, :, :2]) # [...,N,M,2]
- rb = torch.min(
- boxes1[..., None, :2] + boxes1[..., None, 2:], boxes2[..., None, :, :2] + boxes2[..., None, :, 2:]
- ) # [...,N,M,2]
- else: # fmt == "cxcywh":
- lt = torch.max(
- boxes1[..., None, :2] - boxes1[..., None, 2:] / 2, boxes2[..., None, :, :2] - boxes2[..., None, :, 2:] / 2
- ) # [N,M,2]
- rb = torch.min(
- boxes1[..., None, :2] + boxes1[..., None, 2:] / 2, boxes2[..., None, :, :2] + boxes2[..., None, :, 2:] / 2
- ) # [N,M,2]
- wh = _upcast(rb - lt).clamp(min=0) # [N,M,2]
- inter = wh[..., 0] * wh[..., 1] # [N,M]
- union = area1[..., None] + area2[..., None, :] - inter
- return inter, union
- def box_iou(boxes1: Tensor, boxes2: Tensor, fmt: str = "xyxy") -> Tensor:
- """
- Return intersection-over-union (Jaccard index) between two sets of boxes from a given format.
- Args:
- boxes1 (Tensor[..., N, 4]): first set of boxes
- boxes2 (Tensor[..., M, 4]): second set of boxes
- fmt (str): Format of the input boxes.
- Default is "xyxy" to preserve backward compatibility.
- Supported formats are "xyxy", "xywh", and "cxcywh".
- Returns:
- Tensor[..., N, M]: the NxM matrix containing the pairwise IoU values for every element
- in boxes1 and boxes2
- """
- if not torch.jit.is_scripting() and not torch.jit.is_tracing():
- _log_api_usage_once(box_iou)
- allowed_fmts = (
- "xyxy",
- "xywh",
- "cxcywh",
- )
- if fmt not in allowed_fmts:
- raise ValueError(f"Unsupported Box IoU Calculation for given format {fmt}.")
- inter, union = _box_inter_union(boxes1, boxes2, fmt=fmt)
- iou = inter / union
- return iou
- # Implementation adapted from https://github.com/facebookresearch/detr/blob/master/util/box_ops.py
- def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
- """
- Return generalized intersection-over-union (Jaccard index) between two sets of boxes.
- Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with
- ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
- Args:
- boxes1 (Tensor[..., N, 4]): first set of boxes
- boxes2 (Tensor[..., M, 4]): second set of boxes
- Returns:
- Tensor[..., N, M]: the NxM matrix containing the pairwise generalized IoU values
- for every element in boxes1 and boxes2
- """
- if not torch.jit.is_scripting() and not torch.jit.is_tracing():
- _log_api_usage_once(generalized_box_iou)
- inter, union = _box_inter_union(boxes1, boxes2)
- iou = inter / union
- lti = torch.min(boxes1[..., None, :2], boxes2[..., None, :, :2])
- rbi = torch.max(boxes1[..., None, 2:], boxes2[..., None, :, 2:])
- whi = _upcast(rbi - lti).clamp(min=0) # [N,M,2]
- areai = whi[..., 0] * whi[..., 1]
- return iou - (areai - union) / areai
- def complete_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tensor:
- """
- Return complete intersection-over-union (Jaccard index) between two sets of boxes.
- Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with
- ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
- Args:
- boxes1 (Tensor[..., N, 4]): first set of boxes
- boxes2 (Tensor[..., M, 4]): second set of boxes
- eps (float, optional): small number to prevent division by zero. Default: 1e-7
- Returns:
- Tensor[..., N, M]: the NxM matrix containing the pairwise complete IoU values
- for every element in boxes1 and boxes2
- """
- if not torch.jit.is_scripting() and not torch.jit.is_tracing():
- _log_api_usage_once(complete_box_iou)
- boxes1 = _upcast(boxes1)
- boxes2 = _upcast(boxes2)
- diou, iou = _box_diou_iou(boxes1, boxes2, eps)
- w_pred = boxes1[..., None, 2] - boxes1[..., None, 0]
- h_pred = boxes1[..., None, 3] - boxes1[..., None, 1]
- w_gt = boxes2[..., None, :, 2] - boxes2[..., None, :, 0]
- h_gt = boxes2[..., None, :, 3] - boxes2[..., None, :, 1]
- v = (4 / (torch.pi**2)) * torch.pow(torch.atan(w_pred / h_pred) - torch.atan(w_gt / h_gt), 2)
- with torch.no_grad():
- alpha = v / (1 - iou + v + eps)
- return diou - alpha * v
- def distance_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tensor:
- """
- Return distance intersection-over-union (Jaccard index) between two sets of boxes.
- Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with
- ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
- Args:
- boxes1 (Tensor[..., N, 4]): first set of boxes
- boxes2 (Tensor[..., M, 4]): second set of boxes
- eps (float, optional): small number to prevent division by zero. Default: 1e-7
- Returns:
- Tensor[..., N, M]: the NxM matrix containing the pairwise distance IoU values
- for every element in boxes1 and boxes2
- """
- if not torch.jit.is_scripting() and not torch.jit.is_tracing():
- _log_api_usage_once(distance_box_iou)
- boxes1 = _upcast(boxes1)
- boxes2 = _upcast(boxes2)
- diou, _ = _box_diou_iou(boxes1, boxes2, eps=eps)
- return diou
- def _box_diou_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> tuple[Tensor, Tensor]:
- iou = box_iou(boxes1, boxes2)
- lti = torch.min(boxes1[..., None, :2], boxes2[..., None, :, :2])
- rbi = torch.max(boxes1[..., None, 2:], boxes2[..., None, :, 2:])
- whi = _upcast(rbi - lti).clamp(min=0) # [N,M,2]
- diagonal_distance_squared = (whi[..., 0] ** 2) + (whi[..., 1] ** 2) + eps
- # centers of boxes
- x_p = (boxes1[..., 0] + boxes1[..., 2]) / 2
- y_p = (boxes1[..., 1] + boxes1[..., 3]) / 2
- x_g = (boxes2[..., 0] + boxes2[..., 2]) / 2
- y_g = (boxes2[..., 1] + boxes2[..., 3]) / 2
- # The distance between boxes' centers squared.
- centers_distance_squared = (_upcast(x_p[..., None] - x_g[..., None, :]) ** 2) + (
- _upcast(y_p[..., None] - y_g[..., None, :]) ** 2
- )
- # The distance IoU is the IoU penalized by a normalized
- # distance between boxes' centers squared.
- return iou - (centers_distance_squared / diagonal_distance_squared), iou
- def masks_to_boxes(masks: torch.Tensor) -> torch.Tensor:
- """
- Compute the bounding boxes around the provided masks.
- Returns a [N, 4] tensor containing bounding boxes. The boxes are in ``(x1, y1, x2, y2)`` format with
- ``0 <= x1 <= x2`` and ``0 <= y1 <= y2``.
- .. note::
- Empty masks (all zeros) will return bounding boxes ``[0, 0, 0, 0]``.
- .. warning::
- In most cases the output will guarantee ``x1 < x2`` and ``y1 < y2``. But
- if the input is degenerate, e.g. if a mask is a single row or a single
- column, then the output may have x1 = x2 or y1 = y2.
- Args:
- masks (Tensor[N, H, W]): masks to transform where N is the number of masks
- and (H, W) are the spatial dimensions.
- Returns:
- Tensor[N, 4]: bounding boxes
- """
- if not torch.jit.is_scripting() and not torch.jit.is_tracing():
- _log_api_usage_once(masks_to_boxes)
- if masks.numel() == 0:
- return torch.zeros((0, 4), device=masks.device, dtype=torch.float)
- n, h, w = masks.shape
- masks_bool = masks.bool()
- non_zero_rows = torch.any(masks_bool, dim=2)
- non_zero_cols = torch.any(masks_bool, dim=1)
- empty_masks = ~torch.any(non_zero_rows, dim=1)
- non_zero_rows_f = non_zero_rows.float()
- non_zero_cols_f = non_zero_cols.float()
- y1 = non_zero_rows_f.argmax(dim=1)
- x1 = non_zero_cols_f.argmax(dim=1)
- y2 = (h - 1) - non_zero_rows_f.flip(dims=[1]).argmax(dim=1)
- x2 = (w - 1) - non_zero_cols_f.flip(dims=[1]).argmax(dim=1)
- bounding_boxes = torch.stack([x1, y1, x2, y2], dim=1).float()
- bounding_boxes[empty_masks] = 0
- return bounding_boxes
|