utils.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from typing import Optional, Tuple, Union
  18. import torch
  19. import torch.nn.functional as F
  20. from kornia.core import Tensor
  21. @torch.no_grad()
  22. def sample_keypoints(
  23. scoremap: Tensor, num_samples: Optional[int] = 10_000, return_scoremap: bool = True, increase_coverage: bool = True
  24. ) -> Union[Tensor, Tuple[Tensor, Tensor]]:
  25. """Sample keypoints from provided candidates."""
  26. device = scoremap.device
  27. dtype = scoremap.dtype
  28. B, H, W = scoremap.shape
  29. if increase_coverage:
  30. weights = (-(torch.linspace(-2, 2, steps=51, device=device, dtype=dtype) ** 2)).exp()[None, None]
  31. # 10000 is just some number for maybe numerical stability, who knows. :), result is invariant anyway
  32. local_density_x = F.conv2d((scoremap[:, None] + 1e-6) * 10000, weights[..., None, :], padding=(0, 51 // 2))
  33. local_density = F.conv2d(local_density_x, weights[..., None], padding=(51 // 2, 0))[:, 0]
  34. scoremap = scoremap * (local_density + 1e-8) ** (-1 / 2)
  35. grid = get_grid(B, H, W, device=device).reshape(B, H * W, 2)
  36. inds = torch.topk(scoremap.reshape(B, H * W), k=num_samples).indices # type: ignore
  37. kps = torch.gather(grid, dim=1, index=inds[..., None].expand(B, num_samples, 2)) # type: ignore
  38. if return_scoremap:
  39. return kps, torch.gather(scoremap.reshape(B, H * W), dim=1, index=inds)
  40. return kps
  41. def get_grid(B: int, H: int, W: int, device: torch.device) -> torch.Tensor:
  42. """Get grid of provided layout."""
  43. xs = (torch.arange(W, device=device) + 0.5) / W * 2 - 1
  44. ys = (torch.arange(H, device=device) + 0.5) / H * 2 - 1
  45. yy, xx = torch.meshgrid(ys, xs, indexing="ij")
  46. base = torch.stack((xx, yy), dim=-1).reshape(1, H * W, 2)
  47. return base.expand(B, -1, -1)
  48. def dedode_denormalize_pixel_coordinates(flow: torch.Tensor, h: int, w: int) -> torch.Tensor:
  49. """Denormalize pixel coordinates."""
  50. flow = torch.stack(
  51. (
  52. w * (flow[..., 0] + 1) / 2,
  53. h * (flow[..., 1] + 1) / 2,
  54. ),
  55. dim=-1,
  56. )
  57. return flow