loss_rt_detr.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470
  1. # Copyright 2020 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import torch
  15. import torch.nn as nn
  16. import torch.nn.functional as F
  17. from ..utils import is_scipy_available, is_vision_available, requires_backends
  18. from .loss_for_object_detection import (
  19. box_iou,
  20. dice_loss,
  21. generalized_box_iou,
  22. nested_tensor_from_tensor_list,
  23. sigmoid_focal_loss,
  24. )
  25. if is_scipy_available():
  26. from scipy.optimize import linear_sum_assignment
  27. if is_vision_available():
  28. from transformers.image_transforms import center_to_corners_format
  29. # different for RT-DETR: not slicing the last element like in DETR one
  30. def _set_aux_loss(outputs_class, outputs_coord):
  31. return [{"logits": a, "pred_boxes": b} for a, b in zip(outputs_class, outputs_coord)]
  32. class RTDetrHungarianMatcher(nn.Module):
  33. """This class computes an assignment between the targets and the predictions of the network
  34. For efficiency reasons, the targets don't include the no_object. Because of this, in general, there are more
  35. predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, while the others are
  36. un-matched (and thus treated as non-objects).
  37. Args:
  38. config: RTDetrConfig
  39. """
  40. def __init__(self, config):
  41. super().__init__()
  42. requires_backends(self, ["scipy"])
  43. self.class_cost = config.matcher_class_cost
  44. self.bbox_cost = config.matcher_bbox_cost
  45. self.giou_cost = config.matcher_giou_cost
  46. self.use_focal_loss = config.use_focal_loss
  47. self.alpha = config.matcher_alpha
  48. self.gamma = config.matcher_gamma
  49. if self.class_cost == self.bbox_cost == self.giou_cost == 0:
  50. raise ValueError("All costs of the Matcher can't be 0")
  51. @torch.no_grad()
  52. def forward(self, outputs, targets):
  53. """Performs the matching
  54. Params:
  55. outputs: This is a dict that contains at least these entries:
  56. "logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
  57. "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates
  58. targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
  59. "class_labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
  60. objects in the target) containing the class labels
  61. "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates
  62. Returns:
  63. A list of size batch_size, containing tuples of (index_i, index_j) where:
  64. - index_i is the indices of the selected predictions (in order)
  65. - index_j is the indices of the corresponding selected targets (in order)
  66. For each batch element, it holds:
  67. len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
  68. """
  69. batch_size, num_queries = outputs["logits"].shape[:2]
  70. # We flatten to compute the cost matrices in a batch
  71. out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
  72. # Also concat the target labels and boxes
  73. target_ids = torch.cat([v["class_labels"] for v in targets])
  74. target_bbox = torch.cat([v["boxes"] for v in targets])
  75. # Compute the classification cost. Contrary to the loss, we don't use the NLL,
  76. # but approximate it in 1 - proba[target class].
  77. # The 1 is a constant that doesn't change the matching, it can be omitted.
  78. if self.use_focal_loss:
  79. out_prob = F.sigmoid(outputs["logits"].flatten(0, 1))
  80. out_prob = out_prob[:, target_ids]
  81. neg_cost_class = (1 - self.alpha) * (out_prob**self.gamma) * (-(1 - out_prob + 1e-8).log())
  82. pos_cost_class = self.alpha * ((1 - out_prob) ** self.gamma) * (-(out_prob + 1e-8).log())
  83. class_cost = pos_cost_class - neg_cost_class
  84. else:
  85. out_prob = outputs["logits"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes]
  86. class_cost = -out_prob[:, target_ids]
  87. # Compute the L1 cost between boxes
  88. bbox_cost = torch.cdist(out_bbox, target_bbox, p=1)
  89. # Compute the giou cost between boxes
  90. giou_cost = -generalized_box_iou(center_to_corners_format(out_bbox), center_to_corners_format(target_bbox))
  91. # Compute the final cost matrix
  92. cost_matrix = self.bbox_cost * bbox_cost + self.class_cost * class_cost + self.giou_cost * giou_cost
  93. cost_matrix = cost_matrix.view(batch_size, num_queries, -1).cpu()
  94. sizes = [len(v["boxes"]) for v in targets]
  95. indices = [linear_sum_assignment(c[i]) for i, c in enumerate(cost_matrix.split(sizes, -1))]
  96. return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
  97. class RTDetrLoss(nn.Module):
  98. """
  99. This class computes the losses for RTDetr. The process happens in two steps: 1) we compute hungarian assignment
  100. between ground truth boxes and the outputs of the model 2) we supervise each pair of matched ground-truth /
  101. prediction (supervise class and box).
  102. Args:
  103. matcher (`DetrHungarianMatcher`):
  104. Module able to compute a matching between targets and proposals.
  105. weight_dict (`Dict`):
  106. Dictionary relating each loss with its weights. These losses are configured in RTDetrConf as
  107. `weight_loss_vfl`, `weight_loss_bbox`, `weight_loss_giou`
  108. losses (`list[str]`):
  109. List of all the losses to be applied. See `get_loss` for a list of all available losses.
  110. alpha (`float`):
  111. Parameter alpha used to compute the focal loss.
  112. gamma (`float`):
  113. Parameter gamma used to compute the focal loss.
  114. eos_coef (`float`):
  115. Relative classification weight applied to the no-object category.
  116. num_classes (`int`):
  117. Number of object categories, omitting the special no-object category.
  118. """
  119. def __init__(self, config):
  120. super().__init__()
  121. self.matcher = RTDetrHungarianMatcher(config)
  122. self.num_classes = config.num_labels
  123. self.weight_dict = {
  124. "loss_vfl": config.weight_loss_vfl,
  125. "loss_bbox": config.weight_loss_bbox,
  126. "loss_giou": config.weight_loss_giou,
  127. }
  128. self.losses = ["vfl", "boxes"]
  129. self.eos_coef = config.eos_coefficient
  130. empty_weight = torch.ones(config.num_labels + 1)
  131. empty_weight[-1] = self.eos_coef
  132. self.register_buffer("empty_weight", empty_weight)
  133. self.alpha = config.focal_loss_alpha
  134. self.gamma = config.focal_loss_gamma
  135. def loss_labels_vfl(self, outputs, targets, indices, num_boxes, log=True):
  136. if "pred_boxes" not in outputs:
  137. raise KeyError("No predicted boxes found in outputs")
  138. if "logits" not in outputs:
  139. raise KeyError("No predicted logits found in outputs")
  140. idx = self._get_source_permutation_idx(indices)
  141. src_boxes = outputs["pred_boxes"][idx]
  142. target_boxes = torch.cat([_target["boxes"][i] for _target, (_, i) in zip(targets, indices)], dim=0)
  143. ious, _ = box_iou(center_to_corners_format(src_boxes.detach()), center_to_corners_format(target_boxes))
  144. ious = torch.diag(ious)
  145. src_logits = outputs["logits"]
  146. dtype = src_logits.dtype
  147. target_classes_original = torch.cat([_target["class_labels"][i] for _target, (_, i) in zip(targets, indices)])
  148. target_classes = torch.full(
  149. src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device
  150. )
  151. target_classes[idx] = target_classes_original
  152. target = F.one_hot(target_classes, num_classes=self.num_classes + 1)[..., :-1]
  153. target_score_original = torch.zeros_like(target_classes, dtype=dtype)
  154. target_score_original[idx] = ious.to(dtype)
  155. target_score = target_score_original.unsqueeze(-1) * target
  156. pred_score = F.sigmoid(src_logits.detach())
  157. # pow promotes to float32 under float16 CUDA autocast; cast back to preserve original dtype
  158. weight = (self.alpha * pred_score.pow(self.gamma) * (1 - target) + target_score).to(dtype)
  159. loss = F.binary_cross_entropy_with_logits(src_logits, target_score, weight=weight, reduction="none")
  160. loss = loss.mean(1).sum() * src_logits.shape[1] / num_boxes
  161. return {"loss_vfl": loss}
  162. def loss_labels(self, outputs, targets, indices, num_boxes, log=True):
  163. """Classification loss (NLL)
  164. targets dicts must contain the key "class_labels" containing a tensor of dim [nb_target_boxes]
  165. """
  166. if "logits" not in outputs:
  167. raise KeyError("No logits were found in the outputs")
  168. src_logits = outputs["logits"]
  169. idx = self._get_source_permutation_idx(indices)
  170. target_classes_original = torch.cat([_target["class_labels"][i] for _target, (_, i) in zip(targets, indices)])
  171. target_classes = torch.full(
  172. src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device
  173. )
  174. target_classes[idx] = target_classes_original
  175. loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.class_weight)
  176. losses = {"loss_ce": loss_ce}
  177. return losses
  178. @torch.no_grad()
  179. def loss_cardinality(self, outputs, targets, indices, num_boxes):
  180. """
  181. Compute the cardinality error, i.e. the absolute error in the number of predicted non-empty boxes. This is not
  182. really a loss, it is intended for logging purposes only. It doesn't propagate gradients.
  183. """
  184. logits = outputs["logits"]
  185. device = logits.device
  186. target_lengths = torch.as_tensor([len(v["class_labels"]) for v in targets], device=device)
  187. # Count the number of predictions that are NOT "no-object" (sigmoid > 0.5 threshold)
  188. card_pred = (logits.sigmoid().max(-1).values > 0.5).sum(1)
  189. card_err = nn.functional.l1_loss(card_pred.float(), target_lengths.float())
  190. losses = {"cardinality_error": card_err}
  191. return losses
  192. def loss_boxes(self, outputs, targets, indices, num_boxes):
  193. """
  194. Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss. Targets dicts must
  195. contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]. The target boxes are expected in
  196. format (center_x, center_y, w, h), normalized by the image size.
  197. """
  198. if "pred_boxes" not in outputs:
  199. raise KeyError("No predicted boxes found in outputs")
  200. idx = self._get_source_permutation_idx(indices)
  201. src_boxes = outputs["pred_boxes"][idx]
  202. target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
  203. losses = {}
  204. loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction="none")
  205. losses["loss_bbox"] = loss_bbox.sum() / num_boxes
  206. loss_giou = 1 - torch.diag(
  207. generalized_box_iou(center_to_corners_format(src_boxes), center_to_corners_format(target_boxes))
  208. )
  209. losses["loss_giou"] = loss_giou.sum() / num_boxes
  210. return losses
  211. def loss_masks(self, outputs, targets, indices, num_boxes):
  212. """
  213. Compute the losses related to the masks: the focal loss and the dice loss. Targets dicts must contain the key
  214. "masks" containing a tensor of dim [nb_target_boxes, h, w].
  215. """
  216. if "pred_masks" not in outputs:
  217. raise KeyError("No predicted masks found in outputs")
  218. source_idx = self._get_source_permutation_idx(indices)
  219. target_idx = self._get_target_permutation_idx(indices)
  220. source_masks = outputs["pred_masks"]
  221. source_masks = source_masks[source_idx]
  222. masks = [t["masks"] for t in targets]
  223. target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()
  224. target_masks = target_masks.to(source_masks)
  225. target_masks = target_masks[target_idx]
  226. # upsample predictions to the target size
  227. source_masks = nn.functional.interpolate(
  228. source_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False
  229. )
  230. source_masks = source_masks[:, 0].flatten(1)
  231. target_masks = target_masks.flatten(1)
  232. target_masks = target_masks.view(source_masks.shape)
  233. losses = {
  234. "loss_mask": sigmoid_focal_loss(source_masks, target_masks, num_boxes),
  235. "loss_dice": dice_loss(source_masks, target_masks, num_boxes),
  236. }
  237. return losses
  238. def loss_labels_bce(self, outputs, targets, indices, num_boxes, log=True):
  239. src_logits = outputs["logits"]
  240. idx = self._get_source_permutation_idx(indices)
  241. target_classes_original = torch.cat([_target["class_labels"][i] for _target, (_, i) in zip(targets, indices)])
  242. target_classes = torch.full(
  243. src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device
  244. )
  245. target_classes[idx] = target_classes_original
  246. target = F.one_hot(target_classes, num_classes=self.num_classes + 1)[..., :-1]
  247. loss = F.binary_cross_entropy_with_logits(src_logits, target * 1.0, reduction="none")
  248. loss = loss.mean(1).sum() * src_logits.shape[1] / num_boxes
  249. return {"loss_bce": loss}
  250. def _get_source_permutation_idx(self, indices):
  251. # permute predictions following indices
  252. batch_idx = torch.cat([torch.full_like(source, i) for i, (source, _) in enumerate(indices)])
  253. source_idx = torch.cat([source for (source, _) in indices])
  254. return batch_idx, source_idx
  255. def _get_target_permutation_idx(self, indices):
  256. # permute targets following indices
  257. batch_idx = torch.cat([torch.full_like(target, i) for i, (_, target) in enumerate(indices)])
  258. target_idx = torch.cat([target for (_, target) in indices])
  259. return batch_idx, target_idx
  260. def loss_labels_focal(self, outputs, targets, indices, num_boxes, log=True):
  261. if "logits" not in outputs:
  262. raise KeyError("No logits found in outputs")
  263. src_logits = outputs["logits"]
  264. idx = self._get_source_permutation_idx(indices)
  265. target_classes_original = torch.cat([_target["class_labels"][i] for _target, (_, i) in zip(targets, indices)])
  266. target_classes = torch.full(
  267. src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device
  268. )
  269. target_classes[idx] = target_classes_original
  270. target = F.one_hot(target_classes, num_classes=self.num_classes + 1)[..., :-1]
  271. loss = sigmoid_focal_loss(src_logits, target, self.alpha, self.gamma)
  272. loss = loss.mean(1).sum() * src_logits.shape[1] / num_boxes
  273. return {"loss_focal": loss}
  274. def get_loss(self, loss, outputs, targets, indices, num_boxes):
  275. loss_map = {
  276. "labels": self.loss_labels,
  277. "cardinality": self.loss_cardinality,
  278. "boxes": self.loss_boxes,
  279. "masks": self.loss_masks,
  280. "bce": self.loss_labels_bce,
  281. "focal": self.loss_labels_focal,
  282. "vfl": self.loss_labels_vfl,
  283. }
  284. if loss not in loss_map:
  285. raise ValueError(f"Loss {loss} not supported")
  286. return loss_map[loss](outputs, targets, indices, num_boxes)
  287. @staticmethod
  288. def get_cdn_matched_indices(dn_meta, targets):
  289. dn_positive_idx, dn_num_group = dn_meta["dn_positive_idx"], dn_meta["dn_num_group"]
  290. num_gts = [len(t["class_labels"]) for t in targets]
  291. device = targets[0]["class_labels"].device
  292. dn_match_indices = []
  293. for i, num_gt in enumerate(num_gts):
  294. if num_gt > 0:
  295. gt_idx = torch.arange(num_gt, dtype=torch.int64, device=device)
  296. gt_idx = gt_idx.tile(dn_num_group)
  297. assert len(dn_positive_idx[i]) == len(gt_idx)
  298. dn_match_indices.append((dn_positive_idx[i], gt_idx))
  299. else:
  300. dn_match_indices.append(
  301. (
  302. torch.zeros(0, dtype=torch.int64, device=device),
  303. torch.zeros(0, dtype=torch.int64, device=device),
  304. )
  305. )
  306. return dn_match_indices
  307. def forward(self, outputs, targets):
  308. """
  309. This performs the loss computation.
  310. Args:
  311. outputs (`dict`, *optional*):
  312. Dictionary of tensors, see the output specification of the model for the format.
  313. targets (`list[dict]`, *optional*):
  314. List of dicts, such that `len(targets) == batch_size`. The expected keys in each dict depends on the
  315. losses applied, see each loss' doc.
  316. """
  317. outputs_without_aux = {k: v for k, v in outputs.items() if "auxiliary_outputs" not in k}
  318. # Retrieve the matching between the outputs of the last layer and the targets
  319. indices = self.matcher(outputs_without_aux, targets)
  320. # Compute the average number of target boxes across all nodes, for normalization purposes
  321. num_boxes = sum(len(t["class_labels"]) for t in targets)
  322. num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
  323. num_boxes = torch.clamp(num_boxes, min=1).item()
  324. # Compute all the requested losses
  325. losses = {}
  326. for loss in self.losses:
  327. l_dict = self.get_loss(loss, outputs, targets, indices, num_boxes)
  328. l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict}
  329. losses.update(l_dict)
  330. # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
  331. if "auxiliary_outputs" in outputs:
  332. for i, auxiliary_outputs in enumerate(outputs["auxiliary_outputs"]):
  333. indices = self.matcher(auxiliary_outputs, targets)
  334. for loss in self.losses:
  335. if loss == "masks":
  336. # Intermediate masks losses are too costly to compute, we ignore them.
  337. continue
  338. l_dict = self.get_loss(loss, auxiliary_outputs, targets, indices, num_boxes)
  339. l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict}
  340. l_dict = {k + f"_aux_{i}": v for k, v in l_dict.items()}
  341. losses.update(l_dict)
  342. # In case of cdn auxiliary losses. For rtdetr
  343. if "dn_auxiliary_outputs" in outputs:
  344. if "denoising_meta_values" not in outputs:
  345. raise ValueError(
  346. "The output must have the 'denoising_meta_values` key. Please, ensure that 'outputs' includes a 'denoising_meta_values' entry."
  347. )
  348. indices = self.get_cdn_matched_indices(outputs["denoising_meta_values"], targets)
  349. num_boxes = num_boxes * outputs["denoising_meta_values"]["dn_num_group"]
  350. for i, auxiliary_outputs in enumerate(outputs["dn_auxiliary_outputs"]):
  351. # indices = self.matcher(auxiliary_outputs, targets)
  352. for loss in self.losses:
  353. if loss == "masks":
  354. # Intermediate masks losses are too costly to compute, we ignore them.
  355. continue
  356. kwargs = {}
  357. l_dict = self.get_loss(loss, auxiliary_outputs, targets, indices, num_boxes, **kwargs)
  358. l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict}
  359. l_dict = {k + f"_dn_{i}": v for k, v in l_dict.items()}
  360. losses.update(l_dict)
  361. return losses
  362. def RTDetrForObjectDetectionLoss(
  363. logits,
  364. labels,
  365. device,
  366. pred_boxes,
  367. config,
  368. outputs_class=None,
  369. outputs_coord=None,
  370. enc_topk_logits=None,
  371. enc_topk_bboxes=None,
  372. denoising_meta_values=None,
  373. **kwargs,
  374. ):
  375. criterion = RTDetrLoss(config)
  376. criterion.to(device)
  377. # Second: compute the losses, based on outputs and labels
  378. outputs_loss = {}
  379. outputs_loss["logits"] = logits
  380. outputs_loss["pred_boxes"] = pred_boxes
  381. auxiliary_outputs = None
  382. if config.auxiliary_loss:
  383. if denoising_meta_values is not None:
  384. dn_out_coord, outputs_coord = torch.split(outputs_coord, denoising_meta_values["dn_num_split"], dim=2)
  385. dn_out_class, outputs_class = torch.split(outputs_class, denoising_meta_values["dn_num_split"], dim=2)
  386. auxiliary_outputs = _set_aux_loss(outputs_class[:, :-1].transpose(0, 1), outputs_coord[:, :-1].transpose(0, 1))
  387. outputs_loss["auxiliary_outputs"] = auxiliary_outputs
  388. outputs_loss["auxiliary_outputs"].extend(_set_aux_loss([enc_topk_logits], [enc_topk_bboxes]))
  389. if denoising_meta_values is not None:
  390. outputs_loss["dn_auxiliary_outputs"] = _set_aux_loss(
  391. dn_out_class.transpose(0, 1), dn_out_coord.transpose(0, 1)
  392. )
  393. outputs_loss["denoising_meta_values"] = denoising_meta_values
  394. loss_dict = criterion(outputs_loss, labels)
  395. loss = sum(loss_dict.values())
  396. return loss, loss_dict, auxiliary_outputs