| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243 |
- # LICENSE HEADER MANAGED BY add-license-header
- #
- # Copyright 2018 Kornia Team
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- from typing import Dict, Literal, Optional, Tuple
- import torch
- import torch.nn.functional as F
- from kornia.core import Module, Tensor
- from kornia.core.check import KORNIA_CHECK_SHAPE
- from kornia.enhance.normalize import Normalize
- from kornia.utils.helpers import map_location_to_cpu
- from .dedode_models import DeDoDeDescriptor, DeDoDeDetector, get_descriptor, get_detector
- from .utils import dedode_denormalize_pixel_coordinates, sample_keypoints
- urls: Dict[str, Dict[str, str]] = {
- "detector": {
- "L-upright": "https://github.com/Parskatt/DeDoDe/releases/download/dedode_pretrained_models/dedode_detector_L.pth",
- "L-C4": "https://github.com/georg-bn/rotation-steerers/releases/download/release-2/dedode_detector_C4.pth",
- "L-SO2": "https://github.com/georg-bn/rotation-steerers/releases/download/release-2/dedode_detector_SO2.pth",
- "L-C4-v2": "https://github.com/Parskatt/DeDoDe/releases/download/v2/dedode_detector_L_v2.pth",
- },
- "descriptor": {
- "B-upright": "https://github.com/Parskatt/DeDoDe/releases/download/dedode_pretrained_models/dedode_descriptor_B.pth",
- "B-C4": "https://github.com/georg-bn/rotation-steerers/releases/download/release-2/B_C4_Perm_descriptor_setting_C.pth",
- "B-SO2": "https://github.com/georg-bn/rotation-steerers/releases/download/release-2/B_SO2_Spread_descriptor_setting_C.pth",
- "G-upright": "https://github.com/Parskatt/DeDoDe/releases/download/dedode_pretrained_models/dedode_descriptor_G.pth",
- "G-C4": "https://github.com/georg-bn/rotation-steerers/releases/download/release-2/G_C4_Perm_descriptor_setting_C.pth",
- "G-SO2": "https://github.com/georg-bn/rotation-steerers/releases/download/release-2/G_SO2_Spread_descriptor_setting_C.pth",
- },
- }
- class DeDoDe(Module):
- r"""Module which detects and/or describes local features in an image using the DeDode method.
- See :cite:`edstedt2024dedode` for details.
- .. note:: DeDode takes ImageNet normalized images as input (not in range [0, 1]).
- Args:
- detector_model: The detector model kind. Available options are: `L`.
- descriptor_model: The descriptor model kind. Available options are: `G` or `B`
- amp_dtype: The automatic mixed precision desired.
- Example:
- >>> dedode = DeDoDe.from_pretrained(detector_weights="L-C4-v2", descriptor_weights="B-upright")
- >>> images = torch.randn(1, 3, 256, 256)
- >>> keypoints, scores = dedode.detect(images)
- >>> descriptions = dedode.describe(images, keypoints = keypoints)
- >>> keypoints, scores, features = dedode(images) # alternatively do both
- """
- # TODO: implement steerers and mnn matchers
- def __init__(
- self,
- detector_model: Literal["L"] = "L",
- descriptor_model: Literal["G", "B"] = "G",
- amp_dtype: torch.dtype = torch.float16,
- ) -> None:
- super().__init__()
- self.detector: DeDoDeDetector = get_detector(detector_model, amp_dtype)
- self.descriptor: DeDoDeDescriptor = get_descriptor(descriptor_model, amp_dtype)
- self.normalizer = Normalize(torch.tensor([0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225]))
- def forward(
- self,
- images: Tensor,
- n: Optional[int] = 10_000,
- apply_imagenet_normalization: bool = True,
- pad_if_not_divisible: bool = True,
- ) -> Tuple[Tensor, Tensor, Tensor]:
- """Detect and describe keypoints in the input images.
- Args:
- images: A tensor of shape :math:`(B, 3, H, W)` containing the ImageNet-Normalized input images.
- n: The number of keypoints to detect.
- apply_imagenet_normalization: Whether to apply ImageNet normalization to the input images.
- pad_if_not_divisible: pad image shape if not evenly divisible.
- Returns:
- keypoints: A tensor of shape :math:`(B, N, 2)` containing the detected keypoints in the image range,
- unlike `.detect()` function
- scores: A tensor of shape :math:`(B, N)` containing the scores of the detected keypoints.
- descriptions: A tensor of shape :math:`(B, N, DIM)` containing the descriptions of the detected keypoints.
- DIM is 256 for B and 512 for G.
- """
- if apply_imagenet_normalization:
- images = self.normalizer(images)
- _B, _C, H, W = images.shape
- h, w = images.shape[2:]
- if pad_if_not_divisible:
- pd_h = 14 - h % 14 if h % 14 > 0 else 0
- pd_w = 14 - w % 14 if w % 14 > 0 else 0
- images = torch.nn.functional.pad(images, (0, pd_w, 0, pd_h), value=0.0)
- keypoints, scores = self.detect(images, n=n, apply_imagenet_normalization=False, crop_h=h, crop_w=w)
- descriptions = self.describe(images, keypoints, apply_imagenet_normalization=False, crop_h=h, crop_w=w)
- return dedode_denormalize_pixel_coordinates(keypoints, H, W), scores, descriptions
- @torch.inference_mode()
- def detect(
- self,
- images: Tensor,
- n: Optional[int] = 10_000,
- apply_imagenet_normalization: bool = True,
- pad_if_not_divisible: bool = True,
- crop_h: Optional[int] = None,
- crop_w: Optional[int] = None,
- ) -> Tuple[Tensor, Tensor]:
- """Detect keypoints in the input images.
- Args:
- images: A tensor of shape :math:`(B, 3, H, W)` containing the input images.
- n: The number of keypoints to detect.
- apply_imagenet_normalization: Whether to apply ImageNet normalization to the input images.
- pad_if_not_divisible: pad image shape if not evenly divisible.
- crop_h: The height of the crop to be used for detection. If None, the full image is used.
- crop_w: The width of the crop to be used for detection. If None, the full image is used.
- Returns:
- keypoints: A tensor of shape :math:`(B, N, 2)` containing the detected keypoints,
- normalized to the range :math:`[-1, 1]`.
- scores: A tensor of shape :math:`(B, N)` containing the scores of the detected keypoints.
- """
- KORNIA_CHECK_SHAPE(images, ["B", "3", "H", "W"])
- self.train(False)
- B, _C, H, W = images.shape
- if pad_if_not_divisible:
- h, w = images.shape[2:]
- pd_h = 14 - h % 14 if h % 14 > 0 else 0
- pd_w = 14 - w % 14 if w % 14 > 0 else 0
- images = torch.nn.functional.pad(images, (0, pd_w, 0, pd_h), value=0.0)
- if apply_imagenet_normalization:
- images = self.normalizer(images)
- logits = self.detector.forward(images)
- # Remove the padding, if any
- logits = logits[..., :H, :W]
- if crop_h is not None and crop_w is not None:
- logits = logits[..., :crop_h, :crop_w]
- H, W = crop_h, crop_w
- scoremap = logits.reshape(B, H * W).softmax(dim=-1).reshape(B, H, W)
- keypoints, confidence = sample_keypoints(scoremap, num_samples=n)
- return keypoints, confidence
- @torch.inference_mode()
- def describe(
- self,
- images: Tensor,
- keypoints: Optional[Tensor] = None,
- apply_imagenet_normalization: bool = True,
- crop_h: Optional[int] = None,
- crop_w: Optional[int] = None,
- ) -> Tensor:
- """Describe keypoints in the input images. If keypoints are not provided, returns the dense descriptors.
- Args:
- images: A tensor of shape :math:`(B, 3, H, W)` containing the input images.
- keypoints: An optional tensor of shape :math:`(B, N, 2)` containing the detected keypoints.
- apply_imagenet_normalization: Whether to apply ImageNet normalization to the input images.
- crop_h: The height of the crop to be used for description. If None, the full image is used.
- crop_w: The width of the crop to be used for description. If None, the full image is used.
- Returns:
- descriptions: A tensor of shape :math:`(B, N, DIM)` containing the descriptions of the detected keypoints.
- If the dense descriptors are requested, the shape is :math:`(B, DIM, H, W)`.
- """
- KORNIA_CHECK_SHAPE(images, ["B", "3", "H", "W"])
- _B, _C, H, W = images.shape
- if keypoints is not None:
- KORNIA_CHECK_SHAPE(keypoints, ["B", "N", "2"])
- if apply_imagenet_normalization:
- images = self.normalizer(images)
- self.train(False)
- descriptions = self.descriptor.forward(images)
- if crop_h is not None and crop_w is not None:
- descriptions = descriptions[..., :crop_h, :crop_w]
- H, W = crop_h, crop_w
- if keypoints is not None:
- described_keypoints = F.grid_sample(
- descriptions.float(), keypoints[:, None], mode="bilinear", align_corners=False
- )[:, :, 0].mT
- return described_keypoints
- return descriptions
- @classmethod
- def from_pretrained(
- cls,
- detector_weights: str = "L-C4-v2",
- descriptor_weights: str = "G-upright",
- amp_dtype: torch.dtype = torch.float16,
- ) -> Module:
- r"""Load a pretrained model.
- Args:
- detector_weights: The weights to load for the detector.
- One of 'L-upright' (original paper, https://arxiv.org/abs/2308.08479),
- 'L-C4', 'L-SO2' (from steerers, better for rotations, https://arxiv.org/abs/2312.02152),
- 'L-C4-v2' (from dedode v2, better at rotations, less clustering, https://arxiv.org/abs/2404.08928)
- Default is 'L-C4-v2', but perhaps it should be 'L-C4-v2'?
- descriptor_weights: The weights to load for the descriptor.
- One of 'B-upright','G-upright' (original paper, https://arxiv.org/abs/2308.08479),
- 'B-C4', 'B-SO2', 'G-C4', 'G-SO2' (from steerers, better for rotations, https://arxiv.org/abs/2312.02152).
- Default is 'G-upright'.
- amp_dtype: the dtype to use for the model. One of torch.float16 or torch.float32.
- Default is torch.float16, suitable for CUDA. Use torch.float32 for CPU or MPS
- Returns:
- The pretrained model.
- """
- model: DeDoDe = cls(
- detector_model=detector_weights[0], # type: ignore[arg-type]
- descriptor_model=descriptor_weights[0], # type: ignore[arg-type]
- amp_dtype=amp_dtype,
- )
- model.detector.load_state_dict(
- torch.hub.load_state_dict_from_url(urls["detector"][detector_weights], map_location=torch.device("cpu"))
- )
- model.descriptor.load_state_dict(
- torch.hub.load_state_dict_from_url(urls["descriptor"][descriptor_weights], map_location=torch.device("cpu"))
- )
- model.eval()
- return model
|