| 123456789101112131415161718192021222324252627282930313233343536 |
- """ Cross Entropy w/ smoothing or soft targets
- Hacked together by / Copyright 2021 Ross Wightman
- """
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- class LabelSmoothingCrossEntropy(nn.Module):
- """ NLL loss with label smoothing.
- """
- def __init__(self, smoothing=0.1):
- super(LabelSmoothingCrossEntropy, self).__init__()
- assert smoothing < 1.0
- self.smoothing = smoothing
- self.confidence = 1. - smoothing
- def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
- logprobs = F.log_softmax(x, dim=-1)
- nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
- nll_loss = nll_loss.squeeze(1)
- smooth_loss = -logprobs.mean(dim=-1)
- loss = self.confidence * nll_loss + self.smoothing * smooth_loss
- return loss.mean()
- class SoftTargetCrossEntropy(nn.Module):
- def __init__(self):
- super(SoftTargetCrossEntropy, self).__init__()
- def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
- loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1)
- return loss.mean()
|