binary_cross_entropy.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. """ Binary Cross Entropy w/ a few extras
  2. Hacked together by / Copyright 2021 Ross Wightman
  3. """
  4. from typing import Optional, Union
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. class BinaryCrossEntropy(nn.Module):
  9. """ BCE with optional one-hot from dense targets, label smoothing, thresholding
  10. NOTE for experiments comparing CE to BCE /w label smoothing, may remove
  11. """
  12. def __init__(
  13. self,
  14. smoothing=0.1,
  15. target_threshold: Optional[float] = None,
  16. weight: Optional[torch.Tensor] = None,
  17. reduction: str = 'mean',
  18. sum_classes: bool = False,
  19. pos_weight: Optional[Union[torch.Tensor, float]] = None,
  20. ):
  21. super(BinaryCrossEntropy, self).__init__()
  22. assert 0. <= smoothing < 1.0
  23. if pos_weight is not None:
  24. if not isinstance(pos_weight, torch.Tensor):
  25. pos_weight = torch.tensor(pos_weight)
  26. self.smoothing = smoothing
  27. self.target_threshold = target_threshold
  28. self.reduction = 'none' if sum_classes else reduction
  29. self.sum_classes = sum_classes
  30. self.register_buffer('weight', weight)
  31. self.register_buffer('pos_weight', pos_weight)
  32. def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
  33. batch_size = x.shape[0]
  34. assert batch_size == target.shape[0]
  35. if target.shape != x.shape:
  36. # NOTE currently assume smoothing or other label softening is applied upstream if targets are already sparse
  37. num_classes = x.shape[-1]
  38. # FIXME should off/on be different for smoothing w/ BCE? Other impl out there differ
  39. off_value = self.smoothing / num_classes
  40. on_value = 1. - self.smoothing + off_value
  41. target = target.long().view(-1, 1)
  42. target = torch.full(
  43. (batch_size, num_classes),
  44. off_value,
  45. device=x.device, dtype=x.dtype).scatter_(1, target, on_value)
  46. if self.target_threshold is not None:
  47. # Make target 0, or 1 if threshold set
  48. target = target.gt(self.target_threshold).to(dtype=target.dtype)
  49. loss = F.binary_cross_entropy_with_logits(
  50. x, target,
  51. self.weight,
  52. pos_weight=self.pos_weight,
  53. reduction=self.reduction,
  54. )
  55. if self.sum_classes:
  56. loss = loss.sum(-1).mean()
  57. return loss