structs.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from __future__ import annotations
  18. from dataclasses import dataclass
  19. from typing import Any
  20. import torch.nn.functional as F
  21. from kornia.core import Device, Tensor
  22. @dataclass
  23. class DISKFeatures:
  24. r"""A data structure holding DISK keypoints, descriptors and detection scores for an image.
  25. Since DISK detects a varying number of keypoints per image, `DISKFeatures` is not batched.
  26. Args:
  27. keypoints: Tensor of shape :math:`(N, 2)`, where :math:`N` is the number of keypoints.
  28. descriptors: Tensor of shape :math:`(N, D)`, where :math:`D` is the descriptor dimension.
  29. detection_scores: Tensor of shape :math:`(N,)` where the detection score can be interpreted as
  30. the log-probability of keeping a keypoint after it has been proposed (see the paper
  31. section *Method → Feature distribution* for details).
  32. """
  33. keypoints: Tensor
  34. descriptors: Tensor
  35. detection_scores: Tensor
  36. @property
  37. def n(self) -> int:
  38. return self.keypoints.shape[0]
  39. @property
  40. def device(self) -> Device:
  41. return self.keypoints.device
  42. @property
  43. def x(self) -> Tensor:
  44. """Accesses the x coordinates of keypoints (along image width)."""
  45. return self.keypoints[:, 0]
  46. @property
  47. def y(self) -> Tensor:
  48. """Accesses the y coordinates of keypoints (along image height)."""
  49. return self.keypoints[:, 1]
  50. def to(self, *args: Any, **kwargs: Any) -> DISKFeatures:
  51. """Call :func:`torch.Tensor.to` on each tensor to move the keypoints, descriptors and detection scores to
  52. the specified device and/or data type.
  53. Args:
  54. *args: Arguments passed to :func:`torch.Tensor.to`.
  55. **kwargs: Keyword arguments passed to :func:`torch.Tensor.to`.
  56. Returns:
  57. A new DISKFeatures object with tensors of appropriate type and location.
  58. """ # noqa:D205
  59. return DISKFeatures(
  60. self.keypoints.to(*args, **kwargs),
  61. self.descriptors.to(*args, **kwargs),
  62. self.detection_scores.to(*args, **kwargs),
  63. )
  64. @dataclass
  65. class Keypoints:
  66. """A temporary struct used to store keypoint detections and their log-probabilities.
  67. After construction, merge_with_descriptors is used to select corresponding descriptors from unet output.
  68. """
  69. xys: Tensor
  70. detection_logp: Tensor
  71. def merge_with_descriptors(self, descriptors: Tensor) -> DISKFeatures:
  72. """Select descriptors from a dense `descriptors` tensor, at locations given by `self.xys`."""
  73. dtype = descriptors.dtype
  74. x, y = self.xys.T
  75. desc = descriptors[:, y, x].T
  76. desc = F.normalize(desc, dim=-1)
  77. return DISKFeatures(self.xys.to(dtype), desc, self.detection_logp)