cross_entropy.py 1.1 KB

123456789101112131415161718192021222324252627282930313233343536
  1. """ Cross Entropy w/ smoothing or soft targets
  2. Hacked together by / Copyright 2021 Ross Wightman
  3. """
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. class LabelSmoothingCrossEntropy(nn.Module):
  8. """ NLL loss with label smoothing.
  9. """
  10. def __init__(self, smoothing=0.1):
  11. super(LabelSmoothingCrossEntropy, self).__init__()
  12. assert smoothing < 1.0
  13. self.smoothing = smoothing
  14. self.confidence = 1. - smoothing
  15. def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
  16. logprobs = F.log_softmax(x, dim=-1)
  17. nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
  18. nll_loss = nll_loss.squeeze(1)
  19. smooth_loss = -logprobs.mean(dim=-1)
  20. loss = self.confidence * nll_loss + self.smoothing * smooth_loss
  21. return loss.mean()
  22. class SoftTargetCrossEntropy(nn.Module):
  23. def __init__(self):
  24. super(SoftTargetCrossEntropy, self).__init__()
  25. def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
  26. loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1)
  27. return loss.mean()