detector.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from __future__ import annotations
  18. from typing import Optional
  19. import torch
  20. import torch.nn.functional as F
  21. from kornia.core import Tensor
  22. from .structs import Keypoints
  23. def nms(signal: Tensor, window_size: int = 5, cutoff: float = 0.0) -> Tensor:
  24. """Apply non-maximum suppression."""
  25. if window_size % 2 != 1:
  26. raise ValueError(f"window_size has to be odd, got {window_size}")
  27. _, ixs = F.max_pool2d(signal, kernel_size=window_size, stride=1, padding=window_size // 2, return_indices=True)
  28. h, w = signal.shape[1:]
  29. coords = torch.arange(h * w, device=signal.device).reshape(1, h, w)
  30. nms = ixs == coords
  31. if cutoff is None:
  32. return nms
  33. else:
  34. return nms & (signal > cutoff)
  35. def heatmap_to_keypoints(
  36. heatmap: Tensor, n: Optional[int] = None, window_size: int = 5, score_threshold: float = 0.0
  37. ) -> list[Keypoints]:
  38. """Inference-time nms-based detection protocol."""
  39. heatmap = heatmap.squeeze(1)
  40. nmsed = nms(heatmap, window_size=window_size, cutoff=score_threshold)
  41. keypoints = []
  42. for b in range(heatmap.shape[0]):
  43. yx = nmsed[b].nonzero(as_tuple=False)
  44. detection_logp = heatmap[b][nmsed[b]]
  45. xy = yx.flip((1,))
  46. if n is not None:
  47. n_ = min(n + 1, detection_logp.numel())
  48. # torch.kthvalue picks in ascending order and we want to pick in
  49. # descending order, so we pick n-th smallest among -logp to get
  50. # -threshold
  51. minus_threshold, _indices = torch.kthvalue(-detection_logp, n_)
  52. mask = detection_logp > -minus_threshold
  53. xy = xy[mask]
  54. detection_logp = detection_logp[mask]
  55. # it may be that due to numerical saturation on the threshold we have
  56. # more than n keypoints, so we need to clip them
  57. xy = xy[:n]
  58. detection_logp = detection_logp[:n]
  59. keypoints.append(Keypoints(xy, detection_logp))
  60. return keypoints