loss_utils.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. # Copyright 2024 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import torch
  15. import torch.nn as nn
  16. from torch.nn import BCEWithLogitsLoss, MSELoss
  17. from .loss_d_fine import DFineForObjectDetectionLoss
  18. from .loss_deformable_detr import DeformableDetrForObjectDetectionLoss, DeformableDetrForSegmentationLoss
  19. from .loss_for_object_detection import ForObjectDetectionLoss, ForSegmentationLoss
  20. from .loss_grounding_dino import GroundingDinoForObjectDetectionLoss
  21. from .loss_lw_detr import LwDetrForObjectDetectionLoss
  22. from .loss_rt_detr import RTDetrForObjectDetectionLoss
  23. def fixed_cross_entropy(
  24. source: torch.Tensor,
  25. target: torch.Tensor,
  26. num_items_in_batch: torch.Tensor | None = None,
  27. ignore_index: int = -100,
  28. **kwargs,
  29. ) -> torch.Tensor:
  30. reduction = "sum" if num_items_in_batch is not None else "mean"
  31. loss = nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction=reduction)
  32. if reduction == "sum":
  33. # just in case users pass an int for num_items_in_batch, which could be the case for custom trainer
  34. if torch.is_tensor(num_items_in_batch):
  35. num_items_in_batch = num_items_in_batch.to(loss.device)
  36. loss = loss / num_items_in_batch
  37. return loss
  38. def ForCausalLMLoss(
  39. logits,
  40. labels,
  41. vocab_size: int,
  42. num_items_in_batch: torch.Tensor | None = None,
  43. ignore_index: int = -100,
  44. shift_labels: torch.Tensor | None = None,
  45. **kwargs,
  46. ) -> torch.Tensor:
  47. # Upcast to float if we need to compute the loss to avoid potential precision issues
  48. logits = logits.float()
  49. if shift_labels is None:
  50. # Shift so that tokens < n predict n
  51. labels = nn.functional.pad(labels, (0, 1), value=ignore_index)
  52. shift_labels = labels[..., 1:].contiguous()
  53. # Flatten the tokens
  54. logits = logits.view(-1, vocab_size)
  55. shift_labels = shift_labels.view(-1)
  56. shift_labels = shift_labels.to(logits.device)
  57. loss = fixed_cross_entropy(logits, shift_labels, num_items_in_batch, ignore_index, **kwargs)
  58. return loss
  59. def ForMaskedLMLoss(
  60. logits: torch.Tensor,
  61. labels: torch.Tensor,
  62. vocab_size: int,
  63. num_items_in_batch: torch.Tensor | None = None,
  64. ignore_index: int = -100,
  65. **kwargs,
  66. ):
  67. # Upcast to float if we need to compute the loss to avoid potential precision issues
  68. logits = logits.float()
  69. # Flatten the tokens
  70. logits = logits.view(-1, vocab_size)
  71. labels = labels.view(-1)
  72. labels = labels.to(logits.device)
  73. loss = fixed_cross_entropy(logits, labels, num_items_in_batch, ignore_index, **kwargs)
  74. return loss
  75. def ForSequenceClassificationLoss(labels: torch.Tensor, pooled_logits: torch.Tensor, config, **kwargs) -> torch.Tensor:
  76. num_labels = config.num_labels
  77. if config.problem_type is None:
  78. if num_labels == 1:
  79. config.problem_type = "regression"
  80. elif num_labels > 1 and (labels.dtype in (torch.long, torch.int)):
  81. config.problem_type = "single_label_classification"
  82. else:
  83. config.problem_type = "multi_label_classification"
  84. labels = labels.to(pooled_logits.device)
  85. if config.problem_type == "regression":
  86. loss_fct = MSELoss()
  87. if num_labels == 1:
  88. return loss_fct(pooled_logits.squeeze(), labels.squeeze())
  89. else:
  90. return loss_fct(pooled_logits, labels)
  91. if config.problem_type == "single_label_classification":
  92. return fixed_cross_entropy(pooled_logits.view(-1, num_labels), labels.view(-1), **kwargs)
  93. if config.problem_type == "multi_label_classification":
  94. loss_fct = BCEWithLogitsLoss()
  95. return loss_fct(pooled_logits, labels)
  96. raise RuntimeError(f"Invalid problem type: {config.problem_type}")
  97. def ForQuestionAnsweringLoss(start_logits, end_logits, start_positions, end_positions, **kwargs):
  98. total_loss = None
  99. if start_positions is not None and end_positions is not None:
  100. # If we are on multi-GPU, split add a dimension
  101. if len(start_positions.size()) > 1:
  102. start_positions = start_positions.squeeze(-1).to(start_logits.device)
  103. if len(end_positions.size()) > 1:
  104. end_positions = end_positions.squeeze(-1).to(end_logits.device)
  105. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  106. ignored_index = start_logits.size(1)
  107. start_positions = start_positions.clamp(0, ignored_index)
  108. end_positions = end_positions.clamp(0, ignored_index)
  109. start_loss = fixed_cross_entropy(start_logits, start_positions, ignore_index=ignored_index, **kwargs)
  110. end_loss = fixed_cross_entropy(end_logits, end_positions, ignore_index=ignored_index, **kwargs)
  111. total_loss = (start_loss + end_loss) / 2
  112. return total_loss
  113. def ForTokenClassification(logits: torch.Tensor, labels, config, **kwargs):
  114. # Upcast to float if we need to compute the loss to avoid potential precision issues
  115. logits = logits.view(-1, config.num_labels)
  116. labels = labels.view(-1).to(logits.device)
  117. logits = logits.float()
  118. # Flatten the tokens
  119. return fixed_cross_entropy(logits, labels, **kwargs)
  120. LOSS_MAPPING = {
  121. "ForCausalLM": ForCausalLMLoss,
  122. "ForMaskedLM": ForMaskedLMLoss,
  123. "ForQuestionAnswering": ForQuestionAnsweringLoss,
  124. "ForSequenceClassification": ForSequenceClassificationLoss,
  125. "ForImageClassification": ForSequenceClassificationLoss,
  126. "ForVideoClassification": ForSequenceClassificationLoss,
  127. "ForAudioClassification": ForSequenceClassificationLoss,
  128. "ForTokenClassification": ForTokenClassification,
  129. "ForSegmentation": ForSegmentationLoss,
  130. "ForObjectDetection": ForObjectDetectionLoss,
  131. "ForConditionalGeneration": ForCausalLMLoss,
  132. "DeformableDetrForObjectDetection": DeformableDetrForObjectDetectionLoss,
  133. "ConditionalDetrForObjectDetection": DeformableDetrForObjectDetectionLoss,
  134. "DabDetrForObjectDetection": DeformableDetrForObjectDetectionLoss,
  135. "GroundingDinoForObjectDetection": GroundingDinoForObjectDetectionLoss,
  136. "MMGroundingDinoForObjectDetection": GroundingDinoForObjectDetectionLoss,
  137. "ConditionalDetrForSegmentation": DeformableDetrForSegmentationLoss,
  138. "RTDetrForObjectDetection": RTDetrForObjectDetectionLoss,
  139. "RTDetrV2ForObjectDetection": RTDetrForObjectDetectionLoss,
  140. "DFineForObjectDetection": DFineForObjectDetectionLoss,
  141. "CsmForConditionalGeneration": ForCausalLMLoss,
  142. "LwDetrForObjectDetection": LwDetrForObjectDetectionLoss,
  143. }