modeling_superpoint.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468
  1. # Copyright 2024 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch SuperPoint model."""
  15. from dataclasses import dataclass
  16. import torch
  17. from torch import nn
  18. from transformers import PreTrainedModel
  19. from transformers.modeling_outputs import (
  20. BaseModelOutputWithNoAttention,
  21. )
  22. from transformers.models.superpoint.configuration_superpoint import SuperPointConfig
  23. from ...utils import (
  24. ModelOutput,
  25. auto_docstring,
  26. logging,
  27. )
  28. logger = logging.get_logger(__name__)
  29. def remove_keypoints_from_borders(
  30. keypoints: torch.Tensor, scores: torch.Tensor, border: int, height: int, width: int
  31. ) -> tuple[torch.Tensor, torch.Tensor]:
  32. """Removes keypoints (and their associated scores) that are too close to the border"""
  33. mask_h = (keypoints[:, 0] >= border) & (keypoints[:, 0] < (height - border))
  34. mask_w = (keypoints[:, 1] >= border) & (keypoints[:, 1] < (width - border))
  35. mask = mask_h & mask_w
  36. return keypoints[mask], scores[mask]
  37. def top_k_keypoints(keypoints: torch.Tensor, scores: torch.Tensor, k: int) -> tuple[torch.Tensor, torch.Tensor]:
  38. """Keeps the k keypoints with highest score"""
  39. if k >= len(keypoints):
  40. return keypoints, scores
  41. scores, indices = torch.topk(scores, k, dim=0)
  42. return keypoints[indices], scores
  43. def simple_nms(scores: torch.Tensor, nms_radius: int) -> torch.Tensor:
  44. """Applies non-maximum suppression on scores"""
  45. if nms_radius < 0:
  46. raise ValueError("Expected positive values for nms_radius")
  47. def max_pool(x):
  48. return nn.functional.max_pool2d(x, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius)
  49. zeros = torch.zeros_like(scores)
  50. max_mask = scores == max_pool(scores)
  51. for _ in range(2):
  52. supp_mask = max_pool(max_mask.float()) > 0
  53. supp_scores = torch.where(supp_mask, zeros, scores)
  54. new_max_mask = supp_scores == max_pool(supp_scores)
  55. max_mask = max_mask | (new_max_mask & (~supp_mask))
  56. return torch.where(max_mask, scores, zeros)
  57. @dataclass
  58. @auto_docstring(
  59. custom_intro="""
  60. Base class for outputs of image point description models. Due to the nature of keypoint detection, the number of
  61. keypoints is not fixed and can vary from image to image, which makes batching non-trivial. In the batch of images,
  62. the maximum number of keypoints is set as the dimension of the keypoints, scores and descriptors tensors. The mask
  63. tensor is used to indicate which values in the keypoints, scores and descriptors tensors are keypoint information
  64. and which are padding.
  65. """
  66. )
  67. class SuperPointKeypointDescriptionOutput(ModelOutput):
  68. r"""
  69. loss (`torch.FloatTensor` of shape `(1,)`, *optional*):
  70. Loss computed during training.
  71. keypoints (`torch.FloatTensor` of shape `(batch_size, num_keypoints, 2)`):
  72. Relative (x, y) coordinates of predicted keypoints in a given image.
  73. scores (`torch.FloatTensor` of shape `(batch_size, num_keypoints)`):
  74. Scores of predicted keypoints.
  75. descriptors (`torch.FloatTensor` of shape `(batch_size, num_keypoints, descriptor_size)`):
  76. Descriptors of predicted keypoints.
  77. mask (`torch.BoolTensor` of shape `(batch_size, num_keypoints)`):
  78. Mask indicating which values in keypoints, scores and descriptors are keypoint information.
  79. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or
  80. when `config.output_hidden_states=True`):
  81. Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
  82. one for the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states
  83. (also called feature maps) of the model at the output of each stage.
  84. """
  85. loss: torch.FloatTensor | None = None
  86. keypoints: torch.IntTensor | None = None
  87. scores: torch.FloatTensor | None = None
  88. descriptors: torch.FloatTensor | None = None
  89. mask: torch.BoolTensor | None = None
  90. hidden_states: tuple[torch.FloatTensor] | None = None
  91. class SuperPointConvBlock(nn.Module):
  92. def __init__(
  93. self, config: SuperPointConfig, in_channels: int, out_channels: int, add_pooling: bool = False
  94. ) -> None:
  95. super().__init__()
  96. self.conv_a = nn.Conv2d(
  97. in_channels,
  98. out_channels,
  99. kernel_size=3,
  100. stride=1,
  101. padding=1,
  102. )
  103. self.conv_b = nn.Conv2d(
  104. out_channels,
  105. out_channels,
  106. kernel_size=3,
  107. stride=1,
  108. padding=1,
  109. )
  110. self.relu = nn.ReLU(inplace=True)
  111. self.pool = nn.MaxPool2d(kernel_size=2, stride=2) if add_pooling else None
  112. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  113. hidden_states = self.relu(self.conv_a(hidden_states))
  114. hidden_states = self.relu(self.conv_b(hidden_states))
  115. if self.pool is not None:
  116. hidden_states = self.pool(hidden_states)
  117. return hidden_states
  118. class SuperPointEncoder(nn.Module):
  119. """
  120. SuperPoint encoder module. It is made of 4 convolutional layers with ReLU activation and max pooling, reducing the
  121. dimensionality of the image.
  122. """
  123. def __init__(self, config: SuperPointConfig) -> None:
  124. super().__init__()
  125. # SuperPoint uses 1 channel images
  126. self.input_dim = 1
  127. conv_blocks = []
  128. conv_blocks.append(
  129. SuperPointConvBlock(config, self.input_dim, config.encoder_hidden_sizes[0], add_pooling=True)
  130. )
  131. for i in range(1, len(config.encoder_hidden_sizes) - 1):
  132. conv_blocks.append(
  133. SuperPointConvBlock(
  134. config, config.encoder_hidden_sizes[i - 1], config.encoder_hidden_sizes[i], add_pooling=True
  135. )
  136. )
  137. conv_blocks.append(
  138. SuperPointConvBlock(
  139. config, config.encoder_hidden_sizes[-2], config.encoder_hidden_sizes[-1], add_pooling=False
  140. )
  141. )
  142. self.conv_blocks = nn.ModuleList(conv_blocks)
  143. def forward(
  144. self,
  145. input,
  146. output_hidden_states: bool | None = False,
  147. return_dict: bool | None = True,
  148. ) -> tuple | BaseModelOutputWithNoAttention:
  149. all_hidden_states = () if output_hidden_states else None
  150. for conv_block in self.conv_blocks:
  151. input = conv_block(input)
  152. if output_hidden_states:
  153. all_hidden_states = all_hidden_states + (input,)
  154. output = input
  155. if not return_dict:
  156. return tuple(v for v in [output, all_hidden_states] if v is not None)
  157. return BaseModelOutputWithNoAttention(
  158. last_hidden_state=output,
  159. hidden_states=all_hidden_states,
  160. )
  161. class SuperPointInterestPointDecoder(nn.Module):
  162. """
  163. The SuperPointInterestPointDecoder uses the output of the SuperPointEncoder to compute the keypoint with scores.
  164. The scores are first computed by a convolutional layer, then a softmax is applied to get a probability distribution
  165. over the 65 possible keypoint classes. The keypoints are then extracted from the scores by thresholding and
  166. non-maximum suppression. Post-processing is then applied to remove keypoints too close to the image borders as well
  167. as to keep only the k keypoints with highest score.
  168. """
  169. def __init__(self, config: SuperPointConfig) -> None:
  170. super().__init__()
  171. self.keypoint_threshold = config.keypoint_threshold
  172. self.max_keypoints = config.max_keypoints
  173. self.nms_radius = config.nms_radius
  174. self.border_removal_distance = config.border_removal_distance
  175. self.relu = nn.ReLU(inplace=True)
  176. self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
  177. self.conv_score_a = nn.Conv2d(
  178. config.encoder_hidden_sizes[-1],
  179. config.decoder_hidden_size,
  180. kernel_size=3,
  181. stride=1,
  182. padding=1,
  183. )
  184. self.conv_score_b = nn.Conv2d(
  185. config.decoder_hidden_size, config.keypoint_decoder_dim, kernel_size=1, stride=1, padding=0
  186. )
  187. def forward(self, encoded: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
  188. scores = self._get_pixel_scores(encoded)
  189. keypoints, scores = self._extract_keypoints(scores)
  190. return keypoints, scores
  191. def _get_pixel_scores(self, encoded: torch.Tensor) -> torch.Tensor:
  192. """Based on the encoder output, compute the scores for each pixel of the image"""
  193. scores = self.relu(self.conv_score_a(encoded))
  194. scores = self.conv_score_b(scores)
  195. scores = nn.functional.softmax(scores, 1)[:, :-1]
  196. batch_size, _, height, width = scores.shape
  197. scores = scores.permute(0, 2, 3, 1).reshape(batch_size, height, width, 8, 8)
  198. scores = scores.permute(0, 1, 3, 2, 4).reshape(batch_size, height * 8, width * 8)
  199. scores = simple_nms(scores, self.nms_radius)
  200. return scores
  201. def _extract_keypoints(self, scores: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
  202. """
  203. Based on their scores, extract the pixels that represent the keypoints that will be used for descriptors computation.
  204. The keypoints are in the form of relative (x, y) coordinates.
  205. """
  206. _, height, width = scores.shape
  207. # Threshold keypoints by score value
  208. keypoints = torch.nonzero(scores[0] > self.keypoint_threshold)
  209. scores = scores[0][tuple(keypoints.t())]
  210. # Discard keypoints near the image borders
  211. keypoints, scores = remove_keypoints_from_borders(
  212. keypoints, scores, self.border_removal_distance, height * 8, width * 8
  213. )
  214. # Keep the k keypoints with highest score
  215. if self.max_keypoints >= 0:
  216. keypoints, scores = top_k_keypoints(keypoints, scores, self.max_keypoints)
  217. # Convert (y, x) to (x, y)
  218. keypoints = torch.flip(keypoints, [1]).to(scores.dtype)
  219. return keypoints, scores
  220. class SuperPointDescriptorDecoder(nn.Module):
  221. """
  222. The SuperPointDescriptorDecoder uses the outputs of both the SuperPointEncoder and the
  223. SuperPointInterestPointDecoder to compute the descriptors at the keypoints locations.
  224. The descriptors are first computed by a convolutional layer, then normalized to have a norm of 1. The descriptors
  225. are then interpolated at the keypoints locations.
  226. """
  227. def __init__(self, config: SuperPointConfig) -> None:
  228. super().__init__()
  229. self.relu = nn.ReLU(inplace=True)
  230. self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
  231. self.conv_descriptor_a = nn.Conv2d(
  232. config.encoder_hidden_sizes[-1],
  233. config.decoder_hidden_size,
  234. kernel_size=3,
  235. stride=1,
  236. padding=1,
  237. )
  238. self.conv_descriptor_b = nn.Conv2d(
  239. config.decoder_hidden_size,
  240. config.descriptor_decoder_dim,
  241. kernel_size=1,
  242. stride=1,
  243. padding=0,
  244. )
  245. def forward(self, encoded: torch.Tensor, keypoints: torch.Tensor) -> torch.Tensor:
  246. """Based on the encoder output and the keypoints, compute the descriptors for each keypoint"""
  247. descriptors = self.conv_descriptor_b(self.relu(self.conv_descriptor_a(encoded)))
  248. descriptors = nn.functional.normalize(descriptors, p=2, dim=1)
  249. descriptors = self._sample_descriptors(keypoints[None], descriptors[0][None], 8)[0]
  250. # [descriptor_dim, num_keypoints] -> [num_keypoints, descriptor_dim]
  251. descriptors = torch.transpose(descriptors, 0, 1)
  252. return descriptors
  253. @staticmethod
  254. def _sample_descriptors(keypoints, descriptors, scale: int = 8) -> torch.Tensor:
  255. """Interpolate descriptors at keypoint locations"""
  256. batch_size, num_channels, height, width = descriptors.shape
  257. keypoints = keypoints - scale / 2 + 0.5
  258. divisor = torch.tensor([[(width * scale - scale / 2 - 0.5), (height * scale - scale / 2 - 0.5)]])
  259. divisor = divisor.to(keypoints)
  260. keypoints /= divisor
  261. keypoints = keypoints * 2 - 1 # normalize to (-1, 1)
  262. kwargs = {"align_corners": True}
  263. # [batch_size, num_channels, num_keypoints, 2] -> [batch_size, num_channels, num_keypoints, 2]
  264. keypoints = keypoints.view(batch_size, 1, -1, 2)
  265. descriptors = nn.functional.grid_sample(descriptors, keypoints, mode="bilinear", **kwargs)
  266. # [batch_size, descriptor_decoder_dim, num_channels, num_keypoints] -> [batch_size, descriptor_decoder_dim, num_keypoints]
  267. descriptors = descriptors.reshape(batch_size, num_channels, -1)
  268. descriptors = nn.functional.normalize(descriptors, p=2, dim=1)
  269. return descriptors
  270. @auto_docstring
  271. class SuperPointPreTrainedModel(PreTrainedModel):
  272. config: SuperPointConfig
  273. base_model_prefix = "superpoint"
  274. main_input_name = "pixel_values"
  275. input_modalities = ("image",)
  276. supports_gradient_checkpointing = False
  277. def extract_one_channel_pixel_values(self, pixel_values: torch.FloatTensor) -> torch.FloatTensor:
  278. """
  279. Assuming pixel_values has shape (batch_size, 3, height, width), and that all channels values are the same,
  280. extract the first channel value to get a tensor of shape (batch_size, 1, height, width) for SuperPoint. This is
  281. a workaround for the issue discussed in :
  282. https://github.com/huggingface/transformers/pull/25786#issuecomment-1730176446
  283. Args:
  284. pixel_values: torch.FloatTensor of shape (batch_size, 3, height, width)
  285. Returns:
  286. pixel_values: torch.FloatTensor of shape (batch_size, 1, height, width)
  287. """
  288. return pixel_values[:, 0, :, :][:, None, :, :]
  289. @auto_docstring(
  290. custom_intro="""
  291. SuperPoint model outputting keypoints and descriptors.
  292. """
  293. )
  294. class SuperPointForKeypointDetection(SuperPointPreTrainedModel):
  295. """
  296. SuperPoint model. It consists of a SuperPointEncoder, a SuperPointInterestPointDecoder and a
  297. SuperPointDescriptorDecoder. SuperPoint was proposed in `SuperPoint: Self-Supervised Interest Point Detection and
  298. Description <https://huggingface.co/papers/1712.07629>`__ by Daniel DeTone, Tomasz Malisiewicz, and Andrew Rabinovich. It
  299. is a fully convolutional neural network that extracts keypoints and descriptors from an image. It is trained in a
  300. self-supervised manner, using a combination of a photometric loss and a loss based on the homographic adaptation of
  301. keypoints. It is made of a convolutional encoder and two decoders: one for keypoints and one for descriptors.
  302. """
  303. def __init__(self, config: SuperPointConfig) -> None:
  304. super().__init__(config)
  305. self.config = config
  306. self.encoder = SuperPointEncoder(config)
  307. self.keypoint_decoder = SuperPointInterestPointDecoder(config)
  308. self.descriptor_decoder = SuperPointDescriptorDecoder(config)
  309. self.post_init()
  310. @auto_docstring
  311. def forward(
  312. self,
  313. pixel_values: torch.FloatTensor,
  314. labels: torch.LongTensor | None = None,
  315. output_hidden_states: bool | None = None,
  316. return_dict: bool | None = None,
  317. **kwargs,
  318. ) -> tuple | SuperPointKeypointDescriptionOutput:
  319. r"""
  320. Examples:
  321. ```python
  322. >>> from transformers import AutoImageProcessor, SuperPointForKeypointDetection
  323. >>> import torch
  324. >>> from PIL import Image
  325. >>> import httpx
  326. >>> from io import BytesIO
  327. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  328. >>> with httpx.stream("GET", url) as response:
  329. ... image = Image.open(BytesIO(response.read()))
  330. >>> processor = AutoImageProcessor.from_pretrained("magic-leap-community/superpoint")
  331. >>> model = SuperPointForKeypointDetection.from_pretrained("magic-leap-community/superpoint")
  332. >>> inputs = processor(image, return_tensors="pt")
  333. >>> outputs = model(**inputs)
  334. ```"""
  335. loss = None
  336. if labels is not None:
  337. raise ValueError("SuperPoint does not support training for now.")
  338. output_hidden_states = (
  339. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  340. )
  341. return_dict = return_dict if return_dict is not None else self.config.return_dict
  342. pixel_values = self.extract_one_channel_pixel_values(pixel_values)
  343. batch_size, _, height, width = pixel_values.shape
  344. encoder_outputs = self.encoder(
  345. pixel_values,
  346. output_hidden_states=output_hidden_states,
  347. return_dict=return_dict,
  348. )
  349. last_hidden_state = encoder_outputs[0]
  350. list_keypoints_scores = [
  351. self.keypoint_decoder(last_hidden_state[None, ...]) for last_hidden_state in last_hidden_state
  352. ]
  353. list_keypoints = [keypoints_scores[0] for keypoints_scores in list_keypoints_scores]
  354. list_scores = [keypoints_scores[1] for keypoints_scores in list_keypoints_scores]
  355. list_descriptors = [
  356. self.descriptor_decoder(last_hidden_state[None, ...], keypoints[None, ...])
  357. for last_hidden_state, keypoints in zip(last_hidden_state, list_keypoints)
  358. ]
  359. maximum_num_keypoints = max(keypoints.shape[0] for keypoints in list_keypoints)
  360. keypoints = torch.zeros((batch_size, maximum_num_keypoints, 2), device=pixel_values.device)
  361. scores = torch.zeros((batch_size, maximum_num_keypoints), device=pixel_values.device)
  362. descriptors = torch.zeros(
  363. (batch_size, maximum_num_keypoints, self.config.descriptor_decoder_dim),
  364. device=pixel_values.device,
  365. )
  366. mask = torch.zeros((batch_size, maximum_num_keypoints), device=pixel_values.device, dtype=torch.int)
  367. for i, (_keypoints, _scores, _descriptors) in enumerate(zip(list_keypoints, list_scores, list_descriptors)):
  368. keypoints[i, : _keypoints.shape[0]] = _keypoints
  369. scores[i, : _scores.shape[0]] = _scores
  370. descriptors[i, : _descriptors.shape[0]] = _descriptors
  371. mask[i, : _scores.shape[0]] = 1
  372. # Convert to relative coordinates
  373. keypoints = keypoints / torch.tensor([width, height], device=keypoints.device)
  374. hidden_states = encoder_outputs[1] if output_hidden_states else None
  375. if not return_dict:
  376. return tuple(v for v in [loss, keypoints, scores, descriptors, mask, hidden_states] if v is not None)
  377. return SuperPointKeypointDescriptionOutput(
  378. loss=loss,
  379. keypoints=keypoints,
  380. scores=scores,
  381. descriptors=descriptors,
  382. mask=mask,
  383. hidden_states=hidden_states,
  384. )
  385. __all__ = ["SuperPointForKeypointDetection", "SuperPointPreTrainedModel"]