| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118 |
- # LICENSE HEADER MANAGED BY add-license-header
- #
- # Copyright 2018 Kornia Team
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- import torch
- from .confusion_matrix import confusion_matrix
- def mean_iou(pred: torch.Tensor, target: torch.Tensor, num_classes: int, eps: float = 1e-6) -> torch.Tensor:
- r"""Calculate mean Intersection-Over-Union (mIOU).
- The function internally computes the confusion matrix.
- Args:
- pred : tensor with estimated targets returned by a
- classifier. The shape can be :math:`(B, *)` and must contain integer
- values between 0 and K-1.
- target: tensor with ground truth (correct) target
- values. The shape can be :math:`(B, *)` and must contain integer
- values between 0 and K-1, where targets are assumed to be provided as
- one-hot vectors.
- num_classes: total possible number of classes in target.
- eps: epsilon for numerical stability.
- Returns:
- a tensor representing the mean intersection-over union
- with shape :math:`(B, K)` where K is the number of classes.
- Example:
- >>> logits = torch.tensor([[0, 1, 0]])
- >>> target = torch.tensor([[0, 1, 0]])
- >>> mean_iou(logits, target, num_classes=3)
- tensor([[1., 1., 1.]])
- """
- if not torch.is_tensor(pred) and pred.dtype is not torch.int64:
- raise TypeError(f"Input pred type is not a torch.Tensor with torch.int64 dtype. Got {type(pred)}")
- if not torch.is_tensor(target) and target.dtype is not torch.int64:
- raise TypeError(f"Input target type is not a torch.Tensor with torch.int64 dtype. Got {type(target)}")
- if not pred.shape == target.shape:
- raise ValueError(f"Inputs pred and target must have the same shape. Got: {pred.shape} and {target.shape}")
- if not pred.device == target.device:
- raise ValueError(f"Inputs must be in the same device. Got: {pred.device} - {target.device}")
- if not isinstance(num_classes, int) or num_classes < 2:
- raise ValueError(f"The number of classes must be an integer bigger than two. Got: {num_classes}")
- # we first compute the confusion matrix
- conf_mat: torch.Tensor = confusion_matrix(pred, target, num_classes)
- # compute the actual intersection over union
- sum_over_row = torch.sum(conf_mat, dim=1)
- sum_over_col = torch.sum(conf_mat, dim=2)
- conf_mat_diag = torch.diagonal(conf_mat, dim1=-2, dim2=-1)
- denominator = sum_over_row + sum_over_col - conf_mat_diag
- # NOTE: we add epsilon so that samples that are neither in the
- # prediction or ground truth are taken into account.
- ious = (conf_mat_diag + eps) / (denominator + eps)
- return ious
- def mean_iou_bbox(boxes_1: torch.Tensor, boxes_2: torch.Tensor) -> torch.Tensor:
- """Compute the IoU of the cartesian product of two sets of boxes.
- Each box in each set shall be (x1, y1, x2, y2).
- Args:
- boxes_1: a tensor of bounding boxes in :math:`(B1, 4)`.
- boxes_2: a tensor of bounding boxes in :math:`(B2, 4)`.
- Returns:
- a tensor in dimensions :math:`(B1, B2)`, representing the
- intersection of each of the boxes in set 1 with respect to each of the boxes in set 2.
- Example:
- >>> boxes_1 = torch.tensor([[40, 40, 60, 60], [30, 40, 50, 60]])
- >>> boxes_2 = torch.tensor([[40, 50, 60, 70], [30, 40, 40, 50]])
- >>> mean_iou_bbox(boxes_1, boxes_2)
- tensor([[0.3333, 0.0000],
- [0.1429, 0.2500]])
- """
- # TODO: support more box types. e.g. xywh,
- if not (((boxes_1[:, 2] - boxes_1[:, 0]) > 0).all() or ((boxes_1[:, 3] - boxes_1[:, 1]) > 0).all()):
- raise AssertionError("Boxes_1 does not follow (x1, y1, x2, y2) format.")
- if not (((boxes_2[:, 2] - boxes_2[:, 0]) > 0).all() or ((boxes_2[:, 3] - boxes_2[:, 1]) > 0).all()):
- raise AssertionError("Boxes_2 does not follow (x1, y1, x2, y2) format.")
- # find intersection
- lower_bounds = torch.max(boxes_1[:, :2].unsqueeze(1), boxes_2[:, :2].unsqueeze(0)) # (n1, n2, 2)
- upper_bounds = torch.min(boxes_1[:, 2:].unsqueeze(1), boxes_2[:, 2:].unsqueeze(0)) # (n1, n2, 2)
- intersection_dims = torch.clamp(upper_bounds - lower_bounds, min=0) # (n1, n2, 2)
- intersection = intersection_dims[:, :, 0] * intersection_dims[:, :, 1] # (n1, n2)
- # Find areas of each box in both sets
- areas_set_1 = (boxes_1[:, 2] - boxes_1[:, 0]) * (boxes_1[:, 3] - boxes_1[:, 1]) # (n1)
- areas_set_2 = (boxes_2[:, 2] - boxes_2[:, 0]) * (boxes_2[:, 3] - boxes_2[:, 1]) # (n2)
- # Find the union
- # PyTorch auto-broadcasts singleton dimensions
- union = areas_set_1.unsqueeze(1) + areas_set_2.unsqueeze(0) - intersection # (n1, n2)
- return intersection / union # (n1, n2)
|