superpoint.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. # %BANNER_BEGIN%
  2. # ---------------------------------------------------------------------
  3. # %COPYRIGHT_BEGIN%
  4. #
  5. # Magic Leap, Inc. ("COMPANY") CONFIDENTIAL
  6. #
  7. # Unpublished Copyright (c) 2020
  8. # Magic Leap, Inc., All Rights Reserved.
  9. #
  10. # NOTICE: All information contained herein is, and remains the property
  11. # of COMPANY. The intellectual and technical concepts contained herein
  12. # are proprietary to COMPANY and may be covered by U.S. and Foreign
  13. # Patents, patents in process, and are protected by trade secret or
  14. # copyright law. Dissemination of this information or reproduction of
  15. # this material is strictly forbidden unless prior written permission is
  16. # obtained from COMPANY. Access to the source code contained herein is
  17. # hereby forbidden to anyone except current COMPANY employees, managers
  18. # or contractors who have executed Confidentiality and Non-disclosure
  19. # agreements explicitly covering such access.
  20. #
  21. # The copyright notice above does not evidence any actual or intended
  22. # publication or disclosure of this source code, which includes
  23. # information that is confidential and/or proprietary, and is a trade
  24. # secret, of COMPANY. ANY REPRODUCTION, MODIFICATION, DISTRIBUTION,
  25. # PUBLIC PERFORMANCE, OR PUBLIC DISPLAY OF OR THROUGH USE OF THIS
  26. # SOURCE CODE WITHOUT THE EXPRESS WRITTEN CONSENT OF COMPANY IS
  27. # STRICTLY PROHIBITED, AND IN VIOLATION OF APPLICABLE LAWS AND
  28. # INTERNATIONAL TREATIES. THE RECEIPT OR POSSESSION OF THIS SOURCE
  29. # CODE AND/OR RELATED INFORMATION DOES NOT CONVEY OR IMPLY ANY RIGHTS
  30. # TO REPRODUCE, DISCLOSE OR DISTRIBUTE ITS CONTENTS, OR TO MANUFACTURE,
  31. # USE, OR SELL ANYTHING THAT IT MAY DESCRIBE, IN WHOLE OR IN PART.
  32. #
  33. # %COPYRIGHT_END%
  34. # ----------------------------------------------------------------------
  35. # %AUTHORS_BEGIN%
  36. #
  37. # Originating Authors: Paul-Edouard Sarlin
  38. #
  39. # %AUTHORS_END%
  40. # --------------------------------------------------------------------*/
  41. # %BANNER_END%
  42. # Adapted by Remi Pautrat, Philipp Lindenberger
  43. import torch
  44. from kornia.color import rgb_to_grayscale
  45. from torch import nn
  46. from .utils import Extractor
  47. def simple_nms(scores, nms_radius: int):
  48. """Fast Non-maximum suppression to remove nearby points"""
  49. assert nms_radius >= 0
  50. def max_pool(x):
  51. return torch.nn.functional.max_pool2d(
  52. x, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius
  53. )
  54. zeros = torch.zeros_like(scores)
  55. max_mask = scores == max_pool(scores)
  56. for _ in range(2):
  57. supp_mask = max_pool(max_mask.float()) > 0
  58. supp_scores = torch.where(supp_mask, zeros, scores)
  59. new_max_mask = supp_scores == max_pool(supp_scores)
  60. max_mask = max_mask | (new_max_mask & (~supp_mask))
  61. return torch.where(max_mask, scores, zeros)
  62. def top_k_keypoints(keypoints, scores, k):
  63. if k >= len(keypoints):
  64. return keypoints, scores
  65. scores, indices = torch.topk(scores, k, dim=0, sorted=True)
  66. return keypoints[indices], scores
  67. def sample_descriptors(keypoints, descriptors, s: int = 8):
  68. """Interpolate descriptors at keypoint locations"""
  69. b, c, h, w = descriptors.shape
  70. keypoints = keypoints - s / 2 + 0.5
  71. keypoints /= torch.tensor(
  72. [(w * s - s / 2 - 0.5), (h * s - s / 2 - 0.5)],
  73. ).to(
  74. keypoints
  75. )[None]
  76. keypoints = keypoints * 2 - 1 # normalize to (-1, 1)
  77. args = {"align_corners": True} if torch.__version__ >= "1.3" else {}
  78. descriptors = torch.nn.functional.grid_sample(
  79. descriptors, keypoints.view(b, 1, -1, 2), mode="bilinear", **args
  80. )
  81. descriptors = torch.nn.functional.normalize(
  82. descriptors.reshape(b, c, -1), p=2, dim=1
  83. )
  84. return descriptors
  85. class SuperPoint(Extractor):
  86. """SuperPoint Convolutional Detector and Descriptor
  87. SuperPoint: Self-Supervised Interest Point Detection and
  88. Description. Daniel DeTone, Tomasz Malisiewicz, and Andrew
  89. Rabinovich. In CVPRW, 2019. https://arxiv.org/abs/1712.07629
  90. """
  91. default_conf = {
  92. "descriptor_dim": 256,
  93. "nms_radius": 4,
  94. "max_num_keypoints": None,
  95. "detection_threshold": 0.0005,
  96. "remove_borders": 4,
  97. }
  98. preprocess_conf = {
  99. "resize": 1024,
  100. }
  101. required_data_keys = ["image"]
  102. def __init__(self, **conf):
  103. super().__init__(**conf) # Update with default configuration.
  104. self.relu = nn.ReLU(inplace=True)
  105. self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
  106. c1, c2, c3, c4, c5 = 64, 64, 128, 128, 256
  107. self.conv1a = nn.Conv2d(1, c1, kernel_size=3, stride=1, padding=1)
  108. self.conv1b = nn.Conv2d(c1, c1, kernel_size=3, stride=1, padding=1)
  109. self.conv2a = nn.Conv2d(c1, c2, kernel_size=3, stride=1, padding=1)
  110. self.conv2b = nn.Conv2d(c2, c2, kernel_size=3, stride=1, padding=1)
  111. self.conv3a = nn.Conv2d(c2, c3, kernel_size=3, stride=1, padding=1)
  112. self.conv3b = nn.Conv2d(c3, c3, kernel_size=3, stride=1, padding=1)
  113. self.conv4a = nn.Conv2d(c3, c4, kernel_size=3, stride=1, padding=1)
  114. self.conv4b = nn.Conv2d(c4, c4, kernel_size=3, stride=1, padding=1)
  115. self.convPa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
  116. self.convPb = nn.Conv2d(c5, 65, kernel_size=1, stride=1, padding=0)
  117. self.convDa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
  118. self.convDb = nn.Conv2d(
  119. c5, self.conf.descriptor_dim, kernel_size=1, stride=1, padding=0
  120. )
  121. url = "https://github.com/cvg/LightGlue/releases/download/v0.1_arxiv/superpoint_v1.pth" # noqa
  122. self.load_state_dict(torch.hub.load_state_dict_from_url(url))
  123. if self.conf.max_num_keypoints is not None and self.conf.max_num_keypoints <= 0:
  124. raise ValueError("max_num_keypoints must be positive or None")
  125. def forward(self, data: dict) -> dict:
  126. """Compute keypoints, scores, descriptors for image"""
  127. for key in self.required_data_keys:
  128. assert key in data, f"Missing key {key} in data"
  129. image = data["image"]
  130. if image.shape[1] == 3:
  131. image = rgb_to_grayscale(image)
  132. # Shared Encoder
  133. x = self.relu(self.conv1a(image))
  134. x = self.relu(self.conv1b(x))
  135. x = self.pool(x)
  136. x = self.relu(self.conv2a(x))
  137. x = self.relu(self.conv2b(x))
  138. x = self.pool(x)
  139. x = self.relu(self.conv3a(x))
  140. x = self.relu(self.conv3b(x))
  141. x = self.pool(x)
  142. x = self.relu(self.conv4a(x))
  143. x = self.relu(self.conv4b(x))
  144. # Compute the dense keypoint scores
  145. cPa = self.relu(self.convPa(x))
  146. scores = self.convPb(cPa)
  147. scores = torch.nn.functional.softmax(scores, 1)[:, :-1]
  148. b, _, h, w = scores.shape
  149. scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8)
  150. scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h * 8, w * 8)
  151. scores = simple_nms(scores, self.conf.nms_radius)
  152. # Discard keypoints near the image borders
  153. if self.conf.remove_borders:
  154. pad = self.conf.remove_borders
  155. scores[:, :pad] = -1
  156. scores[:, :, :pad] = -1
  157. scores[:, -pad:] = -1
  158. scores[:, :, -pad:] = -1
  159. # Extract keypoints
  160. best_kp = torch.where(scores > self.conf.detection_threshold)
  161. scores = scores[best_kp]
  162. # Separate into batches
  163. keypoints = [
  164. torch.stack(best_kp[1:3], dim=-1)[best_kp[0] == i] for i in range(b)
  165. ]
  166. scores = [scores[best_kp[0] == i] for i in range(b)]
  167. # Keep the k keypoints with highest score
  168. if self.conf.max_num_keypoints is not None:
  169. keypoints, scores = list(
  170. zip(
  171. *[
  172. top_k_keypoints(k, s, self.conf.max_num_keypoints)
  173. for k, s in zip(keypoints, scores)
  174. ]
  175. )
  176. )
  177. # Convert (h, w) to (x, y)
  178. keypoints = [torch.flip(k, [1]).float() for k in keypoints]
  179. # Compute the dense descriptors
  180. cDa = self.relu(self.convDa(x))
  181. descriptors = self.convDb(cDa)
  182. descriptors = torch.nn.functional.normalize(descriptors, p=2, dim=1)
  183. # Extract descriptors
  184. descriptors = [
  185. sample_descriptors(k[None], d[None], 8)[0]
  186. for k, d in zip(keypoints, descriptors)
  187. ]
  188. return {
  189. "keypoints": torch.stack(keypoints, 0),
  190. "keypoint_scores": torch.stack(scores, 0),
  191. "descriptors": torch.stack(descriptors, 0).transpose(-1, -2).contiguous(),
  192. }