| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470 |
- # Copyright 2020 The HuggingFace Team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from ..utils import is_scipy_available, is_vision_available, requires_backends
- from .loss_for_object_detection import (
- box_iou,
- dice_loss,
- generalized_box_iou,
- nested_tensor_from_tensor_list,
- sigmoid_focal_loss,
- )
- if is_scipy_available():
- from scipy.optimize import linear_sum_assignment
- if is_vision_available():
- from transformers.image_transforms import center_to_corners_format
- # different for RT-DETR: not slicing the last element like in DETR one
- def _set_aux_loss(outputs_class, outputs_coord):
- return [{"logits": a, "pred_boxes": b} for a, b in zip(outputs_class, outputs_coord)]
- class RTDetrHungarianMatcher(nn.Module):
- """This class computes an assignment between the targets and the predictions of the network
- For efficiency reasons, the targets don't include the no_object. Because of this, in general, there are more
- predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, while the others are
- un-matched (and thus treated as non-objects).
- Args:
- config: RTDetrConfig
- """
- def __init__(self, config):
- super().__init__()
- requires_backends(self, ["scipy"])
- self.class_cost = config.matcher_class_cost
- self.bbox_cost = config.matcher_bbox_cost
- self.giou_cost = config.matcher_giou_cost
- self.use_focal_loss = config.use_focal_loss
- self.alpha = config.matcher_alpha
- self.gamma = config.matcher_gamma
- if self.class_cost == self.bbox_cost == self.giou_cost == 0:
- raise ValueError("All costs of the Matcher can't be 0")
- @torch.no_grad()
- def forward(self, outputs, targets):
- """Performs the matching
- Params:
- outputs: This is a dict that contains at least these entries:
- "logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
- "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates
- targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
- "class_labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
- objects in the target) containing the class labels
- "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates
- Returns:
- A list of size batch_size, containing tuples of (index_i, index_j) where:
- - index_i is the indices of the selected predictions (in order)
- - index_j is the indices of the corresponding selected targets (in order)
- For each batch element, it holds:
- len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
- """
- batch_size, num_queries = outputs["logits"].shape[:2]
- # We flatten to compute the cost matrices in a batch
- out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
- # Also concat the target labels and boxes
- target_ids = torch.cat([v["class_labels"] for v in targets])
- target_bbox = torch.cat([v["boxes"] for v in targets])
- # Compute the classification cost. Contrary to the loss, we don't use the NLL,
- # but approximate it in 1 - proba[target class].
- # The 1 is a constant that doesn't change the matching, it can be omitted.
- if self.use_focal_loss:
- out_prob = F.sigmoid(outputs["logits"].flatten(0, 1))
- out_prob = out_prob[:, target_ids]
- neg_cost_class = (1 - self.alpha) * (out_prob**self.gamma) * (-(1 - out_prob + 1e-8).log())
- pos_cost_class = self.alpha * ((1 - out_prob) ** self.gamma) * (-(out_prob + 1e-8).log())
- class_cost = pos_cost_class - neg_cost_class
- else:
- out_prob = outputs["logits"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes]
- class_cost = -out_prob[:, target_ids]
- # Compute the L1 cost between boxes
- bbox_cost = torch.cdist(out_bbox, target_bbox, p=1)
- # Compute the giou cost between boxes
- giou_cost = -generalized_box_iou(center_to_corners_format(out_bbox), center_to_corners_format(target_bbox))
- # Compute the final cost matrix
- cost_matrix = self.bbox_cost * bbox_cost + self.class_cost * class_cost + self.giou_cost * giou_cost
- cost_matrix = cost_matrix.view(batch_size, num_queries, -1).cpu()
- sizes = [len(v["boxes"]) for v in targets]
- indices = [linear_sum_assignment(c[i]) for i, c in enumerate(cost_matrix.split(sizes, -1))]
- return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
- class RTDetrLoss(nn.Module):
- """
- This class computes the losses for RTDetr. The process happens in two steps: 1) we compute hungarian assignment
- between ground truth boxes and the outputs of the model 2) we supervise each pair of matched ground-truth /
- prediction (supervise class and box).
- Args:
- matcher (`DetrHungarianMatcher`):
- Module able to compute a matching between targets and proposals.
- weight_dict (`Dict`):
- Dictionary relating each loss with its weights. These losses are configured in RTDetrConf as
- `weight_loss_vfl`, `weight_loss_bbox`, `weight_loss_giou`
- losses (`list[str]`):
- List of all the losses to be applied. See `get_loss` for a list of all available losses.
- alpha (`float`):
- Parameter alpha used to compute the focal loss.
- gamma (`float`):
- Parameter gamma used to compute the focal loss.
- eos_coef (`float`):
- Relative classification weight applied to the no-object category.
- num_classes (`int`):
- Number of object categories, omitting the special no-object category.
- """
- def __init__(self, config):
- super().__init__()
- self.matcher = RTDetrHungarianMatcher(config)
- self.num_classes = config.num_labels
- self.weight_dict = {
- "loss_vfl": config.weight_loss_vfl,
- "loss_bbox": config.weight_loss_bbox,
- "loss_giou": config.weight_loss_giou,
- }
- self.losses = ["vfl", "boxes"]
- self.eos_coef = config.eos_coefficient
- empty_weight = torch.ones(config.num_labels + 1)
- empty_weight[-1] = self.eos_coef
- self.register_buffer("empty_weight", empty_weight)
- self.alpha = config.focal_loss_alpha
- self.gamma = config.focal_loss_gamma
- def loss_labels_vfl(self, outputs, targets, indices, num_boxes, log=True):
- if "pred_boxes" not in outputs:
- raise KeyError("No predicted boxes found in outputs")
- if "logits" not in outputs:
- raise KeyError("No predicted logits found in outputs")
- idx = self._get_source_permutation_idx(indices)
- src_boxes = outputs["pred_boxes"][idx]
- target_boxes = torch.cat([_target["boxes"][i] for _target, (_, i) in zip(targets, indices)], dim=0)
- ious, _ = box_iou(center_to_corners_format(src_boxes.detach()), center_to_corners_format(target_boxes))
- ious = torch.diag(ious)
- src_logits = outputs["logits"]
- dtype = src_logits.dtype
- target_classes_original = torch.cat([_target["class_labels"][i] for _target, (_, i) in zip(targets, indices)])
- target_classes = torch.full(
- src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device
- )
- target_classes[idx] = target_classes_original
- target = F.one_hot(target_classes, num_classes=self.num_classes + 1)[..., :-1]
- target_score_original = torch.zeros_like(target_classes, dtype=dtype)
- target_score_original[idx] = ious.to(dtype)
- target_score = target_score_original.unsqueeze(-1) * target
- pred_score = F.sigmoid(src_logits.detach())
- # pow promotes to float32 under float16 CUDA autocast; cast back to preserve original dtype
- weight = (self.alpha * pred_score.pow(self.gamma) * (1 - target) + target_score).to(dtype)
- loss = F.binary_cross_entropy_with_logits(src_logits, target_score, weight=weight, reduction="none")
- loss = loss.mean(1).sum() * src_logits.shape[1] / num_boxes
- return {"loss_vfl": loss}
- def loss_labels(self, outputs, targets, indices, num_boxes, log=True):
- """Classification loss (NLL)
- targets dicts must contain the key "class_labels" containing a tensor of dim [nb_target_boxes]
- """
- if "logits" not in outputs:
- raise KeyError("No logits were found in the outputs")
- src_logits = outputs["logits"]
- idx = self._get_source_permutation_idx(indices)
- target_classes_original = torch.cat([_target["class_labels"][i] for _target, (_, i) in zip(targets, indices)])
- target_classes = torch.full(
- src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device
- )
- target_classes[idx] = target_classes_original
- loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.class_weight)
- losses = {"loss_ce": loss_ce}
- return losses
- @torch.no_grad()
- def loss_cardinality(self, outputs, targets, indices, num_boxes):
- """
- Compute the cardinality error, i.e. the absolute error in the number of predicted non-empty boxes. This is not
- really a loss, it is intended for logging purposes only. It doesn't propagate gradients.
- """
- logits = outputs["logits"]
- device = logits.device
- target_lengths = torch.as_tensor([len(v["class_labels"]) for v in targets], device=device)
- # Count the number of predictions that are NOT "no-object" (sigmoid > 0.5 threshold)
- card_pred = (logits.sigmoid().max(-1).values > 0.5).sum(1)
- card_err = nn.functional.l1_loss(card_pred.float(), target_lengths.float())
- losses = {"cardinality_error": card_err}
- return losses
- def loss_boxes(self, outputs, targets, indices, num_boxes):
- """
- Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss. Targets dicts must
- contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]. The target boxes are expected in
- format (center_x, center_y, w, h), normalized by the image size.
- """
- if "pred_boxes" not in outputs:
- raise KeyError("No predicted boxes found in outputs")
- idx = self._get_source_permutation_idx(indices)
- src_boxes = outputs["pred_boxes"][idx]
- target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
- losses = {}
- loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction="none")
- losses["loss_bbox"] = loss_bbox.sum() / num_boxes
- loss_giou = 1 - torch.diag(
- generalized_box_iou(center_to_corners_format(src_boxes), center_to_corners_format(target_boxes))
- )
- losses["loss_giou"] = loss_giou.sum() / num_boxes
- return losses
- def loss_masks(self, outputs, targets, indices, num_boxes):
- """
- Compute the losses related to the masks: the focal loss and the dice loss. Targets dicts must contain the key
- "masks" containing a tensor of dim [nb_target_boxes, h, w].
- """
- if "pred_masks" not in outputs:
- raise KeyError("No predicted masks found in outputs")
- source_idx = self._get_source_permutation_idx(indices)
- target_idx = self._get_target_permutation_idx(indices)
- source_masks = outputs["pred_masks"]
- source_masks = source_masks[source_idx]
- masks = [t["masks"] for t in targets]
- target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()
- target_masks = target_masks.to(source_masks)
- target_masks = target_masks[target_idx]
- # upsample predictions to the target size
- source_masks = nn.functional.interpolate(
- source_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False
- )
- source_masks = source_masks[:, 0].flatten(1)
- target_masks = target_masks.flatten(1)
- target_masks = target_masks.view(source_masks.shape)
- losses = {
- "loss_mask": sigmoid_focal_loss(source_masks, target_masks, num_boxes),
- "loss_dice": dice_loss(source_masks, target_masks, num_boxes),
- }
- return losses
- def loss_labels_bce(self, outputs, targets, indices, num_boxes, log=True):
- src_logits = outputs["logits"]
- idx = self._get_source_permutation_idx(indices)
- target_classes_original = torch.cat([_target["class_labels"][i] for _target, (_, i) in zip(targets, indices)])
- target_classes = torch.full(
- src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device
- )
- target_classes[idx] = target_classes_original
- target = F.one_hot(target_classes, num_classes=self.num_classes + 1)[..., :-1]
- loss = F.binary_cross_entropy_with_logits(src_logits, target * 1.0, reduction="none")
- loss = loss.mean(1).sum() * src_logits.shape[1] / num_boxes
- return {"loss_bce": loss}
- def _get_source_permutation_idx(self, indices):
- # permute predictions following indices
- batch_idx = torch.cat([torch.full_like(source, i) for i, (source, _) in enumerate(indices)])
- source_idx = torch.cat([source for (source, _) in indices])
- return batch_idx, source_idx
- def _get_target_permutation_idx(self, indices):
- # permute targets following indices
- batch_idx = torch.cat([torch.full_like(target, i) for i, (_, target) in enumerate(indices)])
- target_idx = torch.cat([target for (_, target) in indices])
- return batch_idx, target_idx
- def loss_labels_focal(self, outputs, targets, indices, num_boxes, log=True):
- if "logits" not in outputs:
- raise KeyError("No logits found in outputs")
- src_logits = outputs["logits"]
- idx = self._get_source_permutation_idx(indices)
- target_classes_original = torch.cat([_target["class_labels"][i] for _target, (_, i) in zip(targets, indices)])
- target_classes = torch.full(
- src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device
- )
- target_classes[idx] = target_classes_original
- target = F.one_hot(target_classes, num_classes=self.num_classes + 1)[..., :-1]
- loss = sigmoid_focal_loss(src_logits, target, self.alpha, self.gamma)
- loss = loss.mean(1).sum() * src_logits.shape[1] / num_boxes
- return {"loss_focal": loss}
- def get_loss(self, loss, outputs, targets, indices, num_boxes):
- loss_map = {
- "labels": self.loss_labels,
- "cardinality": self.loss_cardinality,
- "boxes": self.loss_boxes,
- "masks": self.loss_masks,
- "bce": self.loss_labels_bce,
- "focal": self.loss_labels_focal,
- "vfl": self.loss_labels_vfl,
- }
- if loss not in loss_map:
- raise ValueError(f"Loss {loss} not supported")
- return loss_map[loss](outputs, targets, indices, num_boxes)
- @staticmethod
- def get_cdn_matched_indices(dn_meta, targets):
- dn_positive_idx, dn_num_group = dn_meta["dn_positive_idx"], dn_meta["dn_num_group"]
- num_gts = [len(t["class_labels"]) for t in targets]
- device = targets[0]["class_labels"].device
- dn_match_indices = []
- for i, num_gt in enumerate(num_gts):
- if num_gt > 0:
- gt_idx = torch.arange(num_gt, dtype=torch.int64, device=device)
- gt_idx = gt_idx.tile(dn_num_group)
- assert len(dn_positive_idx[i]) == len(gt_idx)
- dn_match_indices.append((dn_positive_idx[i], gt_idx))
- else:
- dn_match_indices.append(
- (
- torch.zeros(0, dtype=torch.int64, device=device),
- torch.zeros(0, dtype=torch.int64, device=device),
- )
- )
- return dn_match_indices
- def forward(self, outputs, targets):
- """
- This performs the loss computation.
- Args:
- outputs (`dict`, *optional*):
- Dictionary of tensors, see the output specification of the model for the format.
- targets (`list[dict]`, *optional*):
- List of dicts, such that `len(targets) == batch_size`. The expected keys in each dict depends on the
- losses applied, see each loss' doc.
- """
- outputs_without_aux = {k: v for k, v in outputs.items() if "auxiliary_outputs" not in k}
- # Retrieve the matching between the outputs of the last layer and the targets
- indices = self.matcher(outputs_without_aux, targets)
- # Compute the average number of target boxes across all nodes, for normalization purposes
- num_boxes = sum(len(t["class_labels"]) for t in targets)
- num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
- num_boxes = torch.clamp(num_boxes, min=1).item()
- # Compute all the requested losses
- losses = {}
- for loss in self.losses:
- l_dict = self.get_loss(loss, outputs, targets, indices, num_boxes)
- l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict}
- losses.update(l_dict)
- # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
- if "auxiliary_outputs" in outputs:
- for i, auxiliary_outputs in enumerate(outputs["auxiliary_outputs"]):
- indices = self.matcher(auxiliary_outputs, targets)
- for loss in self.losses:
- if loss == "masks":
- # Intermediate masks losses are too costly to compute, we ignore them.
- continue
- l_dict = self.get_loss(loss, auxiliary_outputs, targets, indices, num_boxes)
- l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict}
- l_dict = {k + f"_aux_{i}": v for k, v in l_dict.items()}
- losses.update(l_dict)
- # In case of cdn auxiliary losses. For rtdetr
- if "dn_auxiliary_outputs" in outputs:
- if "denoising_meta_values" not in outputs:
- raise ValueError(
- "The output must have the 'denoising_meta_values` key. Please, ensure that 'outputs' includes a 'denoising_meta_values' entry."
- )
- indices = self.get_cdn_matched_indices(outputs["denoising_meta_values"], targets)
- num_boxes = num_boxes * outputs["denoising_meta_values"]["dn_num_group"]
- for i, auxiliary_outputs in enumerate(outputs["dn_auxiliary_outputs"]):
- # indices = self.matcher(auxiliary_outputs, targets)
- for loss in self.losses:
- if loss == "masks":
- # Intermediate masks losses are too costly to compute, we ignore them.
- continue
- kwargs = {}
- l_dict = self.get_loss(loss, auxiliary_outputs, targets, indices, num_boxes, **kwargs)
- l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict}
- l_dict = {k + f"_dn_{i}": v for k, v in l_dict.items()}
- losses.update(l_dict)
- return losses
- def RTDetrForObjectDetectionLoss(
- logits,
- labels,
- device,
- pred_boxes,
- config,
- outputs_class=None,
- outputs_coord=None,
- enc_topk_logits=None,
- enc_topk_bboxes=None,
- denoising_meta_values=None,
- **kwargs,
- ):
- criterion = RTDetrLoss(config)
- criterion.to(device)
- # Second: compute the losses, based on outputs and labels
- outputs_loss = {}
- outputs_loss["logits"] = logits
- outputs_loss["pred_boxes"] = pred_boxes
- auxiliary_outputs = None
- if config.auxiliary_loss:
- if denoising_meta_values is not None:
- dn_out_coord, outputs_coord = torch.split(outputs_coord, denoising_meta_values["dn_num_split"], dim=2)
- dn_out_class, outputs_class = torch.split(outputs_class, denoising_meta_values["dn_num_split"], dim=2)
- auxiliary_outputs = _set_aux_loss(outputs_class[:, :-1].transpose(0, 1), outputs_coord[:, :-1].transpose(0, 1))
- outputs_loss["auxiliary_outputs"] = auxiliary_outputs
- outputs_loss["auxiliary_outputs"].extend(_set_aux_loss([enc_topk_logits], [enc_topk_bboxes]))
- if denoising_meta_values is not None:
- outputs_loss["dn_auxiliary_outputs"] = _set_aux_loss(
- dn_out_class.transpose(0, 1), dn_out_coord.transpose(0, 1)
- )
- outputs_loss["denoising_meta_values"] = denoising_meta_values
- loss_dict = criterion(outputs_loss, labels)
- loss = sum(loss_dict.values())
- return loss, loss_dict, auxiliary_outputs
|