dog_hardnet.py 1.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. import torch
  2. from kornia.color import rgb_to_grayscale
  3. from kornia.feature import HardNet, LAFDescriptor, laf_from_center_scale_ori
  4. from .sift import SIFT
  5. class DoGHardNet(SIFT):
  6. required_data_keys = ["image"]
  7. def __init__(self, **conf):
  8. super().__init__(**conf)
  9. self.laf_desc = LAFDescriptor(HardNet(True)).eval()
  10. def forward(self, data: dict) -> dict:
  11. image = data["image"]
  12. if image.shape[1] == 3:
  13. image = rgb_to_grayscale(image)
  14. device = image.device
  15. self.laf_desc = self.laf_desc.to(device)
  16. self.laf_desc.descriptor = self.laf_desc.descriptor.eval()
  17. pred = []
  18. if "image_size" in data.keys():
  19. im_size = data.get("image_size").long()
  20. else:
  21. im_size = None
  22. for k in range(len(image)):
  23. img = image[k]
  24. if im_size is not None:
  25. w, h = data["image_size"][k]
  26. img = img[:, : h.to(torch.int32), : w.to(torch.int32)]
  27. p = self.extract_single_image(img)
  28. lafs = laf_from_center_scale_ori(
  29. p["keypoints"].reshape(1, -1, 2),
  30. 6.0 * p["scales"].reshape(1, -1, 1, 1),
  31. torch.rad2deg(p["oris"]).reshape(1, -1, 1),
  32. ).to(device)
  33. p["descriptors"] = self.laf_desc(img[None], lafs).reshape(-1, 128)
  34. pred.append(p)
  35. pred = {k: torch.stack([p[k] for p in pred], 0).to(device) for k in pred[0]}
  36. return pred