# LICENSE HEADER MANAGED BY add-license-header # # Copyright 2018 Kornia Team # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from typing import Dict, Literal, Optional, Tuple import torch import torch.nn.functional as F from kornia.core import Module, Tensor from kornia.core.check import KORNIA_CHECK_SHAPE from kornia.enhance.normalize import Normalize from kornia.utils.helpers import map_location_to_cpu from .dedode_models import DeDoDeDescriptor, DeDoDeDetector, get_descriptor, get_detector from .utils import dedode_denormalize_pixel_coordinates, sample_keypoints urls: Dict[str, Dict[str, str]] = { "detector": { "L-upright": "https://github.com/Parskatt/DeDoDe/releases/download/dedode_pretrained_models/dedode_detector_L.pth", "L-C4": "https://github.com/georg-bn/rotation-steerers/releases/download/release-2/dedode_detector_C4.pth", "L-SO2": "https://github.com/georg-bn/rotation-steerers/releases/download/release-2/dedode_detector_SO2.pth", "L-C4-v2": "https://github.com/Parskatt/DeDoDe/releases/download/v2/dedode_detector_L_v2.pth", }, "descriptor": { "B-upright": "https://github.com/Parskatt/DeDoDe/releases/download/dedode_pretrained_models/dedode_descriptor_B.pth", "B-C4": "https://github.com/georg-bn/rotation-steerers/releases/download/release-2/B_C4_Perm_descriptor_setting_C.pth", "B-SO2": "https://github.com/georg-bn/rotation-steerers/releases/download/release-2/B_SO2_Spread_descriptor_setting_C.pth", "G-upright": "https://github.com/Parskatt/DeDoDe/releases/download/dedode_pretrained_models/dedode_descriptor_G.pth", "G-C4": "https://github.com/georg-bn/rotation-steerers/releases/download/release-2/G_C4_Perm_descriptor_setting_C.pth", "G-SO2": "https://github.com/georg-bn/rotation-steerers/releases/download/release-2/G_SO2_Spread_descriptor_setting_C.pth", }, } class DeDoDe(Module): r"""Module which detects and/or describes local features in an image using the DeDode method. See :cite:`edstedt2024dedode` for details. .. note:: DeDode takes ImageNet normalized images as input (not in range [0, 1]). Args: detector_model: The detector model kind. Available options are: `L`. descriptor_model: The descriptor model kind. Available options are: `G` or `B` amp_dtype: The automatic mixed precision desired. Example: >>> dedode = DeDoDe.from_pretrained(detector_weights="L-C4-v2", descriptor_weights="B-upright") >>> images = torch.randn(1, 3, 256, 256) >>> keypoints, scores = dedode.detect(images) >>> descriptions = dedode.describe(images, keypoints = keypoints) >>> keypoints, scores, features = dedode(images) # alternatively do both """ # TODO: implement steerers and mnn matchers def __init__( self, detector_model: Literal["L"] = "L", descriptor_model: Literal["G", "B"] = "G", amp_dtype: torch.dtype = torch.float16, ) -> None: super().__init__() self.detector: DeDoDeDetector = get_detector(detector_model, amp_dtype) self.descriptor: DeDoDeDescriptor = get_descriptor(descriptor_model, amp_dtype) self.normalizer = Normalize(torch.tensor([0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225])) def forward( self, images: Tensor, n: Optional[int] = 10_000, apply_imagenet_normalization: bool = True, pad_if_not_divisible: bool = True, ) -> Tuple[Tensor, Tensor, Tensor]: """Detect and describe keypoints in the input images. Args: images: A tensor of shape :math:`(B, 3, H, W)` containing the ImageNet-Normalized input images. n: The number of keypoints to detect. apply_imagenet_normalization: Whether to apply ImageNet normalization to the input images. pad_if_not_divisible: pad image shape if not evenly divisible. Returns: keypoints: A tensor of shape :math:`(B, N, 2)` containing the detected keypoints in the image range, unlike `.detect()` function scores: A tensor of shape :math:`(B, N)` containing the scores of the detected keypoints. descriptions: A tensor of shape :math:`(B, N, DIM)` containing the descriptions of the detected keypoints. DIM is 256 for B and 512 for G. """ if apply_imagenet_normalization: images = self.normalizer(images) _B, _C, H, W = images.shape h, w = images.shape[2:] if pad_if_not_divisible: pd_h = 14 - h % 14 if h % 14 > 0 else 0 pd_w = 14 - w % 14 if w % 14 > 0 else 0 images = torch.nn.functional.pad(images, (0, pd_w, 0, pd_h), value=0.0) keypoints, scores = self.detect(images, n=n, apply_imagenet_normalization=False, crop_h=h, crop_w=w) descriptions = self.describe(images, keypoints, apply_imagenet_normalization=False, crop_h=h, crop_w=w) return dedode_denormalize_pixel_coordinates(keypoints, H, W), scores, descriptions @torch.inference_mode() def detect( self, images: Tensor, n: Optional[int] = 10_000, apply_imagenet_normalization: bool = True, pad_if_not_divisible: bool = True, crop_h: Optional[int] = None, crop_w: Optional[int] = None, ) -> Tuple[Tensor, Tensor]: """Detect keypoints in the input images. Args: images: A tensor of shape :math:`(B, 3, H, W)` containing the input images. n: The number of keypoints to detect. apply_imagenet_normalization: Whether to apply ImageNet normalization to the input images. pad_if_not_divisible: pad image shape if not evenly divisible. crop_h: The height of the crop to be used for detection. If None, the full image is used. crop_w: The width of the crop to be used for detection. If None, the full image is used. Returns: keypoints: A tensor of shape :math:`(B, N, 2)` containing the detected keypoints, normalized to the range :math:`[-1, 1]`. scores: A tensor of shape :math:`(B, N)` containing the scores of the detected keypoints. """ KORNIA_CHECK_SHAPE(images, ["B", "3", "H", "W"]) self.train(False) B, _C, H, W = images.shape if pad_if_not_divisible: h, w = images.shape[2:] pd_h = 14 - h % 14 if h % 14 > 0 else 0 pd_w = 14 - w % 14 if w % 14 > 0 else 0 images = torch.nn.functional.pad(images, (0, pd_w, 0, pd_h), value=0.0) if apply_imagenet_normalization: images = self.normalizer(images) logits = self.detector.forward(images) # Remove the padding, if any logits = logits[..., :H, :W] if crop_h is not None and crop_w is not None: logits = logits[..., :crop_h, :crop_w] H, W = crop_h, crop_w scoremap = logits.reshape(B, H * W).softmax(dim=-1).reshape(B, H, W) keypoints, confidence = sample_keypoints(scoremap, num_samples=n) return keypoints, confidence @torch.inference_mode() def describe( self, images: Tensor, keypoints: Optional[Tensor] = None, apply_imagenet_normalization: bool = True, crop_h: Optional[int] = None, crop_w: Optional[int] = None, ) -> Tensor: """Describe keypoints in the input images. If keypoints are not provided, returns the dense descriptors. Args: images: A tensor of shape :math:`(B, 3, H, W)` containing the input images. keypoints: An optional tensor of shape :math:`(B, N, 2)` containing the detected keypoints. apply_imagenet_normalization: Whether to apply ImageNet normalization to the input images. crop_h: The height of the crop to be used for description. If None, the full image is used. crop_w: The width of the crop to be used for description. If None, the full image is used. Returns: descriptions: A tensor of shape :math:`(B, N, DIM)` containing the descriptions of the detected keypoints. If the dense descriptors are requested, the shape is :math:`(B, DIM, H, W)`. """ KORNIA_CHECK_SHAPE(images, ["B", "3", "H", "W"]) _B, _C, H, W = images.shape if keypoints is not None: KORNIA_CHECK_SHAPE(keypoints, ["B", "N", "2"]) if apply_imagenet_normalization: images = self.normalizer(images) self.train(False) descriptions = self.descriptor.forward(images) if crop_h is not None and crop_w is not None: descriptions = descriptions[..., :crop_h, :crop_w] H, W = crop_h, crop_w if keypoints is not None: described_keypoints = F.grid_sample( descriptions.float(), keypoints[:, None], mode="bilinear", align_corners=False )[:, :, 0].mT return described_keypoints return descriptions @classmethod def from_pretrained( cls, detector_weights: str = "L-C4-v2", descriptor_weights: str = "G-upright", amp_dtype: torch.dtype = torch.float16, ) -> Module: r"""Load a pretrained model. Args: detector_weights: The weights to load for the detector. One of 'L-upright' (original paper, https://arxiv.org/abs/2308.08479), 'L-C4', 'L-SO2' (from steerers, better for rotations, https://arxiv.org/abs/2312.02152), 'L-C4-v2' (from dedode v2, better at rotations, less clustering, https://arxiv.org/abs/2404.08928) Default is 'L-C4-v2', but perhaps it should be 'L-C4-v2'? descriptor_weights: The weights to load for the descriptor. One of 'B-upright','G-upright' (original paper, https://arxiv.org/abs/2308.08479), 'B-C4', 'B-SO2', 'G-C4', 'G-SO2' (from steerers, better for rotations, https://arxiv.org/abs/2312.02152). Default is 'G-upright'. amp_dtype: the dtype to use for the model. One of torch.float16 or torch.float32. Default is torch.float16, suitable for CUDA. Use torch.float32 for CPU or MPS Returns: The pretrained model. """ model: DeDoDe = cls( detector_model=detector_weights[0], # type: ignore[arg-type] descriptor_model=descriptor_weights[0], # type: ignore[arg-type] amp_dtype=amp_dtype, ) model.detector.load_state_dict( torch.hub.load_state_dict_from_url(urls["detector"][detector_weights], map_location=torch.device("cpu")) ) model.descriptor.load_state_dict( torch.hub.load_state_dict_from_url(urls["descriptor"][descriptor_weights], map_location=torch.device("cpu")) ) model.eval() return model