loss_lw_detr.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355
  1. # Copyright 2026 The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import torch
  16. import torch.nn as nn
  17. from ..utils import is_accelerate_available, is_scipy_available, is_vision_available
  18. from .loss_for_object_detection import (
  19. HungarianMatcher,
  20. _set_aux_loss,
  21. box_iou,
  22. dice_loss,
  23. generalized_box_iou,
  24. nested_tensor_from_tensor_list,
  25. sigmoid_focal_loss,
  26. )
  27. if is_vision_available():
  28. from transformers.image_transforms import center_to_corners_format
  29. if is_scipy_available():
  30. from scipy.optimize import linear_sum_assignment
  31. if is_accelerate_available():
  32. from accelerate import PartialState
  33. from accelerate.utils import reduce
  34. class LwDetrHungarianMatcher(HungarianMatcher):
  35. @torch.no_grad()
  36. def forward(self, outputs, targets, group_detr):
  37. """
  38. Differences:
  39. - out_prob = outputs["logits"].flatten(0, 1).sigmoid() instead of softmax
  40. - class_cost uses alpha and gamma
  41. """
  42. batch_size, num_queries = outputs["logits"].shape[:2]
  43. # We flatten to compute the cost matrices in a batch
  44. out_prob = outputs["logits"].flatten(0, 1).sigmoid() # [batch_size * num_queries, num_classes]
  45. out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
  46. # Also concat the target labels and boxes
  47. target_ids = torch.cat([v["class_labels"] for v in targets])
  48. target_bbox = torch.cat([v["boxes"] for v in targets])
  49. # Compute the classification cost.
  50. alpha = 0.25
  51. gamma = 2.0
  52. neg_cost_class = (1 - alpha) * (out_prob**gamma) * (-(1 - out_prob + 1e-8).log())
  53. pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log())
  54. class_cost = pos_cost_class[:, target_ids] - neg_cost_class[:, target_ids]
  55. # Compute the L1 cost between boxes, cdist only supports float32
  56. dtype = out_bbox.dtype
  57. out_bbox = out_bbox.to(torch.float32)
  58. target_bbox = target_bbox.to(torch.float32)
  59. bbox_cost = torch.cdist(out_bbox, target_bbox, p=1)
  60. bbox_cost = bbox_cost.to(dtype)
  61. # Compute the giou cost between boxes
  62. giou_cost = -generalized_box_iou(center_to_corners_format(out_bbox), center_to_corners_format(target_bbox))
  63. # Final cost matrix
  64. cost_matrix = self.bbox_cost * bbox_cost + self.class_cost * class_cost + self.giou_cost * giou_cost
  65. cost_matrix = cost_matrix.view(batch_size, num_queries, -1).cpu()
  66. sizes = [len(v["boxes"]) for v in targets]
  67. indices = []
  68. group_num_queries = num_queries // group_detr
  69. cost_matrix_list = cost_matrix.split(group_num_queries, dim=1)
  70. for group_id in range(group_detr):
  71. group_cost_matrix = cost_matrix_list[group_id]
  72. group_indices = [linear_sum_assignment(c[i]) for i, c in enumerate(group_cost_matrix.split(sizes, -1))]
  73. if group_id == 0:
  74. indices = group_indices
  75. else:
  76. indices = [
  77. (
  78. np.concatenate([indice1[0], indice2[0] + group_num_queries * group_id]),
  79. np.concatenate([indice1[1], indice2[1]]),
  80. )
  81. for indice1, indice2 in zip(indices, group_indices)
  82. ]
  83. return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
  84. class LwDetrImageLoss(nn.Module):
  85. def __init__(self, matcher, num_classes, focal_alpha, losses, group_detr):
  86. super().__init__()
  87. self.matcher = matcher
  88. self.num_classes = num_classes
  89. self.focal_alpha = focal_alpha
  90. self.losses = losses
  91. self.group_detr = group_detr
  92. # removed logging parameter, which was part of the original implementation
  93. def loss_labels(self, outputs, targets, indices, num_boxes):
  94. if "logits" not in outputs:
  95. raise KeyError("No logits were found in the outputs")
  96. source_logits = outputs["logits"]
  97. dtype = source_logits.dtype
  98. idx = self._get_source_permutation_idx(indices)
  99. target_classes_o = torch.cat([t["class_labels"][J] for t, (_, J) in zip(targets, indices)])
  100. alpha = self.focal_alpha
  101. gamma = 2
  102. src_boxes = outputs["pred_boxes"][idx]
  103. target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
  104. iou_targets = torch.diag(
  105. box_iou(center_to_corners_format(src_boxes.detach()), center_to_corners_format(target_boxes))[0]
  106. )
  107. # Convert to the same dtype as the source logits as box_iou upcasts to float32
  108. iou_targets = iou_targets.to(dtype)
  109. pos_ious = iou_targets.clone().detach()
  110. prob = source_logits.sigmoid()
  111. # init positive weights and negative weights
  112. pos_weights = torch.zeros_like(source_logits)
  113. # pow promotes to float32 under float16 CUDA autocast; cast back to preserve original dtype
  114. neg_weights = prob.pow(gamma).to(dtype)
  115. pos_ind = idx + (target_classes_o,)
  116. pos_quality = prob[pos_ind].pow(alpha) * pos_ious.pow(1 - alpha)
  117. pos_quality = torch.clamp(pos_quality, 0.01).detach().to(dtype)
  118. pos_weights[pos_ind] = pos_quality
  119. neg_weights[pos_ind] = 1 - pos_quality
  120. loss_ce = -pos_weights * prob.log() - neg_weights * (1 - prob).log()
  121. loss_ce = loss_ce.sum() / num_boxes
  122. losses = {"loss_ce": loss_ce}
  123. return losses
  124. @torch.no_grad()
  125. def loss_cardinality(self, outputs, targets, indices, num_boxes):
  126. """
  127. Compute the cardinality error, i.e. the absolute error in the number of predicted non-empty boxes.
  128. This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients.
  129. """
  130. logits = outputs["logits"]
  131. device = logits.device
  132. target_lengths = torch.as_tensor([len(v["class_labels"]) for v in targets], device=device)
  133. # Count the number of predictions that are NOT "no-object" (sigmoid > 0.5 threshold)
  134. card_pred = (logits.sigmoid().max(-1).values > 0.5).sum(1)
  135. card_err = nn.functional.l1_loss(card_pred.float(), target_lengths.float())
  136. losses = {"cardinality_error": card_err}
  137. return losses
  138. # Copied from loss.loss_for_object_detection.ImageLoss.loss_boxes
  139. def loss_boxes(self, outputs, targets, indices, num_boxes):
  140. """
  141. Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss.
  142. Targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]. The target boxes
  143. are expected in format (center_x, center_y, w, h), normalized by the image size.
  144. """
  145. if "pred_boxes" not in outputs:
  146. raise KeyError("No predicted boxes found in outputs")
  147. idx = self._get_source_permutation_idx(indices)
  148. source_boxes = outputs["pred_boxes"][idx]
  149. target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
  150. loss_bbox = nn.functional.l1_loss(source_boxes, target_boxes, reduction="none")
  151. losses = {}
  152. losses["loss_bbox"] = loss_bbox.sum() / num_boxes
  153. loss_giou = 1 - torch.diag(
  154. generalized_box_iou(center_to_corners_format(source_boxes), center_to_corners_format(target_boxes))
  155. )
  156. losses["loss_giou"] = loss_giou.sum() / num_boxes
  157. return losses
  158. # Copied from loss.loss_for_object_detection.ImageLoss.loss_masks
  159. def loss_masks(self, outputs, targets, indices, num_boxes):
  160. """
  161. Compute the losses related to the masks: the focal loss and the dice loss.
  162. Targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w].
  163. """
  164. if "pred_masks" not in outputs:
  165. raise KeyError("No predicted masks found in outputs")
  166. source_idx = self._get_source_permutation_idx(indices)
  167. target_idx = self._get_target_permutation_idx(indices)
  168. source_masks = outputs["pred_masks"]
  169. source_masks = source_masks[source_idx]
  170. masks = [t["masks"] for t in targets]
  171. # TODO use valid to mask invalid areas due to padding in loss
  172. target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()
  173. target_masks = target_masks.to(source_masks)
  174. target_masks = target_masks[target_idx]
  175. # upsample predictions to the target size
  176. source_masks = nn.functional.interpolate(
  177. source_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False
  178. )
  179. source_masks = source_masks[:, 0].flatten(1)
  180. target_masks = target_masks.flatten(1)
  181. target_masks = target_masks.view(source_masks.shape)
  182. losses = {
  183. "loss_mask": sigmoid_focal_loss(source_masks, target_masks, num_boxes),
  184. "loss_dice": dice_loss(source_masks, target_masks, num_boxes),
  185. }
  186. return losses
  187. # Copied from loss.loss_for_object_detection.ImageLoss._get_source_permutation_idx
  188. def _get_source_permutation_idx(self, indices):
  189. # permute predictions following indices
  190. batch_idx = torch.cat([torch.full_like(source, i) for i, (source, _) in enumerate(indices)])
  191. source_idx = torch.cat([source for (source, _) in indices])
  192. return batch_idx, source_idx
  193. # Copied from loss.loss_for_object_detection.ImageLoss._get_target_permutation_idx
  194. def _get_target_permutation_idx(self, indices):
  195. # permute targets following indices
  196. batch_idx = torch.cat([torch.full_like(target, i) for i, (_, target) in enumerate(indices)])
  197. target_idx = torch.cat([target for (_, target) in indices])
  198. return batch_idx, target_idx
  199. def get_loss(self, loss, outputs, targets, indices, num_boxes):
  200. loss_map = {
  201. "labels": self.loss_labels,
  202. "cardinality": self.loss_cardinality,
  203. "boxes": self.loss_boxes,
  204. "masks": self.loss_masks,
  205. }
  206. if loss not in loss_map:
  207. raise ValueError(f"Loss {loss} not supported")
  208. return loss_map[loss](outputs, targets, indices, num_boxes)
  209. def forward(self, outputs, targets):
  210. """
  211. This performs the loss computation.
  212. Args:
  213. outputs (`dict`, *optional*):
  214. Dictionary of tensors, see the output specification of the model for the format.
  215. targets (`list[dict]`, *optional*):
  216. List of dicts, such that `len(targets) == batch_size`. The expected keys in each dict depends on the
  217. losses applied, see each loss' doc.
  218. """
  219. group_detr = self.group_detr if self.training else 1
  220. outputs_without_aux_and_enc = {
  221. k: v for k, v in outputs.items() if k != "enc_outputs" and k != "auxiliary_outputs"
  222. }
  223. # Retrieve the matching between the outputs of the last layer and the targets
  224. indices = self.matcher(outputs_without_aux_and_enc, targets, group_detr)
  225. # Compute the average number of target boxes across all nodes, for normalization purposes
  226. num_boxes = sum(len(t["class_labels"]) for t in targets)
  227. num_boxes = num_boxes * group_detr
  228. num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
  229. world_size = 1
  230. if is_accelerate_available():
  231. if PartialState._shared_state != {}:
  232. num_boxes = reduce(num_boxes)
  233. world_size = PartialState().num_processes
  234. num_boxes = torch.clamp(num_boxes / world_size, min=1).item()
  235. # Compute all the requested losses
  236. losses = {}
  237. for loss in self.losses:
  238. losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes))
  239. # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
  240. if "auxiliary_outputs" in outputs:
  241. for i, auxiliary_outputs in enumerate(outputs["auxiliary_outputs"]):
  242. indices = self.matcher(auxiliary_outputs, targets, group_detr)
  243. for loss in self.losses:
  244. if loss == "masks":
  245. # Intermediate masks losses are too costly to compute, we ignore them.
  246. continue
  247. l_dict = self.get_loss(loss, auxiliary_outputs, targets, indices, num_boxes)
  248. l_dict = {k + f"_{i}": v for k, v in l_dict.items()}
  249. losses.update(l_dict)
  250. if "enc_outputs" in outputs:
  251. enc_outputs = outputs["enc_outputs"]
  252. indices = self.matcher(enc_outputs, targets, group_detr=group_detr)
  253. for loss in self.losses:
  254. l_dict = self.get_loss(loss, enc_outputs, targets, indices, num_boxes)
  255. l_dict = {k + "_enc": v for k, v in l_dict.items()}
  256. losses.update(l_dict)
  257. return losses
  258. def LwDetrForObjectDetectionLoss(
  259. logits,
  260. labels,
  261. device,
  262. pred_boxes,
  263. config,
  264. outputs_class=None,
  265. outputs_coord=None,
  266. enc_outputs_class=None,
  267. enc_outputs_coord=None,
  268. **kwargs,
  269. ):
  270. # First: create the matcher
  271. matcher = LwDetrHungarianMatcher(
  272. class_cost=config.class_cost, bbox_cost=config.bbox_cost, giou_cost=config.giou_cost
  273. )
  274. # Second: create the criterion
  275. losses = ["labels", "boxes", "cardinality"]
  276. criterion = LwDetrImageLoss(
  277. matcher=matcher,
  278. num_classes=config.num_labels,
  279. focal_alpha=config.focal_alpha,
  280. losses=losses,
  281. group_detr=config.group_detr,
  282. )
  283. criterion.to(device)
  284. # Third: compute the losses, based on outputs and labels
  285. outputs_loss = {}
  286. auxiliary_outputs = None
  287. outputs_loss["logits"] = logits
  288. outputs_loss["pred_boxes"] = pred_boxes
  289. outputs_loss["enc_outputs"] = {
  290. "logits": enc_outputs_class,
  291. "pred_boxes": enc_outputs_coord,
  292. }
  293. if config.auxiliary_loss:
  294. auxiliary_outputs = _set_aux_loss(outputs_class, outputs_coord)
  295. outputs_loss["auxiliary_outputs"] = auxiliary_outputs
  296. loss_dict = criterion(outputs_loss, labels)
  297. # Fourth: compute total loss, as a weighted sum of the various losses
  298. weight_dict = {"loss_ce": 1, "loss_bbox": config.bbox_loss_coefficient}
  299. weight_dict["loss_giou"] = config.giou_loss_coefficient
  300. if config.auxiliary_loss:
  301. aux_weight_dict = {}
  302. for i in range(config.decoder_layers - 1):
  303. aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
  304. weight_dict.update(aux_weight_dict)
  305. enc_weight_dict = {k + "_enc": v for k, v in weight_dict.items()}
  306. weight_dict.update(enc_weight_dict)
  307. loss = sum(loss_dict[k] * weight_dict[k] for k in loss_dict if k in weight_dict)
  308. return loss, loss_dict, auxiliary_outputs