| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192 |
- # LICENSE HEADER MANAGED BY add-license-header
- #
- # Copyright 2018 Kornia Team
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- r"""Losses based on the divergence between probability distributions."""
- from __future__ import annotations
- import torch
- import torch.nn.functional as F
- from kornia.core import Tensor
- def _kl_div_2d(p: Tensor, q: Tensor) -> Tensor:
- # D_KL(P || Q)
- batch, chans, height, width = p.shape
- unsummed_kl = F.kl_div(
- q.reshape(batch * chans, height * width).log(), p.reshape(batch * chans, height * width), reduction="none"
- )
- kl_values = unsummed_kl.sum(-1).view(batch, chans)
- return kl_values
- def _js_div_2d(p: Tensor, q: Tensor) -> Tensor:
- # JSD(P || Q)
- m = 0.5 * (p + q)
- return 0.5 * _kl_div_2d(p, m) + 0.5 * _kl_div_2d(q, m)
- # TODO: add this to the main module
- def _reduce_loss(losses: Tensor, reduction: str) -> Tensor:
- if reduction == "none":
- return losses
- return torch.mean(losses) if reduction == "mean" else torch.sum(losses)
- def js_div_loss_2d(pred: Tensor, target: Tensor, reduction: str = "mean") -> Tensor:
- r"""Calculate the Jensen-Shannon divergence loss between heatmaps.
- Args:
- pred: the input tensor with shape :math:`(B, N, H, W)`.
- target: the target tensor with shape :math:`(B, N, H, W)`.
- reduction: Specifies the reduction to apply to the
- output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction
- will be applied, ``'mean'``: the sum of the output will be divided by
- the number of elements in the output, ``'sum'``: the output will be
- summed.
- Examples:
- >>> pred = torch.full((1, 1, 2, 4), 0.125)
- >>> loss = js_div_loss_2d(pred, pred)
- >>> loss.item()
- 0.0
- """
- return _reduce_loss(_js_div_2d(target, pred), reduction)
- def kl_div_loss_2d(pred: Tensor, target: Tensor, reduction: str = "mean") -> Tensor:
- r"""Calculate the Kullback-Leibler divergence loss between heatmaps.
- Args:
- pred: the input tensor with shape :math:`(B, N, H, W)`.
- target: the target tensor with shape :math:`(B, N, H, W)`.
- reduction: Specifies the reduction to apply to the
- output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction
- will be applied, ``'mean'``: the sum of the output will be divided by
- the number of elements in the output, ``'sum'``: the output will be
- summed.
- Examples:
- >>> pred = torch.full((1, 1, 2, 4), 0.125)
- >>> loss = kl_div_loss_2d(pred, pred)
- >>> loss.item()
- 0.0
- """
- return _reduce_loss(_kl_div_2d(target, pred), reduction)
|