asymmetric_loss.py 3.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. import torch
  2. import torch.nn as nn
  3. class AsymmetricLossMultiLabel(nn.Module):
  4. def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=False):
  5. super(AsymmetricLossMultiLabel, self).__init__()
  6. self.gamma_neg = gamma_neg
  7. self.gamma_pos = gamma_pos
  8. self.clip = clip
  9. self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
  10. self.eps = eps
  11. def forward(self, x, y):
  12. """"
  13. Parameters
  14. ----------
  15. x: input logits
  16. y: targets (multi-label binarized vector)
  17. """
  18. # Calculating Probabilities
  19. x_sigmoid = torch.sigmoid(x)
  20. xs_pos = x_sigmoid
  21. xs_neg = 1 - x_sigmoid
  22. # Asymmetric Clipping
  23. if self.clip is not None and self.clip > 0:
  24. xs_neg = (xs_neg + self.clip).clamp(max=1)
  25. # Basic CE calculation
  26. los_pos = y * torch.log(xs_pos.clamp(min=self.eps))
  27. los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps))
  28. loss = los_pos + los_neg
  29. # Asymmetric Focusing
  30. if self.gamma_neg > 0 or self.gamma_pos > 0:
  31. if self.disable_torch_grad_focal_loss:
  32. torch.set_grad_enabled(False)
  33. pt0 = xs_pos * y
  34. pt1 = xs_neg * (1 - y) # pt = p if t > 0 else 1-p
  35. pt = pt0 + pt1
  36. one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y)
  37. one_sided_w = torch.pow(1 - pt, one_sided_gamma)
  38. if self.disable_torch_grad_focal_loss:
  39. torch.set_grad_enabled(True)
  40. loss *= one_sided_w
  41. return -loss.sum()
  42. class AsymmetricLossSingleLabel(nn.Module):
  43. def __init__(self, gamma_pos=1, gamma_neg=4, eps: float = 0.1, reduction='mean'):
  44. super(AsymmetricLossSingleLabel, self).__init__()
  45. self.eps = eps
  46. self.logsoftmax = nn.LogSoftmax(dim=-1)
  47. self.targets_classes = [] # prevent gpu repeated memory allocation
  48. self.gamma_pos = gamma_pos
  49. self.gamma_neg = gamma_neg
  50. self.reduction = reduction
  51. def forward(self, inputs, target, reduction=None):
  52. """"
  53. Parameters
  54. ----------
  55. x: input logits
  56. y: targets (1-hot vector)
  57. """
  58. num_classes = inputs.size()[-1]
  59. log_preds = self.logsoftmax(inputs)
  60. self.targets_classes = torch.zeros_like(inputs).scatter_(1, target.long().unsqueeze(1), 1)
  61. # ASL weights
  62. targets = self.targets_classes
  63. anti_targets = 1 - targets
  64. xs_pos = torch.exp(log_preds)
  65. xs_neg = 1 - xs_pos
  66. xs_pos = xs_pos * targets
  67. xs_neg = xs_neg * anti_targets
  68. asymmetric_w = torch.pow(1 - xs_pos - xs_neg,
  69. self.gamma_pos * targets + self.gamma_neg * anti_targets)
  70. log_preds = log_preds * asymmetric_w
  71. if self.eps > 0: # label smoothing
  72. self.targets_classes = self.targets_classes.mul(1 - self.eps).add(self.eps / num_classes)
  73. # loss calculation
  74. loss = - self.targets_classes.mul(log_preds)
  75. loss = loss.sum(dim=-1)
  76. if self.reduction == 'mean':
  77. loss = loss.mean()
  78. return loss