disk.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. import kornia
  2. import torch
  3. from .utils import Extractor
  4. class DISK(Extractor):
  5. default_conf = {
  6. "weights": "depth",
  7. "max_num_keypoints": None,
  8. "desc_dim": 128,
  9. "nms_window_size": 5,
  10. "detection_threshold": 0.0,
  11. "pad_if_not_divisible": True,
  12. }
  13. preprocess_conf = {
  14. "resize": 1024,
  15. "grayscale": False,
  16. }
  17. required_data_keys = ["image"]
  18. def __init__(self, **conf) -> None:
  19. super().__init__(**conf) # Update with default configuration.
  20. self.model = kornia.feature.DISK.from_pretrained(self.conf.weights)
  21. def forward(self, data: dict) -> dict:
  22. """Compute keypoints, scores, descriptors for image"""
  23. for key in self.required_data_keys:
  24. assert key in data, f"Missing key {key} in data"
  25. image = data["image"]
  26. if image.shape[1] == 1:
  27. image = kornia.color.grayscale_to_rgb(image)
  28. features = self.model(
  29. image,
  30. n=self.conf.max_num_keypoints,
  31. window_size=self.conf.nms_window_size,
  32. score_threshold=self.conf.detection_threshold,
  33. pad_if_not_divisible=self.conf.pad_if_not_divisible,
  34. )
  35. keypoints = [f.keypoints for f in features]
  36. scores = [f.detection_scores for f in features]
  37. descriptors = [f.descriptors for f in features]
  38. del features
  39. keypoints = torch.stack(keypoints, 0)
  40. scores = torch.stack(scores, 0)
  41. descriptors = torch.stack(descriptors, 0)
  42. return {
  43. "keypoints": keypoints.to(image).contiguous(),
  44. "keypoint_scores": scores.to(image).contiguous(),
  45. "descriptors": descriptors.to(image).contiguous(),
  46. }