| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778 |
- # LICENSE HEADER MANAGED BY add-license-header
- #
- # Copyright 2018 Kornia Team
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- from __future__ import annotations
- from typing import Optional
- import torch
- import torch.nn.functional as F
- from kornia.core import Tensor
- from .structs import Keypoints
- def nms(signal: Tensor, window_size: int = 5, cutoff: float = 0.0) -> Tensor:
- """Apply non-maximum suppression."""
- if window_size % 2 != 1:
- raise ValueError(f"window_size has to be odd, got {window_size}")
- _, ixs = F.max_pool2d(signal, kernel_size=window_size, stride=1, padding=window_size // 2, return_indices=True)
- h, w = signal.shape[1:]
- coords = torch.arange(h * w, device=signal.device).reshape(1, h, w)
- nms = ixs == coords
- if cutoff is None:
- return nms
- else:
- return nms & (signal > cutoff)
- def heatmap_to_keypoints(
- heatmap: Tensor, n: Optional[int] = None, window_size: int = 5, score_threshold: float = 0.0
- ) -> list[Keypoints]:
- """Inference-time nms-based detection protocol."""
- heatmap = heatmap.squeeze(1)
- nmsed = nms(heatmap, window_size=window_size, cutoff=score_threshold)
- keypoints = []
- for b in range(heatmap.shape[0]):
- yx = nmsed[b].nonzero(as_tuple=False)
- detection_logp = heatmap[b][nmsed[b]]
- xy = yx.flip((1,))
- if n is not None:
- n_ = min(n + 1, detection_logp.numel())
- # torch.kthvalue picks in ascending order and we want to pick in
- # descending order, so we pick n-th smallest among -logp to get
- # -threshold
- minus_threshold, _indices = torch.kthvalue(-detection_logp, n_)
- mask = detection_logp > -minus_threshold
- xy = xy[mask]
- detection_logp = detection_logp[mask]
- # it may be that due to numerical saturation on the threshold we have
- # more than n keypoints, so we need to clip them
- xy = xy[:n]
- detection_logp = detection_logp[:n]
- keypoints.append(Keypoints(xy, detection_logp))
- return keypoints
|