# LICENSE HEADER MANAGED BY add-license-header # # Copyright 2018 Kornia Team # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from __future__ import annotations from dataclasses import dataclass from typing import Any import torch.nn.functional as F from kornia.core import Device, Tensor @dataclass class DISKFeatures: r"""A data structure holding DISK keypoints, descriptors and detection scores for an image. Since DISK detects a varying number of keypoints per image, `DISKFeatures` is not batched. Args: keypoints: Tensor of shape :math:`(N, 2)`, where :math:`N` is the number of keypoints. descriptors: Tensor of shape :math:`(N, D)`, where :math:`D` is the descriptor dimension. detection_scores: Tensor of shape :math:`(N,)` where the detection score can be interpreted as the log-probability of keeping a keypoint after it has been proposed (see the paper section *Method → Feature distribution* for details). """ keypoints: Tensor descriptors: Tensor detection_scores: Tensor @property def n(self) -> int: return self.keypoints.shape[0] @property def device(self) -> Device: return self.keypoints.device @property def x(self) -> Tensor: """Accesses the x coordinates of keypoints (along image width).""" return self.keypoints[:, 0] @property def y(self) -> Tensor: """Accesses the y coordinates of keypoints (along image height).""" return self.keypoints[:, 1] def to(self, *args: Any, **kwargs: Any) -> DISKFeatures: """Call :func:`torch.Tensor.to` on each tensor to move the keypoints, descriptors and detection scores to the specified device and/or data type. Args: *args: Arguments passed to :func:`torch.Tensor.to`. **kwargs: Keyword arguments passed to :func:`torch.Tensor.to`. Returns: A new DISKFeatures object with tensors of appropriate type and location. """ # noqa:D205 return DISKFeatures( self.keypoints.to(*args, **kwargs), self.descriptors.to(*args, **kwargs), self.detection_scores.to(*args, **kwargs), ) @dataclass class Keypoints: """A temporary struct used to store keypoint detections and their log-probabilities. After construction, merge_with_descriptors is used to select corresponding descriptors from unet output. """ xys: Tensor detection_logp: Tensor def merge_with_descriptors(self, descriptors: Tensor) -> DISKFeatures: """Select descriptors from a dense `descriptors` tensor, at locations given by `self.xys`.""" dtype = descriptors.dtype x, y = self.xys.T desc = descriptors[:, y, x].T desc = F.normalize(desc, dim=-1) return DISKFeatures(self.xys.to(dtype), desc, self.detection_logp)