sift.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. import warnings
  2. import cv2
  3. import numpy as np
  4. import torch
  5. from kornia.color import rgb_to_grayscale
  6. from packaging import version
  7. try:
  8. import pycolmap
  9. except ImportError:
  10. pycolmap = None
  11. from .utils import Extractor
  12. def filter_dog_point(points, scales, angles, image_shape, nms_radius, scores=None):
  13. h, w = image_shape
  14. ij = np.round(points - 0.5).astype(int).T[::-1]
  15. # Remove duplicate points (identical coordinates).
  16. # Pick highest scale or score
  17. s = scales if scores is None else scores
  18. buffer = np.zeros((h, w))
  19. np.maximum.at(buffer, tuple(ij), s)
  20. keep = np.where(buffer[tuple(ij)] == s)[0]
  21. # Pick lowest angle (arbitrary).
  22. ij = ij[:, keep]
  23. buffer[:] = np.inf
  24. o_abs = np.abs(angles[keep])
  25. np.minimum.at(buffer, tuple(ij), o_abs)
  26. mask = buffer[tuple(ij)] == o_abs
  27. ij = ij[:, mask]
  28. keep = keep[mask]
  29. if nms_radius > 0:
  30. # Apply NMS on the remaining points
  31. buffer[:] = 0
  32. buffer[tuple(ij)] = s[keep] # scores or scale
  33. local_max = torch.nn.functional.max_pool2d(
  34. torch.from_numpy(buffer).unsqueeze(0),
  35. kernel_size=nms_radius * 2 + 1,
  36. stride=1,
  37. padding=nms_radius,
  38. ).squeeze(0)
  39. is_local_max = buffer == local_max.numpy()
  40. keep = keep[is_local_max[tuple(ij)]]
  41. return keep
  42. def sift_to_rootsift(x: torch.Tensor, eps=1e-6) -> torch.Tensor:
  43. x = torch.nn.functional.normalize(x, p=1, dim=-1, eps=eps)
  44. x.clip_(min=eps).sqrt_()
  45. return torch.nn.functional.normalize(x, p=2, dim=-1, eps=eps)
  46. def run_opencv_sift(features: cv2.Feature2D, image: np.ndarray) -> np.ndarray:
  47. """
  48. Detect keypoints using OpenCV Detector.
  49. Optionally, perform description.
  50. Args:
  51. features: OpenCV based keypoints detector and descriptor
  52. image: Grayscale image of uint8 data type
  53. Returns:
  54. keypoints: 1D array of detected cv2.KeyPoint
  55. scores: 1D array of responses
  56. descriptors: 1D array of descriptors
  57. """
  58. detections, descriptors = features.detectAndCompute(image, None)
  59. points = np.array([k.pt for k in detections], dtype=np.float32)
  60. scores = np.array([k.response for k in detections], dtype=np.float32)
  61. scales = np.array([k.size for k in detections], dtype=np.float32)
  62. angles = np.deg2rad(np.array([k.angle for k in detections], dtype=np.float32))
  63. return points, scores, scales, angles, descriptors
  64. class SIFT(Extractor):
  65. default_conf = {
  66. "rootsift": True,
  67. "nms_radius": 0, # None to disable filtering entirely.
  68. "max_num_keypoints": 4096,
  69. "backend": "opencv", # in {opencv, pycolmap, pycolmap_cpu, pycolmap_cuda}
  70. "detection_threshold": 0.0066667, # from COLMAP
  71. "edge_threshold": 10,
  72. "first_octave": -1, # only used by pycolmap, the default of COLMAP
  73. "num_octaves": 4,
  74. }
  75. preprocess_conf = {
  76. "resize": 1024,
  77. }
  78. required_data_keys = ["image"]
  79. def __init__(self, **conf):
  80. super().__init__(**conf) # Update with default configuration.
  81. backend = self.conf.backend
  82. if backend.startswith("pycolmap"):
  83. if pycolmap is None:
  84. raise ImportError(
  85. "Cannot find module pycolmap: install it with pip"
  86. "or use backend=opencv."
  87. )
  88. options = {
  89. "peak_threshold": self.conf.detection_threshold,
  90. "edge_threshold": self.conf.edge_threshold,
  91. "first_octave": self.conf.first_octave,
  92. "num_octaves": self.conf.num_octaves,
  93. "normalization": pycolmap.Normalization.L2, # L1_ROOT is buggy.
  94. }
  95. device = (
  96. "auto" if backend == "pycolmap" else backend.replace("pycolmap_", "")
  97. )
  98. if (
  99. backend == "pycolmap_cpu" or not pycolmap.has_cuda
  100. ) and pycolmap.__version__ < "0.5.0":
  101. warnings.warn(
  102. "The pycolmap CPU SIFT is buggy in version < 0.5.0, "
  103. "consider upgrading pycolmap or use the CUDA version.",
  104. stacklevel=1,
  105. )
  106. else:
  107. options["max_num_features"] = self.conf.max_num_keypoints
  108. self.sift = pycolmap.Sift(options=options, device=device)
  109. elif backend == "opencv":
  110. self.sift = cv2.SIFT_create(
  111. contrastThreshold=self.conf.detection_threshold,
  112. nfeatures=self.conf.max_num_keypoints,
  113. edgeThreshold=self.conf.edge_threshold,
  114. nOctaveLayers=self.conf.num_octaves,
  115. )
  116. else:
  117. backends = {"opencv", "pycolmap", "pycolmap_cpu", "pycolmap_cuda"}
  118. raise ValueError(
  119. f"Unknown backend: {backend} not in " f"{{{','.join(backends)}}}."
  120. )
  121. def extract_single_image(self, image: torch.Tensor):
  122. image_np = image.cpu().numpy().squeeze(0)
  123. if self.conf.backend.startswith("pycolmap"):
  124. if version.parse(pycolmap.__version__) >= version.parse("0.5.0"):
  125. detections, descriptors = self.sift.extract(image_np)
  126. scores = None # Scores are not exposed by COLMAP anymore.
  127. else:
  128. detections, scores, descriptors = self.sift.extract(image_np)
  129. keypoints = detections[:, :2] # Keep only (x, y).
  130. scales, angles = detections[:, -2:].T
  131. if scores is not None and (
  132. self.conf.backend == "pycolmap_cpu" or not pycolmap.has_cuda
  133. ):
  134. # Set the scores as a combination of abs. response and scale.
  135. scores = np.abs(scores) * scales
  136. elif self.conf.backend == "opencv":
  137. # TODO: Check if opencv keypoints are already in corner convention
  138. keypoints, scores, scales, angles, descriptors = run_opencv_sift(
  139. self.sift, (image_np * 255.0).astype(np.uint8)
  140. )
  141. pred = {
  142. "keypoints": keypoints,
  143. "scales": scales,
  144. "oris": angles,
  145. "descriptors": descriptors,
  146. }
  147. if scores is not None:
  148. pred["keypoint_scores"] = scores
  149. # sometimes pycolmap returns points outside the image. We remove them
  150. if self.conf.backend.startswith("pycolmap"):
  151. is_inside = (
  152. pred["keypoints"] + 0.5 < np.array([image_np.shape[-2:][::-1]])
  153. ).all(-1)
  154. pred = {k: v[is_inside] for k, v in pred.items()}
  155. if self.conf.nms_radius is not None:
  156. keep = filter_dog_point(
  157. pred["keypoints"],
  158. pred["scales"],
  159. pred["oris"],
  160. image_np.shape,
  161. self.conf.nms_radius,
  162. scores=pred.get("keypoint_scores"),
  163. )
  164. pred = {k: v[keep] for k, v in pred.items()}
  165. pred = {k: torch.from_numpy(v) for k, v in pred.items()}
  166. if scores is not None:
  167. # Keep the k keypoints with highest score
  168. num_points = self.conf.max_num_keypoints
  169. if num_points is not None and len(pred["keypoints"]) > num_points:
  170. indices = torch.topk(pred["keypoint_scores"], num_points).indices
  171. pred = {k: v[indices] for k, v in pred.items()}
  172. return pred
  173. def forward(self, data: dict) -> dict:
  174. image = data["image"]
  175. if image.shape[1] == 3:
  176. image = rgb_to_grayscale(image)
  177. device = image.device
  178. image = image.cpu()
  179. pred = []
  180. for k in range(len(image)):
  181. img = image[k]
  182. if "image_size" in data.keys():
  183. # avoid extracting points in padded areas
  184. w, h = data["image_size"][k]
  185. img = img[:, :h, :w]
  186. p = self.extract_single_image(img)
  187. pred.append(p)
  188. pred = {k: torch.stack([p[k] for p in pred], 0).to(device) for k in pred[0]}
  189. if self.conf.rootsift:
  190. pred["descriptors"] = sift_to_rootsift(pred["descriptors"])
  191. return pred