| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168 |
- # Copyright 2024 The HuggingFace Team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import torch
- import torch.nn as nn
- from torch.nn import BCEWithLogitsLoss, MSELoss
- from .loss_d_fine import DFineForObjectDetectionLoss
- from .loss_deformable_detr import DeformableDetrForObjectDetectionLoss, DeformableDetrForSegmentationLoss
- from .loss_for_object_detection import ForObjectDetectionLoss, ForSegmentationLoss
- from .loss_grounding_dino import GroundingDinoForObjectDetectionLoss
- from .loss_lw_detr import LwDetrForObjectDetectionLoss
- from .loss_rt_detr import RTDetrForObjectDetectionLoss
- def fixed_cross_entropy(
- source: torch.Tensor,
- target: torch.Tensor,
- num_items_in_batch: torch.Tensor | None = None,
- ignore_index: int = -100,
- **kwargs,
- ) -> torch.Tensor:
- reduction = "sum" if num_items_in_batch is not None else "mean"
- loss = nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction=reduction)
- if reduction == "sum":
- # just in case users pass an int for num_items_in_batch, which could be the case for custom trainer
- if torch.is_tensor(num_items_in_batch):
- num_items_in_batch = num_items_in_batch.to(loss.device)
- loss = loss / num_items_in_batch
- return loss
- def ForCausalLMLoss(
- logits,
- labels,
- vocab_size: int,
- num_items_in_batch: torch.Tensor | None = None,
- ignore_index: int = -100,
- shift_labels: torch.Tensor | None = None,
- **kwargs,
- ) -> torch.Tensor:
- # Upcast to float if we need to compute the loss to avoid potential precision issues
- logits = logits.float()
- if shift_labels is None:
- # Shift so that tokens < n predict n
- labels = nn.functional.pad(labels, (0, 1), value=ignore_index)
- shift_labels = labels[..., 1:].contiguous()
- # Flatten the tokens
- logits = logits.view(-1, vocab_size)
- shift_labels = shift_labels.view(-1)
- shift_labels = shift_labels.to(logits.device)
- loss = fixed_cross_entropy(logits, shift_labels, num_items_in_batch, ignore_index, **kwargs)
- return loss
- def ForMaskedLMLoss(
- logits: torch.Tensor,
- labels: torch.Tensor,
- vocab_size: int,
- num_items_in_batch: torch.Tensor | None = None,
- ignore_index: int = -100,
- **kwargs,
- ):
- # Upcast to float if we need to compute the loss to avoid potential precision issues
- logits = logits.float()
- # Flatten the tokens
- logits = logits.view(-1, vocab_size)
- labels = labels.view(-1)
- labels = labels.to(logits.device)
- loss = fixed_cross_entropy(logits, labels, num_items_in_batch, ignore_index, **kwargs)
- return loss
- def ForSequenceClassificationLoss(labels: torch.Tensor, pooled_logits: torch.Tensor, config, **kwargs) -> torch.Tensor:
- num_labels = config.num_labels
- if config.problem_type is None:
- if num_labels == 1:
- config.problem_type = "regression"
- elif num_labels > 1 and (labels.dtype in (torch.long, torch.int)):
- config.problem_type = "single_label_classification"
- else:
- config.problem_type = "multi_label_classification"
- labels = labels.to(pooled_logits.device)
- if config.problem_type == "regression":
- loss_fct = MSELoss()
- if num_labels == 1:
- return loss_fct(pooled_logits.squeeze(), labels.squeeze())
- else:
- return loss_fct(pooled_logits, labels)
- if config.problem_type == "single_label_classification":
- return fixed_cross_entropy(pooled_logits.view(-1, num_labels), labels.view(-1), **kwargs)
- if config.problem_type == "multi_label_classification":
- loss_fct = BCEWithLogitsLoss()
- return loss_fct(pooled_logits, labels)
- raise RuntimeError(f"Invalid problem type: {config.problem_type}")
- def ForQuestionAnsweringLoss(start_logits, end_logits, start_positions, end_positions, **kwargs):
- total_loss = None
- if start_positions is not None and end_positions is not None:
- # If we are on multi-GPU, split add a dimension
- if len(start_positions.size()) > 1:
- start_positions = start_positions.squeeze(-1).to(start_logits.device)
- if len(end_positions.size()) > 1:
- end_positions = end_positions.squeeze(-1).to(end_logits.device)
- # sometimes the start/end positions are outside our model inputs, we ignore these terms
- ignored_index = start_logits.size(1)
- start_positions = start_positions.clamp(0, ignored_index)
- end_positions = end_positions.clamp(0, ignored_index)
- start_loss = fixed_cross_entropy(start_logits, start_positions, ignore_index=ignored_index, **kwargs)
- end_loss = fixed_cross_entropy(end_logits, end_positions, ignore_index=ignored_index, **kwargs)
- total_loss = (start_loss + end_loss) / 2
- return total_loss
- def ForTokenClassification(logits: torch.Tensor, labels, config, **kwargs):
- # Upcast to float if we need to compute the loss to avoid potential precision issues
- logits = logits.view(-1, config.num_labels)
- labels = labels.view(-1).to(logits.device)
- logits = logits.float()
- # Flatten the tokens
- return fixed_cross_entropy(logits, labels, **kwargs)
- LOSS_MAPPING = {
- "ForCausalLM": ForCausalLMLoss,
- "ForMaskedLM": ForMaskedLMLoss,
- "ForQuestionAnswering": ForQuestionAnsweringLoss,
- "ForSequenceClassification": ForSequenceClassificationLoss,
- "ForImageClassification": ForSequenceClassificationLoss,
- "ForVideoClassification": ForSequenceClassificationLoss,
- "ForAudioClassification": ForSequenceClassificationLoss,
- "ForTokenClassification": ForTokenClassification,
- "ForSegmentation": ForSegmentationLoss,
- "ForObjectDetection": ForObjectDetectionLoss,
- "ForConditionalGeneration": ForCausalLMLoss,
- "DeformableDetrForObjectDetection": DeformableDetrForObjectDetectionLoss,
- "ConditionalDetrForObjectDetection": DeformableDetrForObjectDetectionLoss,
- "DabDetrForObjectDetection": DeformableDetrForObjectDetectionLoss,
- "GroundingDinoForObjectDetection": GroundingDinoForObjectDetectionLoss,
- "MMGroundingDinoForObjectDetection": GroundingDinoForObjectDetectionLoss,
- "ConditionalDetrForSegmentation": DeformableDetrForSegmentationLoss,
- "RTDetrForObjectDetection": RTDetrForObjectDetectionLoss,
- "RTDetrV2ForObjectDetection": RTDetrForObjectDetectionLoss,
- "DFineForObjectDetection": DFineForObjectDetectionLoss,
- "CsmForConditionalGeneration": ForCausalLMLoss,
- "LwDetrForObjectDetection": LwDetrForObjectDetectionLoss,
- }
|