divergence.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. r"""Losses based on the divergence between probability distributions."""
  18. from __future__ import annotations
  19. import torch
  20. import torch.nn.functional as F
  21. from kornia.core import Tensor
  22. def _kl_div_2d(p: Tensor, q: Tensor) -> Tensor:
  23. # D_KL(P || Q)
  24. batch, chans, height, width = p.shape
  25. unsummed_kl = F.kl_div(
  26. q.reshape(batch * chans, height * width).log(), p.reshape(batch * chans, height * width), reduction="none"
  27. )
  28. kl_values = unsummed_kl.sum(-1).view(batch, chans)
  29. return kl_values
  30. def _js_div_2d(p: Tensor, q: Tensor) -> Tensor:
  31. # JSD(P || Q)
  32. m = 0.5 * (p + q)
  33. return 0.5 * _kl_div_2d(p, m) + 0.5 * _kl_div_2d(q, m)
  34. # TODO: add this to the main module
  35. def _reduce_loss(losses: Tensor, reduction: str) -> Tensor:
  36. if reduction == "none":
  37. return losses
  38. return torch.mean(losses) if reduction == "mean" else torch.sum(losses)
  39. def js_div_loss_2d(pred: Tensor, target: Tensor, reduction: str = "mean") -> Tensor:
  40. r"""Calculate the Jensen-Shannon divergence loss between heatmaps.
  41. Args:
  42. pred: the input tensor with shape :math:`(B, N, H, W)`.
  43. target: the target tensor with shape :math:`(B, N, H, W)`.
  44. reduction: Specifies the reduction to apply to the
  45. output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction
  46. will be applied, ``'mean'``: the sum of the output will be divided by
  47. the number of elements in the output, ``'sum'``: the output will be
  48. summed.
  49. Examples:
  50. >>> pred = torch.full((1, 1, 2, 4), 0.125)
  51. >>> loss = js_div_loss_2d(pred, pred)
  52. >>> loss.item()
  53. 0.0
  54. """
  55. return _reduce_loss(_js_div_2d(target, pred), reduction)
  56. def kl_div_loss_2d(pred: Tensor, target: Tensor, reduction: str = "mean") -> Tensor:
  57. r"""Calculate the Kullback-Leibler divergence loss between heatmaps.
  58. Args:
  59. pred: the input tensor with shape :math:`(B, N, H, W)`.
  60. target: the target tensor with shape :math:`(B, N, H, W)`.
  61. reduction: Specifies the reduction to apply to the
  62. output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction
  63. will be applied, ``'mean'``: the sum of the output will be divided by
  64. the number of elements in the output, ``'sum'``: the output will be
  65. summed.
  66. Examples:
  67. >>> pred = torch.full((1, 1, 2, 4), 0.125)
  68. >>> loss = kl_div_loss_2d(pred, pred)
  69. >>> loss.item()
  70. 0.0
  71. """
  72. return _reduce_loss(_kl_div_2d(target, pred), reduction)