dedode.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from typing import Dict, Literal, Optional, Tuple
  18. import torch
  19. import torch.nn.functional as F
  20. from kornia.core import Module, Tensor
  21. from kornia.core.check import KORNIA_CHECK_SHAPE
  22. from kornia.enhance.normalize import Normalize
  23. from kornia.utils.helpers import map_location_to_cpu
  24. from .dedode_models import DeDoDeDescriptor, DeDoDeDetector, get_descriptor, get_detector
  25. from .utils import dedode_denormalize_pixel_coordinates, sample_keypoints
  26. urls: Dict[str, Dict[str, str]] = {
  27. "detector": {
  28. "L-upright": "https://github.com/Parskatt/DeDoDe/releases/download/dedode_pretrained_models/dedode_detector_L.pth",
  29. "L-C4": "https://github.com/georg-bn/rotation-steerers/releases/download/release-2/dedode_detector_C4.pth",
  30. "L-SO2": "https://github.com/georg-bn/rotation-steerers/releases/download/release-2/dedode_detector_SO2.pth",
  31. "L-C4-v2": "https://github.com/Parskatt/DeDoDe/releases/download/v2/dedode_detector_L_v2.pth",
  32. },
  33. "descriptor": {
  34. "B-upright": "https://github.com/Parskatt/DeDoDe/releases/download/dedode_pretrained_models/dedode_descriptor_B.pth",
  35. "B-C4": "https://github.com/georg-bn/rotation-steerers/releases/download/release-2/B_C4_Perm_descriptor_setting_C.pth",
  36. "B-SO2": "https://github.com/georg-bn/rotation-steerers/releases/download/release-2/B_SO2_Spread_descriptor_setting_C.pth",
  37. "G-upright": "https://github.com/Parskatt/DeDoDe/releases/download/dedode_pretrained_models/dedode_descriptor_G.pth",
  38. "G-C4": "https://github.com/georg-bn/rotation-steerers/releases/download/release-2/G_C4_Perm_descriptor_setting_C.pth",
  39. "G-SO2": "https://github.com/georg-bn/rotation-steerers/releases/download/release-2/G_SO2_Spread_descriptor_setting_C.pth",
  40. },
  41. }
  42. class DeDoDe(Module):
  43. r"""Module which detects and/or describes local features in an image using the DeDode method.
  44. See :cite:`edstedt2024dedode` for details.
  45. .. note:: DeDode takes ImageNet normalized images as input (not in range [0, 1]).
  46. Args:
  47. detector_model: The detector model kind. Available options are: `L`.
  48. descriptor_model: The descriptor model kind. Available options are: `G` or `B`
  49. amp_dtype: The automatic mixed precision desired.
  50. Example:
  51. >>> dedode = DeDoDe.from_pretrained(detector_weights="L-C4-v2", descriptor_weights="B-upright")
  52. >>> images = torch.randn(1, 3, 256, 256)
  53. >>> keypoints, scores = dedode.detect(images)
  54. >>> descriptions = dedode.describe(images, keypoints = keypoints)
  55. >>> keypoints, scores, features = dedode(images) # alternatively do both
  56. """
  57. # TODO: implement steerers and mnn matchers
  58. def __init__(
  59. self,
  60. detector_model: Literal["L"] = "L",
  61. descriptor_model: Literal["G", "B"] = "G",
  62. amp_dtype: torch.dtype = torch.float16,
  63. ) -> None:
  64. super().__init__()
  65. self.detector: DeDoDeDetector = get_detector(detector_model, amp_dtype)
  66. self.descriptor: DeDoDeDescriptor = get_descriptor(descriptor_model, amp_dtype)
  67. self.normalizer = Normalize(torch.tensor([0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225]))
  68. def forward(
  69. self,
  70. images: Tensor,
  71. n: Optional[int] = 10_000,
  72. apply_imagenet_normalization: bool = True,
  73. pad_if_not_divisible: bool = True,
  74. ) -> Tuple[Tensor, Tensor, Tensor]:
  75. """Detect and describe keypoints in the input images.
  76. Args:
  77. images: A tensor of shape :math:`(B, 3, H, W)` containing the ImageNet-Normalized input images.
  78. n: The number of keypoints to detect.
  79. apply_imagenet_normalization: Whether to apply ImageNet normalization to the input images.
  80. pad_if_not_divisible: pad image shape if not evenly divisible.
  81. Returns:
  82. keypoints: A tensor of shape :math:`(B, N, 2)` containing the detected keypoints in the image range,
  83. unlike `.detect()` function
  84. scores: A tensor of shape :math:`(B, N)` containing the scores of the detected keypoints.
  85. descriptions: A tensor of shape :math:`(B, N, DIM)` containing the descriptions of the detected keypoints.
  86. DIM is 256 for B and 512 for G.
  87. """
  88. if apply_imagenet_normalization:
  89. images = self.normalizer(images)
  90. _B, _C, H, W = images.shape
  91. h, w = images.shape[2:]
  92. if pad_if_not_divisible:
  93. pd_h = 14 - h % 14 if h % 14 > 0 else 0
  94. pd_w = 14 - w % 14 if w % 14 > 0 else 0
  95. images = torch.nn.functional.pad(images, (0, pd_w, 0, pd_h), value=0.0)
  96. keypoints, scores = self.detect(images, n=n, apply_imagenet_normalization=False, crop_h=h, crop_w=w)
  97. descriptions = self.describe(images, keypoints, apply_imagenet_normalization=False, crop_h=h, crop_w=w)
  98. return dedode_denormalize_pixel_coordinates(keypoints, H, W), scores, descriptions
  99. @torch.inference_mode()
  100. def detect(
  101. self,
  102. images: Tensor,
  103. n: Optional[int] = 10_000,
  104. apply_imagenet_normalization: bool = True,
  105. pad_if_not_divisible: bool = True,
  106. crop_h: Optional[int] = None,
  107. crop_w: Optional[int] = None,
  108. ) -> Tuple[Tensor, Tensor]:
  109. """Detect keypoints in the input images.
  110. Args:
  111. images: A tensor of shape :math:`(B, 3, H, W)` containing the input images.
  112. n: The number of keypoints to detect.
  113. apply_imagenet_normalization: Whether to apply ImageNet normalization to the input images.
  114. pad_if_not_divisible: pad image shape if not evenly divisible.
  115. crop_h: The height of the crop to be used for detection. If None, the full image is used.
  116. crop_w: The width of the crop to be used for detection. If None, the full image is used.
  117. Returns:
  118. keypoints: A tensor of shape :math:`(B, N, 2)` containing the detected keypoints,
  119. normalized to the range :math:`[-1, 1]`.
  120. scores: A tensor of shape :math:`(B, N)` containing the scores of the detected keypoints.
  121. """
  122. KORNIA_CHECK_SHAPE(images, ["B", "3", "H", "W"])
  123. self.train(False)
  124. B, _C, H, W = images.shape
  125. if pad_if_not_divisible:
  126. h, w = images.shape[2:]
  127. pd_h = 14 - h % 14 if h % 14 > 0 else 0
  128. pd_w = 14 - w % 14 if w % 14 > 0 else 0
  129. images = torch.nn.functional.pad(images, (0, pd_w, 0, pd_h), value=0.0)
  130. if apply_imagenet_normalization:
  131. images = self.normalizer(images)
  132. logits = self.detector.forward(images)
  133. # Remove the padding, if any
  134. logits = logits[..., :H, :W]
  135. if crop_h is not None and crop_w is not None:
  136. logits = logits[..., :crop_h, :crop_w]
  137. H, W = crop_h, crop_w
  138. scoremap = logits.reshape(B, H * W).softmax(dim=-1).reshape(B, H, W)
  139. keypoints, confidence = sample_keypoints(scoremap, num_samples=n)
  140. return keypoints, confidence
  141. @torch.inference_mode()
  142. def describe(
  143. self,
  144. images: Tensor,
  145. keypoints: Optional[Tensor] = None,
  146. apply_imagenet_normalization: bool = True,
  147. crop_h: Optional[int] = None,
  148. crop_w: Optional[int] = None,
  149. ) -> Tensor:
  150. """Describe keypoints in the input images. If keypoints are not provided, returns the dense descriptors.
  151. Args:
  152. images: A tensor of shape :math:`(B, 3, H, W)` containing the input images.
  153. keypoints: An optional tensor of shape :math:`(B, N, 2)` containing the detected keypoints.
  154. apply_imagenet_normalization: Whether to apply ImageNet normalization to the input images.
  155. crop_h: The height of the crop to be used for description. If None, the full image is used.
  156. crop_w: The width of the crop to be used for description. If None, the full image is used.
  157. Returns:
  158. descriptions: A tensor of shape :math:`(B, N, DIM)` containing the descriptions of the detected keypoints.
  159. If the dense descriptors are requested, the shape is :math:`(B, DIM, H, W)`.
  160. """
  161. KORNIA_CHECK_SHAPE(images, ["B", "3", "H", "W"])
  162. _B, _C, H, W = images.shape
  163. if keypoints is not None:
  164. KORNIA_CHECK_SHAPE(keypoints, ["B", "N", "2"])
  165. if apply_imagenet_normalization:
  166. images = self.normalizer(images)
  167. self.train(False)
  168. descriptions = self.descriptor.forward(images)
  169. if crop_h is not None and crop_w is not None:
  170. descriptions = descriptions[..., :crop_h, :crop_w]
  171. H, W = crop_h, crop_w
  172. if keypoints is not None:
  173. described_keypoints = F.grid_sample(
  174. descriptions.float(), keypoints[:, None], mode="bilinear", align_corners=False
  175. )[:, :, 0].mT
  176. return described_keypoints
  177. return descriptions
  178. @classmethod
  179. def from_pretrained(
  180. cls,
  181. detector_weights: str = "L-C4-v2",
  182. descriptor_weights: str = "G-upright",
  183. amp_dtype: torch.dtype = torch.float16,
  184. ) -> Module:
  185. r"""Load a pretrained model.
  186. Args:
  187. detector_weights: The weights to load for the detector.
  188. One of 'L-upright' (original paper, https://arxiv.org/abs/2308.08479),
  189. 'L-C4', 'L-SO2' (from steerers, better for rotations, https://arxiv.org/abs/2312.02152),
  190. 'L-C4-v2' (from dedode v2, better at rotations, less clustering, https://arxiv.org/abs/2404.08928)
  191. Default is 'L-C4-v2', but perhaps it should be 'L-C4-v2'?
  192. descriptor_weights: The weights to load for the descriptor.
  193. One of 'B-upright','G-upright' (original paper, https://arxiv.org/abs/2308.08479),
  194. 'B-C4', 'B-SO2', 'G-C4', 'G-SO2' (from steerers, better for rotations, https://arxiv.org/abs/2312.02152).
  195. Default is 'G-upright'.
  196. amp_dtype: the dtype to use for the model. One of torch.float16 or torch.float32.
  197. Default is torch.float16, suitable for CUDA. Use torch.float32 for CPU or MPS
  198. Returns:
  199. The pretrained model.
  200. """
  201. model: DeDoDe = cls(
  202. detector_model=detector_weights[0], # type: ignore[arg-type]
  203. descriptor_model=descriptor_weights[0], # type: ignore[arg-type]
  204. amp_dtype=amp_dtype,
  205. )
  206. model.detector.load_state_dict(
  207. torch.hub.load_state_dict_from_url(urls["detector"][detector_weights], map_location=torch.device("cpu"))
  208. )
  209. model.descriptor.load_state_dict(
  210. torch.hub.load_state_dict_from_url(urls["descriptor"][descriptor_weights], map_location=torch.device("cpu"))
  211. )
  212. model.eval()
  213. return model