loss_grounding_dino.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287
  1. # Copyright 2025 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import torch
  15. import torch.nn as nn
  16. from ..image_transforms import center_to_corners_format
  17. from ..utils import is_scipy_available
  18. from .loss_for_object_detection import HungarianMatcher, ImageLoss, _set_aux_loss, generalized_box_iou
  19. if is_scipy_available():
  20. from scipy.optimize import linear_sum_assignment
  21. # Similar to the one used in `DeformableDetr` but we reduce with sum and normalize by num_boxes
  22. # instead of mean.
  23. def sigmoid_focal_loss(
  24. inputs: torch.Tensor,
  25. targets: torch.Tensor,
  26. num_boxes: int,
  27. alpha: float = 0.25,
  28. gamma: float = 2,
  29. ):
  30. """
  31. Loss used in RetinaNet for dense detection: https://huggingface.co/papers/1708.02002.
  32. Args:
  33. inputs (`torch.FloatTensor` of arbitrary shape):
  34. The predictions for each example.
  35. targets (`torch.FloatTensor` with the same shape as `inputs`)
  36. A tensor storing the binary classification label for each element in the `inputs` (0 for the negative class
  37. and 1 for the positive class).
  38. num_boxes (`int`):
  39. The total number of boxes in the batch.
  40. alpha (`float`, *optional*, defaults to 0.25):
  41. Optional weighting factor in the range (0,1) to balance positive vs. negative examples.
  42. gamma (`int`, *optional*, defaults to 2):
  43. Exponent of the modulating factor (1 - p_t) to balance easy vs hard examples.
  44. Returns:
  45. Loss tensor
  46. """
  47. prob = inputs.sigmoid()
  48. ce_loss = nn.functional.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
  49. # add modulating factor
  50. p_t = prob * targets + (1 - prob) * (1 - targets)
  51. loss = ce_loss * ((1 - p_t) ** gamma)
  52. if alpha >= 0:
  53. alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
  54. loss = alpha_t * loss
  55. return loss.sum() / num_boxes
  56. class GroundingDinoHungarianMatcher(HungarianMatcher):
  57. @torch.no_grad()
  58. def forward(self, outputs, targets):
  59. """
  60. Args:
  61. outputs (`dict`):
  62. A dictionary that contains at least these entries:
  63. * "logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
  64. * "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates.
  65. * "label_maps": Tuple of tensors of dim [num_classes, hidden_dim].
  66. targets (`list[dict]`):
  67. A list of targets (len(targets) = batch_size), where each target is a dict containing:
  68. * "class_labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of
  69. ground-truth
  70. objects in the target) containing the class labels
  71. * "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates.
  72. Returns:
  73. `list[Tuple]`: A list of size `batch_size`, containing tuples of (index_i, index_j) where:
  74. - index_i is the indices of the selected predictions (in order)
  75. - index_j is the indices of the corresponding selected targets (in order)
  76. For each batch element, it holds: len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
  77. """
  78. batch_size, num_queries = outputs["logits"].shape[:2]
  79. # We flatten to compute the cost matrices in a batch
  80. out_prob = outputs["logits"].flatten(0, 1).sigmoid() # [batch_size * num_queries, hidden_dim]
  81. out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
  82. label_maps = outputs["label_maps"]
  83. # First take the label map for each class in each batch and then concatenate them
  84. label_maps = torch.cat([label_map[target["class_labels"]] for label_map, target in zip(label_maps, targets)])
  85. # Normalize label maps based on number of tokens per class
  86. label_maps = label_maps / label_maps.sum(dim=-1, keepdim=True)
  87. # Also concat the target labels and boxes
  88. target_bbox = torch.cat([v["boxes"] for v in targets])
  89. # Compute the classification cost.
  90. alpha = 0.25
  91. gamma = 2.0
  92. neg_cost_class = (1 - alpha) * (out_prob**gamma) * (-(1 - out_prob + 1e-8).log())
  93. pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log())
  94. # Compute the classification cost by taking pos and neg cost in the appropriate index
  95. class_cost = (pos_cost_class - neg_cost_class) @ label_maps.t()
  96. # Compute the L1 cost between boxes
  97. bbox_cost = torch.cdist(out_bbox, target_bbox, p=1)
  98. # Compute the giou cost between boxes
  99. giou_cost = -generalized_box_iou(center_to_corners_format(out_bbox), center_to_corners_format(target_bbox))
  100. # Final cost matrix
  101. cost_matrix = self.bbox_cost * bbox_cost + self.class_cost * class_cost + self.giou_cost * giou_cost
  102. cost_matrix = cost_matrix.view(batch_size, num_queries, -1).cpu()
  103. sizes = [len(v["boxes"]) for v in targets]
  104. indices = [linear_sum_assignment(c[i]) for i, c in enumerate(cost_matrix.split(sizes, -1))]
  105. return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
  106. class GroundingDinoImageLoss(ImageLoss):
  107. """
  108. This class computes the losses for `GroundingDinoForObjectDetection`. The process happens in two steps: 1) we
  109. compute hungarian assignment between ground truth boxes and the outputs of the model 2) we supervise each pair of
  110. matched ground-truth / prediction (supervise class and box).
  111. Args:
  112. matcher (`GroundingDinoHungarianMatcher`):
  113. Module able to compute a matching between targets and proposals.
  114. focal_alpha (`float`):
  115. Alpha parameter in focal loss.
  116. losses (`list[str]`):
  117. List of all the losses to be applied. See `get_loss` for a list of all available losses.
  118. """
  119. def __init__(self, matcher, focal_alpha, losses):
  120. nn.Module.__init__(self)
  121. self.matcher = matcher
  122. self.focal_alpha = focal_alpha
  123. self.losses = losses
  124. @torch.no_grad()
  125. def loss_cardinality(self, outputs, targets, indices, num_boxes):
  126. """
  127. Compute the cardinality error, i.e. the absolute error in the number of predicted non-empty boxes.
  128. This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients.
  129. """
  130. logits = outputs["logits"]
  131. device = logits.device
  132. target_lengths = torch.as_tensor([len(v["class_labels"]) for v in targets], device=device)
  133. # Count the number of predictions that are NOT "no-object" (sigmoid > 0.5 threshold)
  134. card_pred = (logits.sigmoid().max(-1).values > 0.5).sum(1)
  135. card_err = nn.functional.l1_loss(card_pred.float(), target_lengths.float())
  136. losses = {"cardinality_error": card_err}
  137. return losses
  138. def _get_target_classes_one_hot(self, outputs, targets, indices):
  139. """
  140. Create one_hot based on the matching indices
  141. """
  142. logits = outputs["logits"]
  143. # Add offsets to class_labels to select the correct label map
  144. class_labels = torch.cat(
  145. [
  146. target["class_labels"][J] + len(outputs["label_maps"][i]) if i > 0 else target["class_labels"][J]
  147. for i, (target, (_, J)) in enumerate(zip(targets, indices))
  148. ]
  149. )
  150. label_maps = torch.cat(outputs["label_maps"], dim=0)
  151. idx = self._get_source_permutation_idx(indices)
  152. target_classes_onehot = torch.zeros_like(logits, device=logits.device, dtype=torch.long)
  153. target_classes_onehot[idx] = label_maps[class_labels].to(torch.long)
  154. return target_classes_onehot
  155. def loss_labels(self, outputs, targets, indices, num_boxes):
  156. """
  157. Classification loss (Binary focal loss) targets dicts must contain the key "class_labels" containing a tensor
  158. of dim [nb_target_boxes]
  159. """
  160. if "logits" not in outputs:
  161. raise KeyError("No logits were found in the outputs")
  162. if "text_mask" not in outputs:
  163. raise KeyError("No text_mask were found in the outputs")
  164. target_classes_onehot = self._get_target_classes_one_hot(outputs, targets, indices)
  165. source_logits = outputs["logits"]
  166. text_mask = outputs["text_mask"]
  167. # Select only valid logits
  168. source_logits = torch.masked_select(source_logits, text_mask)
  169. target_classes_onehot = torch.masked_select(target_classes_onehot, text_mask)
  170. target_classes_onehot = target_classes_onehot.float()
  171. loss_ce = sigmoid_focal_loss(
  172. inputs=source_logits,
  173. targets=target_classes_onehot,
  174. num_boxes=num_boxes,
  175. alpha=self.focal_alpha,
  176. gamma=2,
  177. )
  178. losses = {"loss_ce": loss_ce}
  179. return losses
  180. def GroundingDinoForObjectDetectionLoss(
  181. logits,
  182. labels,
  183. device,
  184. pred_boxes,
  185. config,
  186. label_maps,
  187. text_mask,
  188. outputs_class=None,
  189. outputs_coord=None,
  190. encoder_logits=None,
  191. encoder_pred_boxes=None,
  192. ):
  193. # First: create the matcher
  194. matcher = GroundingDinoHungarianMatcher(
  195. class_cost=config.class_cost, bbox_cost=config.bbox_cost, giou_cost=config.giou_cost
  196. )
  197. # Second: create the criterion
  198. losses = ["labels", "boxes", "cardinality"]
  199. criterion = GroundingDinoImageLoss(
  200. matcher=matcher,
  201. focal_alpha=config.focal_alpha,
  202. losses=losses,
  203. )
  204. criterion.to(device)
  205. # Third: compute the losses, based on outputs and labels
  206. outputs_loss = {}
  207. outputs_loss["logits"] = logits
  208. outputs_loss["pred_boxes"] = pred_boxes
  209. outputs_loss["label_maps"] = label_maps
  210. outputs_loss["text_mask"] = text_mask
  211. auxiliary_outputs = None
  212. if config.auxiliary_loss:
  213. auxiliary_outputs = _set_aux_loss(outputs_class, outputs_coord)
  214. for aux_output in auxiliary_outputs:
  215. aux_output["label_maps"] = label_maps
  216. aux_output["text_mask"] = text_mask
  217. outputs_loss["auxiliary_outputs"] = auxiliary_outputs
  218. loss_dict = criterion(outputs_loss, labels)
  219. if config.two_stage:
  220. encoder_outputs_loss = {
  221. "logits": encoder_logits,
  222. "pred_boxes": encoder_pred_boxes,
  223. "label_maps": label_maps,
  224. "text_mask": text_mask,
  225. }
  226. encoder_loss_dict = criterion(encoder_outputs_loss, labels)
  227. encoder_loss_dict = {k + "_enc": v for k, v in encoder_loss_dict.items()}
  228. loss_dict.update(encoder_loss_dict)
  229. # Fourth: compute total loss, as a weighted sum of the various losses
  230. weight_dict = {
  231. "loss_ce": 2.0,
  232. "loss_bbox": config.bbox_loss_coefficient,
  233. "loss_giou": config.giou_loss_coefficient,
  234. }
  235. if config.two_stage:
  236. enc_weight_dict = {k + "_enc": v for k, v in weight_dict.items()}
  237. weight_dict.update(enc_weight_dict)
  238. if config.auxiliary_loss:
  239. aux_weight_dict = {}
  240. for i in range(config.decoder_layers - 1):
  241. aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
  242. weight_dict.update(aux_weight_dict)
  243. loss = sum(loss_dict[k] * weight_dict[k] for k in loss_dict if k in weight_dict)
  244. return loss, loss_dict, auxiliary_outputs