dsnt.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. r"""Implementation of "differentiable spatial to numerical" (soft-argmax) operations.
  18. As described in the paper "Numerical Coordinate Regression with Convolutional Neural Networks" by Nibali et al.
  19. """
  20. from __future__ import annotations
  21. from typing import Optional
  22. import torch
  23. from kornia.core import Tensor, concatenate, softmax
  24. from kornia.core.check import KORNIA_CHECK_IS_TENSOR, KORNIA_CHECK_SHAPE
  25. from kornia.utils.grid import create_meshgrid
  26. def _validate_batched_image_tensor_input(tensor: Tensor) -> None:
  27. KORNIA_CHECK_IS_TENSOR(tensor)
  28. KORNIA_CHECK_SHAPE(tensor, ["B", "C", "H", "W"])
  29. def spatial_softmax2d(input: Tensor, temperature: Optional[Tensor] = None) -> Tensor:
  30. r"""Apply the Softmax function over features in each image channel.
  31. Note that this function behaves differently to :py:class:`torch.nn.Softmax2d`, which
  32. instead applies Softmax over features at each spatial location.
  33. Args:
  34. input: the input tensor with shape :math:`(B, N, H, W)`.
  35. temperature: factor to apply to input, adjusting the "smoothness" of the output distribution.
  36. Returns:
  37. a 2D probability distribution per image channel with shape :math:`(B, N, H, W)`.
  38. Examples:
  39. >>> heatmaps = torch.tensor([[[
  40. ... [0., 0., 0.],
  41. ... [0., 0., 0.],
  42. ... [0., 1., 2.]]]])
  43. >>> spatial_softmax2d(heatmaps)
  44. tensor([[[[0.0585, 0.0585, 0.0585],
  45. [0.0585, 0.0585, 0.0585],
  46. [0.0585, 0.1589, 0.4319]]]])
  47. """
  48. _validate_batched_image_tensor_input(input)
  49. batch_size, channels, height, width = input.shape
  50. if temperature is None:
  51. temperature = torch.tensor(1.0)
  52. temperature = temperature.to(device=input.device, dtype=input.dtype)
  53. x = input.view(batch_size, channels, -1)
  54. x_soft = softmax(x * temperature, dim=-1)
  55. return x_soft.view(batch_size, channels, height, width)
  56. def spatial_expectation2d(input: Tensor, normalized_coordinates: bool = True) -> Tensor:
  57. r"""Compute the expectation of coordinate values using spatial probabilities.
  58. The input heatmap is assumed to represent a valid spatial probability distribution,
  59. which can be achieved using :func:`~kornia.geometry.subpixel.spatial_softmax2d`.
  60. Args:
  61. input: the input tensor representing dense spatial probabilities with shape :math:`(B, N, H, W)`.
  62. normalized_coordinates: whether to return the coordinates normalized in the range
  63. of :math:`[-1, 1]`. Otherwise, it will return the coordinates in the range of the input shape.
  64. Returns:
  65. expected value of the 2D coordinates with shape :math:`(B, N, 2)`. Output order of the coordinates is (x, y).
  66. Examples:
  67. >>> heatmaps = torch.tensor([[[
  68. ... [0., 0., 0.],
  69. ... [0., 0., 0.],
  70. ... [0., 1., 0.]]]])
  71. >>> spatial_expectation2d(heatmaps, False)
  72. tensor([[[1., 2.]]])
  73. """
  74. _validate_batched_image_tensor_input(input)
  75. batch_size, channels, height, width = input.shape
  76. # Create coordinates grid.
  77. grid = create_meshgrid(height, width, normalized_coordinates, input.device)
  78. grid = grid.to(input.dtype)
  79. pos_x = grid[..., 0].reshape(-1)
  80. pos_y = grid[..., 1].reshape(-1)
  81. input_flat = input.view(batch_size, channels, -1)
  82. # Compute the expectation of the coordinates.
  83. expected_y = torch.sum(pos_y * input_flat, -1, keepdim=True)
  84. expected_x = torch.sum(pos_x * input_flat, -1, keepdim=True)
  85. output = concatenate([expected_x, expected_y], -1)
  86. return output.view(batch_size, channels, 2) # BxNx2
  87. def _safe_zero_division(numerator: Tensor, denominator: Tensor, eps: float = 1e-32) -> Tensor:
  88. return numerator / torch.clamp(denominator, min=eps)
  89. def render_gaussian2d(mean: Tensor, std: Tensor, size: tuple[int, int], normalized_coordinates: bool = True) -> Tensor:
  90. r"""Render the PDF of a 2D Gaussian distribution.
  91. Args:
  92. mean: the mean location of the Gaussian to render, :math:`(\mu_x, \mu_y)`. Shape: :math:`(*, 2)`.
  93. std: the standard deviation of the Gaussian to render, :math:`(\sigma_x, \sigma_y)`.
  94. Shape :math:`(*, 2)`. Should be able to be broadcast with `mean`.
  95. size: the (height, width) of the output image.
  96. normalized_coordinates: whether ``mean`` and ``std`` are assumed to use coordinates normalized
  97. in the range of :math:`[-1, 1]`. Otherwise, coordinates are assumed to be in the range of the output shape.
  98. Returns:
  99. tensor including rendered points with shape :math:`(*, H, W)`.
  100. """
  101. if not (std.dtype == mean.dtype and std.device == mean.device):
  102. raise TypeError("Expected inputs to have the same dtype and device")
  103. height, width = size
  104. # Create coordinates grid.
  105. grid = create_meshgrid(height, width, normalized_coordinates, mean.device)
  106. grid = grid.to(mean.dtype)
  107. pos_x = grid[..., 0].view(height, width)
  108. pos_y = grid[..., 1].view(height, width)
  109. # Gaussian PDF = exp(-(x - \mu)^2 / (2 \sigma^2))
  110. # = exp(dists * ks),
  111. # where dists = (x - \mu)^2 and ks = -1 / (2 \sigma^2)
  112. # dists <- (x - \mu)^2
  113. dist_x = (pos_x - mean[..., 0, None, None]) ** 2
  114. dist_y = (pos_y - mean[..., 1, None, None]) ** 2
  115. # ks <- -1 / (2 \sigma^2)
  116. k_x = -0.5 * torch.reciprocal(std[..., 0, None, None])
  117. k_y = -0.5 * torch.reciprocal(std[..., 1, None, None])
  118. # Assemble the 2D Gaussian.
  119. exps_x = torch.exp(dist_x * k_x)
  120. exps_y = torch.exp(dist_y * k_y)
  121. gauss = exps_x * exps_y
  122. # Rescale so that values sum to one.
  123. val_sum = gauss.sum(-2, keepdim=True).sum(-1, keepdim=True)
  124. gauss = _safe_zero_division(gauss, val_sum)
  125. return gauss