focal_loss.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. import torch
  2. import torch.nn.functional as F
  3. from ..utils import _log_api_usage_once
  4. def sigmoid_focal_loss(
  5. inputs: torch.Tensor,
  6. targets: torch.Tensor,
  7. alpha: float = 0.25,
  8. gamma: float = 2,
  9. reduction: str = "none",
  10. ) -> torch.Tensor:
  11. """
  12. Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
  13. Args:
  14. inputs (Tensor): A float tensor of arbitrary shape.
  15. The predictions for each example.
  16. targets (Tensor): A float tensor with the same shape as inputs. Stores the binary
  17. classification label for each element in inputs
  18. (0 for the negative class and 1 for the positive class).
  19. alpha (float): Weighting factor in range [0, 1] to balance
  20. positive vs negative examples or -1 for ignore. Default: ``0.25``.
  21. gamma (float): Exponent of the modulating factor (1 - p_t) to
  22. balance easy vs hard examples. Default: ``2``.
  23. reduction (string): ``'none'`` | ``'mean'`` | ``'sum'``
  24. ``'none'``: No reduction will be applied to the output.
  25. ``'mean'``: The output will be averaged.
  26. ``'sum'``: The output will be summed. Default: ``'none'``.
  27. Returns:
  28. Loss tensor with the reduction option applied.
  29. """
  30. # Original implementation from https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/focal_loss.py
  31. if not (0 <= alpha <= 1) and alpha != -1:
  32. raise ValueError(f"Invalid alpha value: {alpha}. alpha must be in the range [0,1] or -1 for ignore.")
  33. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  34. _log_api_usage_once(sigmoid_focal_loss)
  35. p = torch.sigmoid(inputs)
  36. ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
  37. p_t = p * targets + (1 - p) * (1 - targets)
  38. loss = ce_loss * ((1 - p_t) ** gamma)
  39. if alpha >= 0:
  40. alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
  41. loss = alpha_t * loss
  42. # Check reduction option and return loss accordingly
  43. if reduction == "none":
  44. pass
  45. elif reduction == "mean":
  46. loss = loss.mean()
  47. elif reduction == "sum":
  48. loss = loss.sum()
  49. else:
  50. raise ValueError(
  51. f"Invalid Value for arg 'reduction': '{reduction} \n Supported reduction modes: 'none', 'mean', 'sum'"
  52. )
  53. return loss