focal.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from __future__ import annotations
  18. from typing import Optional
  19. import torch
  20. from torch import nn
  21. from kornia.core import Tensor, tensor
  22. from kornia.core.check import KORNIA_CHECK, KORNIA_CHECK_IS_TENSOR, KORNIA_CHECK_SHAPE
  23. from kornia.losses._utils import mask_ignore_pixels
  24. from kornia.utils.one_hot import one_hot
  25. # based on:
  26. # https://github.com/zhezh/focalloss/blob/master/focalloss.py
  27. def focal_loss(
  28. pred: Tensor,
  29. target: Tensor,
  30. alpha: Optional[float],
  31. gamma: float = 2.0,
  32. reduction: str = "none",
  33. weight: Optional[Tensor] = None,
  34. ignore_index: Optional[int] = -100,
  35. ) -> Tensor:
  36. r"""Criterion that computes Focal loss.
  37. According to :cite:`lin2018focal`, the Focal loss is computed as follows:
  38. .. math::
  39. \text{FL}(p_t) = -\alpha_t (1 - p_t)^{\gamma} \, \text{log}(p_t)
  40. Where:
  41. - :math:`p_t` is the model's estimated probability for each class.
  42. Args:
  43. pred: logits tensor with shape :math:`(N, C, *)` where C = number of classes.
  44. target: labels tensor with shape :math:`(N, *)` where each value is an integer
  45. representing correct classification :math:`target[i] \in [0, C)`.
  46. alpha: Weighting factor :math:`\alpha \in [0, 1]`.
  47. gamma: Focusing parameter :math:`\gamma >= 0`.
  48. reduction: Specifies the reduction to apply to the
  49. output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction
  50. will be applied, ``'mean'``: the sum of the output will be divided by
  51. the number of elements in the output, ``'sum'``: the output will be
  52. summed.
  53. weight: weights for classes with shape :math:`(num\_of\_classes,)`.
  54. ignore_index: labels with this value are ignored in the loss computation.
  55. Return:
  56. the computed loss.
  57. Example:
  58. >>> C = 5 # num_classes
  59. >>> pred = torch.randn(1, C, 3, 5, requires_grad=True)
  60. >>> target = torch.randint(C, (1, 3, 5))
  61. >>> kwargs = {"alpha": 0.5, "gamma": 2.0, "reduction": 'mean'}
  62. >>> output = focal_loss(pred, target, **kwargs)
  63. >>> output.backward()
  64. """
  65. KORNIA_CHECK_SHAPE(pred, ["B", "C", "*"])
  66. out_size = (pred.shape[0],) + pred.shape[2:]
  67. KORNIA_CHECK(
  68. (pred.shape[0] == target.shape[0] and target.shape[1:] == pred.shape[2:]),
  69. f"Expected target size {out_size}, got {target.shape}",
  70. )
  71. KORNIA_CHECK(
  72. pred.device == target.device,
  73. f"pred and target must be in the same device. Got: {pred.device} and {target.device}",
  74. )
  75. target, target_mask = mask_ignore_pixels(target, ignore_index)
  76. # create the labels one hot tensor
  77. target_one_hot: Tensor = one_hot(target, num_classes=pred.shape[1], device=pred.device, dtype=pred.dtype)
  78. # mask ignore pixels
  79. if target_mask is not None:
  80. target_mask.unsqueeze_(1)
  81. target_one_hot = target_one_hot * target_mask
  82. # compute softmax over the classes axis
  83. log_pred_soft: Tensor = pred.log_softmax(1)
  84. # compute the actual focal loss
  85. loss_tmp: Tensor = -torch.pow(1.0 - log_pred_soft.exp(), gamma) * log_pred_soft * target_one_hot
  86. num_of_classes = pred.shape[1]
  87. broadcast_dims = [-1] + [1] * len(pred.shape[2:])
  88. if alpha is not None:
  89. alpha_fac = tensor([1 - alpha] + [alpha] * (num_of_classes - 1), dtype=loss_tmp.dtype, device=loss_tmp.device)
  90. alpha_fac = alpha_fac.view(broadcast_dims)
  91. loss_tmp = alpha_fac * loss_tmp
  92. if weight is not None:
  93. KORNIA_CHECK_IS_TENSOR(weight, "weight must be Tensor or None.")
  94. KORNIA_CHECK(
  95. (weight.shape[0] == num_of_classes and weight.numel() == num_of_classes),
  96. f"weight shape must be (num_of_classes,): ({num_of_classes},), got {weight.shape}",
  97. )
  98. KORNIA_CHECK(
  99. weight.device == pred.device,
  100. f"weight and pred must be in the same device. Got: {weight.device} and {pred.device}",
  101. )
  102. weight = weight.view(broadcast_dims)
  103. loss_tmp = weight * loss_tmp
  104. if reduction == "none":
  105. loss = loss_tmp
  106. elif reduction == "mean":
  107. loss = torch.mean(loss_tmp)
  108. elif reduction == "sum":
  109. loss = torch.sum(loss_tmp)
  110. else:
  111. raise NotImplementedError(f"Invalid reduction mode: {reduction}")
  112. return loss
  113. class FocalLoss(nn.Module):
  114. r"""Criterion that computes Focal loss.
  115. According to :cite:`lin2018focal`, the Focal loss is computed as follows:
  116. .. math::
  117. \text{FL}(p_t) = -\alpha_t (1 - p_t)^{\gamma} \, \text{log}(p_t)
  118. Where:
  119. - :math:`p_t` is the model's estimated probability for each class.
  120. Args:
  121. alpha: Weighting factor :math:`\alpha \in [0, 1]`.
  122. gamma: Focusing parameter :math:`\gamma >= 0`.
  123. reduction: Specifies the reduction to apply to the
  124. output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction
  125. will be applied, ``'mean'``: the sum of the output will be divided by
  126. the number of elements in the output, ``'sum'``: the output will be
  127. summed.
  128. weight: weights for classes with shape :math:`(num\_of\_classes,)`.
  129. ignore_index: labels with this value are ignored in the loss computation.
  130. Shape:
  131. - Pred: :math:`(N, C, *)` where C = number of classes.
  132. - Target: :math:`(N, *)` where each value is an integer
  133. representing correct classification :math:`target[i] \in [0, C)`.
  134. Example:
  135. >>> C = 5 # num_classes
  136. >>> pred = torch.randn(1, C, 3, 5, requires_grad=True)
  137. >>> target = torch.randint(C, (1, 3, 5))
  138. >>> kwargs = {"alpha": 0.5, "gamma": 2.0, "reduction": 'mean'}
  139. >>> criterion = FocalLoss(**kwargs)
  140. >>> output = criterion(pred, target)
  141. >>> output.backward()
  142. """
  143. def __init__(
  144. self,
  145. alpha: Optional[float],
  146. gamma: float = 2.0,
  147. reduction: str = "none",
  148. weight: Optional[Tensor] = None,
  149. ignore_index: Optional[int] = -100,
  150. ) -> None:
  151. super().__init__()
  152. self.alpha: Optional[float] = alpha
  153. self.gamma: float = gamma
  154. self.reduction: str = reduction
  155. self.weight: Optional[Tensor] = weight
  156. self.ignore_index: Optional[int] = ignore_index
  157. def forward(self, pred: Tensor, target: Tensor) -> Tensor:
  158. return focal_loss(pred, target, self.alpha, self.gamma, self.reduction, self.weight, self.ignore_index)
  159. def binary_focal_loss_with_logits(
  160. pred: Tensor,
  161. target: Tensor,
  162. alpha: Optional[float] = 0.25,
  163. gamma: float = 2.0,
  164. reduction: str = "none",
  165. pos_weight: Optional[Tensor] = None,
  166. weight: Optional[Tensor] = None,
  167. ignore_index: Optional[int] = -100,
  168. ) -> Tensor:
  169. r"""Criterion that computes Binary Focal loss.
  170. According to :cite:`lin2018focal`, the Focal loss is computed as follows:
  171. .. math::
  172. \text{FL}(p_t) = -\alpha_t (1 - p_t)^{\gamma} \, \text{log}(p_t)
  173. where:
  174. - :math:`p_t` is the model's estimated probability for each class.
  175. Args:
  176. pred: logits tensor with shape :math:`(N, C, *)` where C = number of classes.
  177. target: labels tensor with the same shape as pred :math:`(N, C, *)`
  178. where each value is between 0 and 1.
  179. alpha: Weighting factor :math:`\alpha \in [0, 1]`.
  180. gamma: Focusing parameter :math:`\gamma >= 0`.
  181. reduction: Specifies the reduction to apply to the
  182. output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction
  183. will be applied, ``'mean'``: the sum of the output will be divided by
  184. the number of elements in the output, ``'sum'``: the output will be
  185. summed.
  186. pos_weight: a weight of positive examples with shape :math:`(num\_of\_classes,)`.
  187. It is possible to trade off recall and precision by adding weights to positive examples.
  188. weight: weights for classes with shape :math:`(num\_of\_classes,)`.
  189. ignore_index: labels with this value are ignored in the loss computation.
  190. Returns:
  191. the computed loss.
  192. Examples:
  193. >>> C = 3 # num_classes
  194. >>> pred = torch.randn(1, C, 5, requires_grad=True)
  195. >>> target = torch.randint(2, (1, C, 5))
  196. >>> kwargs = {"alpha": 0.25, "gamma": 2.0, "reduction": 'mean'}
  197. >>> output = binary_focal_loss_with_logits(pred, target, **kwargs)
  198. >>> output.backward()
  199. """
  200. KORNIA_CHECK_SHAPE(pred, ["B", "C", "*"])
  201. KORNIA_CHECK(pred.shape == target.shape, f"Expected target size {pred.shape}, got {target.shape}")
  202. KORNIA_CHECK(
  203. pred.device == target.device,
  204. f"pred and target must be in the same device. Got: {pred.device} and {target.device}",
  205. )
  206. log_probs_pos: Tensor = nn.functional.logsigmoid(pred)
  207. log_probs_neg: Tensor = nn.functional.logsigmoid(-pred)
  208. target, target_mask = mask_ignore_pixels(target, ignore_index)
  209. if target_mask is not None:
  210. # mask ignore pixels
  211. log_probs_neg = log_probs_neg * target_mask
  212. log_probs_pos = log_probs_pos * target_mask
  213. pos_term: Tensor = -log_probs_neg.exp().pow(gamma) * target * log_probs_pos
  214. neg_term: Tensor = -log_probs_pos.exp().pow(gamma) * (1.0 - target) * log_probs_neg
  215. if alpha is not None:
  216. pos_term = alpha * pos_term
  217. neg_term = (1.0 - alpha) * neg_term
  218. num_of_classes = pred.shape[1]
  219. broadcast_dims = [-1] + [1] * len(pred.shape[2:])
  220. if pos_weight is not None:
  221. KORNIA_CHECK_IS_TENSOR(pos_weight, "pos_weight must be Tensor or None.")
  222. KORNIA_CHECK(
  223. (pos_weight.shape[0] == num_of_classes and pos_weight.numel() == num_of_classes),
  224. f"pos_weight shape must be (num_of_classes,): ({num_of_classes},), got {pos_weight.shape}",
  225. )
  226. KORNIA_CHECK(
  227. pos_weight.device == pred.device,
  228. f"pos_weight and pred must be in the same device. Got: {pos_weight.device} and {pred.device}",
  229. )
  230. pos_weight = pos_weight.view(broadcast_dims)
  231. pos_term = pos_weight * pos_term
  232. loss_tmp: Tensor = pos_term + neg_term
  233. if weight is not None:
  234. KORNIA_CHECK_IS_TENSOR(weight, "weight must be Tensor or None.")
  235. KORNIA_CHECK(
  236. (weight.shape[0] == num_of_classes and weight.numel() == num_of_classes),
  237. f"weight shape must be (num_of_classes,): ({num_of_classes},), got {weight.shape}",
  238. )
  239. KORNIA_CHECK(
  240. weight.device == pred.device,
  241. f"weight and pred must be in the same device. Got: {weight.device} and {pred.device}",
  242. )
  243. weight = weight.view(broadcast_dims)
  244. loss_tmp = weight * loss_tmp
  245. if reduction == "none":
  246. loss = loss_tmp
  247. elif reduction == "mean":
  248. loss = torch.mean(loss_tmp)
  249. elif reduction == "sum":
  250. loss = torch.sum(loss_tmp)
  251. else:
  252. raise NotImplementedError(f"Invalid reduction mode: {reduction}")
  253. return loss
  254. class BinaryFocalLossWithLogits(nn.Module):
  255. r"""Criterion that computes Focal loss.
  256. According to :cite:`lin2018focal`, the Focal loss is computed as follows:
  257. .. math::
  258. \text{FL}(p_t) = -\alpha_t (1 - p_t)^{\gamma} \, \text{log}(p_t)
  259. where:
  260. - :math:`p_t` is the model's estimated probability for each class.
  261. Args:
  262. alpha: Weighting factor :math:`\alpha \in [0, 1]`.
  263. gamma: Focusing parameter :math:`\gamma >= 0`.
  264. reduction: Specifies the reduction to apply to the
  265. output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction
  266. will be applied, ``'mean'``: the sum of the output will be divided by
  267. the number of elements in the output, ``'sum'``: the output will be
  268. summed.
  269. pos_weight: a weight of positive examples with shape :math:`(num\_of\_classes,)`.
  270. It is possible to trade off recall and precision by adding weights to positive examples.
  271. weight: weights for classes with shape :math:`(num\_of\_classes,)`.
  272. ignore_index: labels with this value are ignored in the loss computation.
  273. Shape:
  274. - Pred: :math:`(N, C, *)` where C = number of classes.
  275. - Target: the same shape as Pred :math:`(N, C, *)`
  276. where each value is between 0 and 1.
  277. Examples:
  278. >>> C = 3 # num_classes
  279. >>> pred = torch.randn(1, C, 5, requires_grad=True)
  280. >>> target = torch.randint(2, (1, C, 5))
  281. >>> kwargs = {"alpha": 0.25, "gamma": 2.0, "reduction": 'mean'}
  282. >>> criterion = BinaryFocalLossWithLogits(**kwargs)
  283. >>> output = criterion(pred, target)
  284. >>> output.backward()
  285. """
  286. def __init__(
  287. self,
  288. alpha: Optional[float],
  289. gamma: float = 2.0,
  290. reduction: str = "none",
  291. pos_weight: Optional[Tensor] = None,
  292. weight: Optional[Tensor] = None,
  293. ignore_index: Optional[int] = -100,
  294. ) -> None:
  295. super().__init__()
  296. self.alpha: Optional[float] = alpha
  297. self.gamma: float = gamma
  298. self.reduction: str = reduction
  299. self.pos_weight: Optional[Tensor] = pos_weight
  300. self.weight: Optional[Tensor] = weight
  301. self.ignore_index: Optional[int] = ignore_index
  302. def forward(self, pred: Tensor, target: Tensor) -> Tensor:
  303. return binary_focal_loss_with_logits(
  304. pred, target, self.alpha, self.gamma, self.reduction, self.pos_weight, self.weight, self.ignore_index
  305. )