| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375 |
- # Copyright 2025 The HuggingFace Team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from ..utils import is_vision_available
- from .loss_for_object_detection import box_iou
- from .loss_rt_detr import RTDetrHungarianMatcher, RTDetrLoss
- if is_vision_available():
- from transformers.image_transforms import center_to_corners_format
- def _set_aux_loss(outputs_class, outputs_coord):
- return [{"logits": a, "pred_boxes": b} for a, b in zip(outputs_class, outputs_coord)]
- def _set_aux_loss2(
- outputs_class, outputs_coord, outputs_corners, outputs_ref, teacher_corners=None, teacher_logits=None
- ):
- return [
- {
- "logits": a,
- "pred_boxes": b,
- "pred_corners": c,
- "ref_points": d,
- "teacher_corners": teacher_corners,
- "teacher_logits": teacher_logits,
- }
- for a, b, c, d in zip(outputs_class, outputs_coord, outputs_corners, outputs_ref)
- ]
- def weighting_function(max_num_bins: int, up: torch.Tensor, reg_scale: int) -> torch.Tensor:
- """
- Generates the non-uniform Weighting Function W(n) for bounding box regression.
- Args:
- max_num_bins (int): Max number of the discrete bins.
- up (Tensor): Controls upper bounds of the sequence,
- where maximum offset is ±up * H / W.
- reg_scale (float): Controls the curvature of the Weighting Function.
- Larger values result in flatter weights near the central axis W(max_num_bins/2)=0
- and steeper weights at both ends.
- Returns:
- Tensor: Sequence of Weighting Function.
- """
- upper_bound1 = abs(up[0]) * abs(reg_scale)
- upper_bound2 = abs(up[0]) * abs(reg_scale) * 2
- step = (upper_bound1 + 1) ** (2 / (max_num_bins - 2))
- left_values = [-((step) ** i) + 1 for i in range(max_num_bins // 2 - 1, 0, -1)]
- right_values = [(step) ** i - 1 for i in range(1, max_num_bins // 2)]
- values = [-upper_bound2] + left_values + [torch.zeros_like(up[0][None])] + right_values + [upper_bound2]
- values = [v if v.dim() > 0 else v.unsqueeze(0) for v in values]
- values = torch.cat(values, 0)
- return values
- def translate_gt(gt: torch.Tensor, max_num_bins: int, reg_scale: int, up: torch.Tensor):
- """
- Decodes bounding box ground truth (GT) values into distribution-based GT representations.
- This function maps continuous GT values into discrete distribution bins, which can be used
- for regression tasks in object detection models. It calculates the indices of the closest
- bins to each GT value and assigns interpolation weights to these bins based on their proximity
- to the GT value.
- Args:
- gt (Tensor): Ground truth bounding box values, shape (N, ).
- max_num_bins (int): Maximum number of discrete bins for the distribution.
- reg_scale (float): Controls the curvature of the Weighting Function.
- up (Tensor): Controls the upper bounds of the Weighting Function.
- Returns:
- tuple[Tensor, Tensor, Tensor]:
- - indices (Tensor): Index of the left bin closest to each GT value, shape (N, ).
- - weight_right (Tensor): Weight assigned to the right bin, shape (N, ).
- - weight_left (Tensor): Weight assigned to the left bin, shape (N, ).
- """
- gt = gt.reshape(-1)
- function_values = weighting_function(max_num_bins, up, reg_scale)
- # Find the closest left-side indices for each value
- diffs = function_values.unsqueeze(0) - gt.unsqueeze(1)
- mask = diffs <= 0
- closest_left_indices = torch.sum(mask, dim=1) - 1
- # Calculate the weights for the interpolation
- indices = closest_left_indices.float()
- weight_right = torch.zeros_like(indices)
- weight_left = torch.zeros_like(indices)
- valid_idx_mask = (indices >= 0) & (indices < max_num_bins)
- valid_indices = indices[valid_idx_mask].long()
- # Obtain distances
- left_values = function_values[valid_indices]
- right_values = function_values[valid_indices + 1]
- left_diffs = torch.abs(gt[valid_idx_mask] - left_values)
- right_diffs = torch.abs(right_values - gt[valid_idx_mask])
- # Valid weights
- weight_right[valid_idx_mask] = left_diffs / (left_diffs + right_diffs)
- weight_left[valid_idx_mask] = 1.0 - weight_right[valid_idx_mask]
- # Invalid weights (out of range)
- invalid_idx_mask_neg = indices < 0
- weight_right[invalid_idx_mask_neg] = 0.0
- weight_left[invalid_idx_mask_neg] = 1.0
- indices[invalid_idx_mask_neg] = 0.0
- invalid_idx_mask_pos = indices >= max_num_bins
- weight_right[invalid_idx_mask_pos] = 1.0
- weight_left[invalid_idx_mask_pos] = 0.0
- indices[invalid_idx_mask_pos] = max_num_bins - 0.1
- return indices, weight_right, weight_left
- def bbox2distance(points, bbox, max_num_bins, reg_scale, up, eps=0.1):
- """
- Converts bounding box coordinates to distances from a reference point.
- Args:
- points (Tensor): (n, 4) [x, y, w, h], where (x, y) is the center.
- bbox (Tensor): (n, 4) bounding boxes in "xyxy" format.
- max_num_bins (float): Maximum bin value.
- reg_scale (float): Controlling curvarture of W(n).
- up (Tensor): Controlling upper bounds of W(n).
- eps (float): Small value to ensure target < max_num_bins.
- Returns:
- Tensor: Decoded distances.
- """
- reg_scale = abs(reg_scale)
- left = (points[:, 0] - bbox[:, 0]) / (points[..., 2] / reg_scale + 1e-16) - 0.5 * reg_scale
- top = (points[:, 1] - bbox[:, 1]) / (points[..., 3] / reg_scale + 1e-16) - 0.5 * reg_scale
- right = (bbox[:, 2] - points[:, 0]) / (points[..., 2] / reg_scale + 1e-16) - 0.5 * reg_scale
- bottom = (bbox[:, 3] - points[:, 1]) / (points[..., 3] / reg_scale + 1e-16) - 0.5 * reg_scale
- four_lens = torch.stack([left, top, right, bottom], -1)
- four_lens, weight_right, weight_left = translate_gt(four_lens, max_num_bins, reg_scale, up)
- if max_num_bins is not None:
- four_lens = four_lens.clamp(min=0, max=max_num_bins - eps)
- return four_lens.reshape(-1).detach(), weight_right.detach(), weight_left.detach()
- class DFineLoss(RTDetrLoss):
- """
- This class computes the losses for D-FINE. The process happens in two steps: 1) we compute hungarian assignment
- between ground truth boxes and the outputs of the model 2) we supervise each pair of matched ground-truth /
- prediction (supervise class and box).
- Args:
- matcher (`DetrHungarianMatcher`):
- Module able to compute a matching between targets and proposals.
- weight_dict (`Dict`):
- Dictionary relating each loss with its weights. These losses are configured in DFineConf as
- `weight_loss_vfl`, `weight_loss_bbox`, `weight_loss_giou`, `weight_loss_fgl`, `weight_loss_ddf`
- losses (`list[str]`):
- List of all the losses to be applied. See `get_loss` for a list of all available losses.
- alpha (`float`):
- Parameter alpha used to compute the focal loss.
- gamma (`float`):
- Parameter gamma used to compute the focal loss.
- eos_coef (`float`):
- Relative classification weight applied to the no-object category.
- num_classes (`int`):
- Number of object categories, omitting the special no-object category.
- """
- def __init__(self, config):
- super().__init__(config)
- self.matcher = RTDetrHungarianMatcher(config)
- self.max_num_bins = config.max_num_bins
- self.weight_dict = {
- "loss_vfl": config.weight_loss_vfl,
- "loss_bbox": config.weight_loss_bbox,
- "loss_giou": config.weight_loss_giou,
- "loss_fgl": config.weight_loss_fgl,
- "loss_ddf": config.weight_loss_ddf,
- }
- self.losses = ["vfl", "boxes", "local"]
- self.reg_scale = config.reg_scale
- self.up = nn.Parameter(torch.tensor([config.up]), requires_grad=False)
- def unimodal_distribution_focal_loss(
- self, pred, label, weight_right, weight_left, weight=None, reduction="sum", avg_factor=None
- ):
- dis_left = label.long()
- dis_right = dis_left + 1
- loss = F.cross_entropy(pred, dis_left, reduction="none") * weight_left.reshape(-1) + F.cross_entropy(
- pred, dis_right, reduction="none"
- ) * weight_right.reshape(-1)
- if weight is not None:
- weight = weight.float()
- loss = loss * weight
- if avg_factor is not None:
- loss = loss.sum() / avg_factor
- elif reduction == "mean":
- loss = loss.mean()
- elif reduction == "sum":
- loss = loss.sum()
- return loss
- def loss_local(self, outputs, targets, indices, num_boxes, T=5):
- """Compute Fine-Grained Localization (FGL) Loss
- and Decoupled Distillation Focal (DDF) Loss."""
- losses = {}
- if "pred_corners" in outputs:
- idx = self._get_source_permutation_idx(indices)
- target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
- pred_corners = outputs["pred_corners"][idx].reshape(-1, (self.max_num_bins + 1))
- ref_points = outputs["ref_points"][idx].detach()
- with torch.no_grad():
- self.fgl_targets = bbox2distance(
- ref_points,
- center_to_corners_format(target_boxes),
- self.max_num_bins,
- self.reg_scale,
- self.up,
- )
- target_corners, weight_right, weight_left = self.fgl_targets
- ious = torch.diag(
- box_iou(center_to_corners_format(outputs["pred_boxes"][idx]), center_to_corners_format(target_boxes))[
- 0
- ]
- )
- weight_targets = ious.unsqueeze(-1).repeat(1, 1, 4).reshape(-1).detach()
- losses["loss_fgl"] = self.unimodal_distribution_focal_loss(
- pred_corners,
- target_corners,
- weight_right,
- weight_left,
- weight_targets,
- avg_factor=num_boxes,
- )
- pred_corners = outputs["pred_corners"].reshape(-1, (self.max_num_bins + 1))
- target_corners = outputs["teacher_corners"].reshape(-1, (self.max_num_bins + 1))
- if torch.equal(pred_corners, target_corners):
- losses["loss_ddf"] = pred_corners.sum() * 0
- else:
- weight_targets_local = outputs["teacher_logits"].sigmoid().max(dim=-1)[0]
- mask = torch.zeros_like(weight_targets_local, dtype=torch.bool)
- mask[idx] = True
- mask = mask.unsqueeze(-1).repeat(1, 1, 4).reshape(-1)
- weight_targets_local[idx] = ious.reshape_as(weight_targets_local[idx]).to(weight_targets_local.dtype)
- weight_targets_local = weight_targets_local.unsqueeze(-1).repeat(1, 1, 4).reshape(-1).detach()
- loss_match_local = (
- weight_targets_local
- * (T**2)
- * (
- nn.KLDivLoss(reduction="none")(
- F.log_softmax(pred_corners / T, dim=1),
- F.softmax(target_corners.detach() / T, dim=1),
- )
- ).sum(-1)
- )
- batch_scale = 1 / outputs["pred_boxes"].shape[0] # it should be refined
- self.num_pos, self.num_neg = (
- (mask.sum() * batch_scale) ** 0.5,
- ((~mask).sum() * batch_scale) ** 0.5,
- )
- loss_match_local1 = loss_match_local[mask].mean() if mask.any() else 0
- loss_match_local2 = loss_match_local[~mask].mean() if (~mask).any() else 0
- losses["loss_ddf"] = (loss_match_local1 * self.num_pos + loss_match_local2 * self.num_neg) / (
- self.num_pos + self.num_neg
- )
- return losses
- def get_loss(self, loss, outputs, targets, indices, num_boxes):
- loss_map = {
- "cardinality": self.loss_cardinality,
- "local": self.loss_local,
- "boxes": self.loss_boxes,
- "focal": self.loss_labels_focal,
- "vfl": self.loss_labels_vfl,
- }
- if loss not in loss_map:
- raise ValueError(f"Loss {loss} not supported")
- return loss_map[loss](outputs, targets, indices, num_boxes)
- def DFineForObjectDetectionLoss(
- logits,
- labels,
- device,
- pred_boxes,
- config,
- outputs_class=None,
- outputs_coord=None,
- enc_topk_logits=None,
- enc_topk_bboxes=None,
- denoising_meta_values=None,
- predicted_corners=None,
- initial_reference_points=None,
- **kwargs,
- ):
- criterion = DFineLoss(config)
- criterion.to(device)
- # Second: compute the losses, based on outputs and labels
- outputs_loss = {}
- outputs_loss["logits"] = logits
- outputs_loss["pred_boxes"] = pred_boxes.clamp(min=0, max=1)
- auxiliary_outputs = None
- if config.auxiliary_loss:
- if denoising_meta_values is not None:
- dn_out_coord, outputs_coord = torch.split(
- outputs_coord.clamp(min=0, max=1), denoising_meta_values["dn_num_split"], dim=2
- )
- dn_out_class, outputs_class = torch.split(outputs_class, denoising_meta_values["dn_num_split"], dim=2)
- dn_out_corners, out_corners = torch.split(predicted_corners, denoising_meta_values["dn_num_split"], dim=2)
- dn_out_refs, out_refs = torch.split(initial_reference_points, denoising_meta_values["dn_num_split"], dim=2)
- auxiliary_outputs = _set_aux_loss2(
- outputs_class[:, :-1].transpose(0, 1),
- outputs_coord[:, :-1].transpose(0, 1),
- out_corners[:, :-1].transpose(0, 1),
- out_refs[:, :-1].transpose(0, 1),
- out_corners[:, -1],
- outputs_class[:, -1],
- )
- outputs_loss["auxiliary_outputs"] = auxiliary_outputs
- outputs_loss["auxiliary_outputs"].extend(
- _set_aux_loss([enc_topk_logits], [enc_topk_bboxes.clamp(min=0, max=1)])
- )
- dn_auxiliary_outputs = _set_aux_loss2(
- dn_out_class.transpose(0, 1),
- dn_out_coord.transpose(0, 1),
- dn_out_corners.transpose(0, 1),
- dn_out_refs.transpose(0, 1),
- dn_out_corners[:, -1],
- dn_out_class[:, -1],
- )
- outputs_loss["dn_auxiliary_outputs"] = dn_auxiliary_outputs
- outputs_loss["denoising_meta_values"] = denoising_meta_values
- loss_dict = criterion(outputs_loss, labels)
- loss = sum(loss_dict.values())
- return loss, loss_dict, auxiliary_outputs
|