jsd.py 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from .cross_entropy import LabelSmoothingCrossEntropy
  5. class JsdCrossEntropy(nn.Module):
  6. """ Jensen-Shannon Divergence + Cross-Entropy Loss
  7. Based on impl here: https://github.com/google-research/augmix/blob/master/imagenet.py
  8. From paper: 'AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty -
  9. https://arxiv.org/abs/1912.02781
  10. Hacked together by / Copyright 2020 Ross Wightman
  11. """
  12. def __init__(self, num_splits=3, alpha=12, smoothing=0.1):
  13. super().__init__()
  14. self.num_splits = num_splits
  15. self.alpha = alpha
  16. if smoothing is not None and smoothing > 0:
  17. self.cross_entropy_loss = LabelSmoothingCrossEntropy(smoothing)
  18. else:
  19. self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
  20. def __call__(self, output, target):
  21. split_size = output.shape[0] // self.num_splits
  22. assert split_size * self.num_splits == output.shape[0]
  23. logits_split = torch.split(output, split_size)
  24. # Cross-entropy is only computed on clean images
  25. loss = self.cross_entropy_loss(logits_split[0], target[:split_size])
  26. probs = [F.softmax(logits, dim=1) for logits in logits_split]
  27. # Clamp mixture distribution to avoid exploding KL divergence
  28. logp_mixture = torch.clamp(torch.stack(probs).mean(axis=0), 1e-7, 1).log()
  29. loss += self.alpha * sum([F.kl_div(
  30. logp_mixture, p_split, reduction='batchmean') for p_split in probs]) / len(probs)
  31. return loss