| 1234567891011121314151617181920212223242526272829303132333435363738394041 |
- import torch
- from kornia.color import rgb_to_grayscale
- from kornia.feature import HardNet, LAFDescriptor, laf_from_center_scale_ori
- from .sift import SIFT
- class DoGHardNet(SIFT):
- required_data_keys = ["image"]
- def __init__(self, **conf):
- super().__init__(**conf)
- self.laf_desc = LAFDescriptor(HardNet(True)).eval()
- def forward(self, data: dict) -> dict:
- image = data["image"]
- if image.shape[1] == 3:
- image = rgb_to_grayscale(image)
- device = image.device
- self.laf_desc = self.laf_desc.to(device)
- self.laf_desc.descriptor = self.laf_desc.descriptor.eval()
- pred = []
- if "image_size" in data.keys():
- im_size = data.get("image_size").long()
- else:
- im_size = None
- for k in range(len(image)):
- img = image[k]
- if im_size is not None:
- w, h = data["image_size"][k]
- img = img[:, : h.to(torch.int32), : w.to(torch.int32)]
- p = self.extract_single_image(img)
- lafs = laf_from_center_scale_ori(
- p["keypoints"].reshape(1, -1, 2),
- 6.0 * p["scales"].reshape(1, -1, 1, 1),
- torch.rad2deg(p["oris"]).reshape(1, -1, 1),
- ).to(device)
- p["descriptors"] = self.laf_desc(img[None], lafs).reshape(-1, 128)
- pred.append(p)
- pred = {k: torch.stack([p[k] for p in pred], 0).to(device) for k in pred[0]}
- return pred
|