modeling_superglue.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757
  1. # Copyright 2024 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch SuperGlue model."""
  15. import math
  16. from dataclasses import dataclass
  17. import torch
  18. from torch import nn
  19. from transformers import PreTrainedModel
  20. from transformers.models.superglue.configuration_superglue import SuperGlueConfig
  21. from ... import initialization as init
  22. from ...utils import ModelOutput, auto_docstring, logging
  23. from ..auto import AutoModelForKeypointDetection
  24. logger = logging.get_logger(__name__)
  25. def concat_pairs(tensor_tuple0: tuple[torch.Tensor], tensor_tuple1: tuple[torch.Tensor]) -> tuple[torch.Tensor]:
  26. """
  27. Concatenate two tuples of tensors pairwise
  28. Args:
  29. tensor_tuple0 (`tuple[torch.Tensor]`):
  30. Tuple of tensors.
  31. tensor_tuple1 (`tuple[torch.Tensor]`):
  32. Tuple of tensors.
  33. Returns:
  34. (`tuple[torch.Tensor]`): Tuple of concatenated tensors.
  35. """
  36. return tuple(torch.cat([tensor0, tensor1]) for tensor0, tensor1 in zip(tensor_tuple0, tensor_tuple1))
  37. def normalize_keypoints(keypoints: torch.Tensor, height: int, width: int) -> torch.Tensor:
  38. """
  39. Normalize keypoints locations based on image image_shape
  40. Args:
  41. keypoints (`torch.Tensor` of shape `(batch_size, num_keypoints, 2)`):
  42. Keypoints locations in (x, y) format.
  43. height (`int`):
  44. Image height.
  45. width (`int`):
  46. Image width.
  47. Returns:
  48. Normalized keypoints locations of shape (`torch.Tensor` of shape `(batch_size, num_keypoints, 2)`).
  49. """
  50. size = torch.tensor([width, height], device=keypoints.device, dtype=keypoints.dtype)[None]
  51. center = size / 2
  52. scaling = size.max(1, keepdim=True).values * 0.7
  53. return (keypoints - center[:, None, :]) / scaling[:, None, :]
  54. def log_sinkhorn_iterations(
  55. log_cost_matrix: torch.Tensor,
  56. log_source_distribution: torch.Tensor,
  57. log_target_distribution: torch.Tensor,
  58. num_iterations: int,
  59. ) -> torch.Tensor:
  60. """
  61. Perform Sinkhorn Normalization in Log-space for stability
  62. Args:
  63. log_cost_matrix (`torch.Tensor` of shape `(batch_size, num_rows, num_columns)`):
  64. Logarithm of the cost matrix.
  65. log_source_distribution (`torch.Tensor` of shape `(batch_size, num_rows)`):
  66. Logarithm of the source distribution.
  67. log_target_distribution (`torch.Tensor` of shape `(batch_size, num_columns)`):
  68. Logarithm of the target distribution.
  69. Returns:
  70. log_cost_matrix (`torch.Tensor` of shape `(batch_size, num_rows, num_columns)`): Logarithm of the optimal
  71. transport matrix.
  72. """
  73. log_u_scaling = torch.zeros_like(log_source_distribution)
  74. log_v_scaling = torch.zeros_like(log_target_distribution)
  75. for _ in range(num_iterations):
  76. log_u_scaling = log_source_distribution - torch.logsumexp(log_cost_matrix + log_v_scaling.unsqueeze(1), dim=2)
  77. log_v_scaling = log_target_distribution - torch.logsumexp(log_cost_matrix + log_u_scaling.unsqueeze(2), dim=1)
  78. return log_cost_matrix + log_u_scaling.unsqueeze(2) + log_v_scaling.unsqueeze(1)
  79. def log_optimal_transport(scores: torch.Tensor, reg_param: torch.Tensor, iterations: int) -> torch.Tensor:
  80. """
  81. Perform Differentiable Optimal Transport in Log-space for stability
  82. Args:
  83. scores: (`torch.Tensor` of shape `(batch_size, num_rows, num_columns)`):
  84. Cost matrix.
  85. reg_param: (`torch.Tensor` of shape `(batch_size, 1, 1)`):
  86. Regularization parameter.
  87. iterations: (`int`):
  88. Number of Sinkhorn iterations.
  89. Returns:
  90. log_optimal_transport_matrix: (`torch.Tensor` of shape `(batch_size, num_rows, num_columns)`): Logarithm of the
  91. optimal transport matrix.
  92. """
  93. batch_size, num_rows, num_columns = scores.shape
  94. one_tensor = scores.new_tensor(1)
  95. num_rows_tensor, num_columns_tensor = (num_rows * one_tensor).to(scores), (num_columns * one_tensor).to(scores)
  96. source_reg_param = reg_param.expand(batch_size, num_rows, 1)
  97. target_reg_param = reg_param.expand(batch_size, 1, num_columns)
  98. reg_param = reg_param.expand(batch_size, 1, 1)
  99. couplings = torch.cat([torch.cat([scores, source_reg_param], -1), torch.cat([target_reg_param, reg_param], -1)], 1)
  100. log_normalization = -(num_rows_tensor + num_columns_tensor).log()
  101. log_source_distribution = torch.cat(
  102. [log_normalization.expand(num_rows), num_columns_tensor.log()[None] + log_normalization]
  103. )
  104. log_target_distribution = torch.cat(
  105. [log_normalization.expand(num_columns), num_rows_tensor.log()[None] + log_normalization]
  106. )
  107. log_source_distribution, log_target_distribution = (
  108. log_source_distribution[None].expand(batch_size, -1),
  109. log_target_distribution[None].expand(batch_size, -1),
  110. )
  111. log_optimal_transport_matrix = log_sinkhorn_iterations(
  112. couplings, log_source_distribution, log_target_distribution, num_iterations=iterations
  113. )
  114. log_optimal_transport_matrix = log_optimal_transport_matrix - log_normalization # multiply probabilities by M+N
  115. return log_optimal_transport_matrix
  116. def arange_like(x, dim: int) -> torch.Tensor:
  117. return x.new_ones(x.shape[dim]).cumsum(0) - 1
  118. @dataclass
  119. @auto_docstring(
  120. custom_intro="""
  121. Base class for outputs of SuperGlue keypoint matching models. Due to the nature of keypoint detection and matching, the number
  122. of keypoints is not fixed and can vary from image to image, which makes batching non-trivial. In the batch of
  123. images, the maximum number of matches is set as the dimension of the matches and matching scores. The mask tensor is
  124. used to indicate which values in the keypoints, matches and matching_scores tensors are keypoint matching
  125. information.
  126. """
  127. )
  128. class SuperGlueKeypointMatchingOutput(ModelOutput):
  129. r"""
  130. loss (`torch.FloatTensor` of shape `(1,)`, *optional*):
  131. Loss computed during training.
  132. matches (`torch.FloatTensor` of shape `(batch_size, 2, num_matches)`):
  133. Index of keypoint matched in the other image.
  134. matching_scores (`torch.FloatTensor` of shape `(batch_size, 2, num_matches)`):
  135. Scores of predicted matches.
  136. keypoints (`torch.FloatTensor` of shape `(batch_size, num_keypoints, 2)`):
  137. Absolute (x, y) coordinates of predicted keypoints in a given image.
  138. mask (`torch.IntTensor` of shape `(batch_size, num_keypoints)`):
  139. Mask indicating which values in matches and matching_scores are keypoint matching information.
  140. hidden_states (`tuple[torch.FloatTensor, ...]`, *optional*):
  141. Tuple of `torch.FloatTensor` (one for the output of each stage) of shape `(batch_size, 2, num_channels,
  142. num_keypoints)`, returned when `output_hidden_states=True` is passed or when
  143. `config.output_hidden_states=True`)
  144. attentions (`tuple[torch.FloatTensor, ...]`, *optional*):
  145. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, 2, num_heads, num_keypoints,
  146. num_keypoints)`, returned when `output_attentions=True` is passed or when `config.output_attentions=True`)
  147. """
  148. loss: torch.FloatTensor | None = None
  149. matches: torch.FloatTensor | None = None
  150. matching_scores: torch.FloatTensor | None = None
  151. keypoints: torch.FloatTensor | None = None
  152. mask: torch.IntTensor | None = None
  153. hidden_states: tuple[torch.FloatTensor] | None = None
  154. attentions: tuple[torch.FloatTensor] | None = None
  155. class SuperGlueMultiLayerPerceptron(nn.Module):
  156. def __init__(self, config: SuperGlueConfig, in_channels: int, out_channels: int) -> None:
  157. super().__init__()
  158. self.linear = nn.Linear(in_channels, out_channels)
  159. self.batch_norm = nn.BatchNorm1d(out_channels)
  160. self.activation = nn.ReLU()
  161. def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
  162. hidden_state = self.linear(hidden_state)
  163. hidden_state = hidden_state.transpose(-1, -2)
  164. hidden_state = self.batch_norm(hidden_state)
  165. hidden_state = hidden_state.transpose(-1, -2)
  166. hidden_state = self.activation(hidden_state)
  167. return hidden_state
  168. class SuperGlueKeypointEncoder(nn.Module):
  169. def __init__(self, config: SuperGlueConfig) -> None:
  170. super().__init__()
  171. layer_sizes = config.keypoint_encoder_sizes
  172. hidden_size = config.hidden_size
  173. # 3 here consists of 2 for the (x, y) coordinates and 1 for the score of the keypoint
  174. encoder_channels = [3] + layer_sizes + [hidden_size]
  175. layers = [
  176. SuperGlueMultiLayerPerceptron(config, encoder_channels[i - 1], encoder_channels[i])
  177. for i in range(1, len(encoder_channels) - 1)
  178. ]
  179. layers.append(nn.Linear(encoder_channels[-2], encoder_channels[-1]))
  180. self.encoder = nn.ModuleList(layers)
  181. def forward(
  182. self,
  183. keypoints: torch.Tensor,
  184. scores: torch.Tensor,
  185. output_hidden_states: bool | None = False,
  186. ) -> tuple[torch.Tensor, tuple[torch.Tensor] | None]:
  187. scores = scores.unsqueeze(2)
  188. hidden_state = torch.cat([keypoints, scores], dim=2)
  189. all_hidden_states = () if output_hidden_states else None
  190. for layer in self.encoder:
  191. hidden_state = layer(hidden_state)
  192. if output_hidden_states:
  193. all_hidden_states = all_hidden_states + (hidden_state,)
  194. return hidden_state, all_hidden_states
  195. class SuperGlueSelfAttention(nn.Module):
  196. def __init__(self, config):
  197. super().__init__()
  198. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  199. raise ValueError(
  200. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
  201. f"heads ({config.num_attention_heads})"
  202. )
  203. self.num_attention_heads = config.num_attention_heads
  204. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  205. self.all_head_size = self.num_attention_heads * self.attention_head_size
  206. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  207. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  208. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  209. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  210. self.is_decoder = config.is_decoder
  211. def forward(
  212. self,
  213. hidden_states: torch.Tensor,
  214. attention_mask: torch.FloatTensor | None = None,
  215. encoder_hidden_states: torch.FloatTensor | None = None,
  216. encoder_attention_mask: torch.FloatTensor | None = None,
  217. output_attentions: bool | None = False,
  218. ) -> tuple[torch.Tensor]:
  219. # If this is instantiated as a cross-attention module, the keys
  220. # and values come from an encoder; the attention mask needs to be
  221. # such that the encoder's padding tokens are not attended to.
  222. is_cross_attention = encoder_hidden_states is not None
  223. current_states = encoder_hidden_states if is_cross_attention else hidden_states
  224. attention_mask = encoder_attention_mask if is_cross_attention else attention_mask
  225. batch_size = hidden_states.shape[0]
  226. key_layer = (
  227. self.key(current_states)
  228. .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
  229. .transpose(1, 2)
  230. )
  231. value_layer = (
  232. self.value(current_states)
  233. .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
  234. .transpose(1, 2)
  235. )
  236. query_layer = (
  237. self.query(hidden_states)
  238. .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
  239. .transpose(1, 2)
  240. )
  241. # Take the dot product between "query" and "key" to get the raw attention scores.
  242. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  243. attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  244. if attention_mask is not None:
  245. # Apply the attention mask is (precomputed for all layers in SuperGlueModel forward() function)
  246. attention_scores = attention_scores + attention_mask
  247. # Normalize the attention scores to probabilities.
  248. attention_probs = nn.functional.softmax(attention_scores, dim=-1)
  249. # This is actually dropping out entire tokens to attend to, which might
  250. # seem a bit unusual, but is taken from the original Transformer paper.
  251. attention_probs = self.dropout(attention_probs)
  252. context_layer = torch.matmul(attention_probs, value_layer)
  253. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  254. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  255. context_layer = context_layer.view(new_context_layer_shape)
  256. outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
  257. if self.is_decoder:
  258. outputs = outputs + (None,)
  259. return outputs
  260. class SuperGlueSelfOutput(nn.Module):
  261. def __init__(self, config: SuperGlueConfig):
  262. super().__init__()
  263. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  264. def forward(self, hidden_states: torch.Tensor, *args) -> torch.Tensor:
  265. hidden_states = self.dense(hidden_states)
  266. return hidden_states
  267. SUPERGLUE_SELF_ATTENTION_CLASSES = {
  268. "eager": SuperGlueSelfAttention,
  269. }
  270. class SuperGlueAttention(nn.Module):
  271. def __init__(self, config):
  272. super().__init__()
  273. self.self = SUPERGLUE_SELF_ATTENTION_CLASSES[config._attn_implementation](config)
  274. self.output = SuperGlueSelfOutput(config)
  275. def forward(
  276. self,
  277. hidden_states: torch.Tensor,
  278. attention_mask: torch.FloatTensor | None = None,
  279. encoder_hidden_states: torch.FloatTensor | None = None,
  280. encoder_attention_mask: torch.Tensor | None = None,
  281. output_attentions: bool | None = False,
  282. ) -> tuple[torch.Tensor]:
  283. self_outputs = self.self(
  284. hidden_states,
  285. attention_mask=attention_mask,
  286. encoder_hidden_states=encoder_hidden_states,
  287. encoder_attention_mask=encoder_attention_mask,
  288. output_attentions=output_attentions,
  289. )
  290. attention_output = self.output(self_outputs[0], hidden_states)
  291. outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
  292. return outputs
  293. class SuperGlueAttentionalPropagation(nn.Module):
  294. def __init__(self, config: SuperGlueConfig) -> None:
  295. super().__init__()
  296. hidden_size = config.hidden_size
  297. self.attention = SuperGlueAttention(config)
  298. mlp_channels = [hidden_size * 2, hidden_size * 2, hidden_size]
  299. layers = [
  300. SuperGlueMultiLayerPerceptron(config, mlp_channels[i - 1], mlp_channels[i])
  301. for i in range(1, len(mlp_channels) - 1)
  302. ]
  303. layers.append(nn.Linear(mlp_channels[-2], mlp_channels[-1]))
  304. self.mlp = nn.ModuleList(layers)
  305. def forward(
  306. self,
  307. descriptors: torch.Tensor,
  308. attention_mask: torch.Tensor | None = None,
  309. encoder_hidden_states: torch.Tensor | None = None,
  310. encoder_attention_mask: torch.Tensor | None = None,
  311. output_attentions: bool = False,
  312. output_hidden_states: bool = False,
  313. ) -> tuple[torch.Tensor, tuple[torch.Tensor] | None, tuple[torch.Tensor] | None]:
  314. attention_outputs = self.attention(
  315. descriptors,
  316. attention_mask=attention_mask,
  317. encoder_hidden_states=encoder_hidden_states,
  318. encoder_attention_mask=encoder_attention_mask,
  319. output_attentions=output_attentions,
  320. )
  321. output = attention_outputs[0]
  322. attention = attention_outputs[1:]
  323. hidden_state = torch.cat([descriptors, output], dim=2)
  324. all_hidden_states = () if output_hidden_states else None
  325. for layer in self.mlp:
  326. hidden_state = layer(hidden_state)
  327. if output_hidden_states:
  328. all_hidden_states = all_hidden_states + (hidden_state,)
  329. return hidden_state, all_hidden_states, attention
  330. class SuperGlueAttentionalGNN(nn.Module):
  331. def __init__(self, config: SuperGlueConfig) -> None:
  332. super().__init__()
  333. self.hidden_size = config.hidden_size
  334. self.layers_types = config.gnn_layers_types
  335. self.layers = nn.ModuleList([SuperGlueAttentionalPropagation(config) for _ in range(len(self.layers_types))])
  336. def forward(
  337. self,
  338. descriptors: torch.Tensor,
  339. mask: torch.Tensor | None = None,
  340. output_attentions: bool = False,
  341. output_hidden_states: bool | None = False,
  342. ) -> tuple[torch.Tensor, tuple | None, tuple | None]:
  343. all_hidden_states = () if output_hidden_states else None
  344. all_attentions = () if output_attentions else None
  345. batch_size, num_keypoints, _ = descriptors.shape
  346. if output_hidden_states:
  347. all_hidden_states = all_hidden_states + (descriptors,)
  348. for gnn_layer, layer_type in zip(self.layers, self.layers_types):
  349. encoder_hidden_states = None
  350. encoder_attention_mask = None
  351. if layer_type == "cross":
  352. encoder_hidden_states = (
  353. descriptors.reshape(-1, 2, num_keypoints, self.hidden_size)
  354. .flip(1)
  355. .reshape(batch_size, num_keypoints, self.hidden_size)
  356. )
  357. encoder_attention_mask = (
  358. mask.reshape(-1, 2, 1, 1, num_keypoints).flip(1).reshape(batch_size, 1, 1, num_keypoints)
  359. if mask is not None
  360. else None
  361. )
  362. gnn_outputs = gnn_layer(
  363. descriptors,
  364. attention_mask=mask,
  365. encoder_hidden_states=encoder_hidden_states,
  366. encoder_attention_mask=encoder_attention_mask,
  367. output_hidden_states=output_hidden_states,
  368. output_attentions=output_attentions,
  369. )
  370. delta = gnn_outputs[0]
  371. if output_hidden_states:
  372. all_hidden_states = all_hidden_states + gnn_outputs[1]
  373. if output_attentions:
  374. all_attentions = all_attentions + gnn_outputs[2]
  375. descriptors = descriptors + delta
  376. return descriptors, all_hidden_states, all_attentions
  377. class SuperGlueFinalProjection(nn.Module):
  378. def __init__(self, config: SuperGlueConfig) -> None:
  379. super().__init__()
  380. hidden_size = config.hidden_size
  381. self.final_proj = nn.Linear(hidden_size, hidden_size, bias=True)
  382. def forward(self, descriptors: torch.Tensor) -> torch.Tensor:
  383. return self.final_proj(descriptors)
  384. @auto_docstring
  385. class SuperGluePreTrainedModel(PreTrainedModel):
  386. config: SuperGlueConfig
  387. base_model_prefix = "superglue"
  388. main_input_name = "pixel_values"
  389. input_modalities = ("image",)
  390. @torch.no_grad()
  391. def _init_weights(self, module: nn.Module) -> None:
  392. """Initialize the weights"""
  393. super()._init_weights(module)
  394. if hasattr(module, "bin_score"):
  395. init.ones_(module.bin_score)
  396. @auto_docstring(
  397. custom_intro="""
  398. SuperGlue model taking images as inputs and outputting the matching of them.
  399. """
  400. )
  401. class SuperGlueForKeypointMatching(SuperGluePreTrainedModel):
  402. """SuperGlue feature matching middle-end
  403. Given two sets of keypoints and locations, we determine the
  404. correspondences by:
  405. 1. Keypoint Encoding (normalization + visual feature and location fusion)
  406. 2. Graph Neural Network with multiple self and cross-attention layers
  407. 3. Final projection layer
  408. 4. Optimal Transport Layer (a differentiable Hungarian matching algorithm)
  409. 5. Thresholding matrix based on mutual exclusivity and a match_threshold
  410. The correspondence ids use -1 to indicate non-matching points.
  411. Paul-Edouard Sarlin, Daniel DeTone, Tomasz Malisiewicz, and Andrew
  412. Rabinovich. SuperGlue: Learning Feature Matching with Graph Neural
  413. Networks. In CVPR, 2020. https://huggingface.co/papers/1911.11763
  414. """
  415. def __init__(self, config: SuperGlueConfig) -> None:
  416. super().__init__(config)
  417. self.keypoint_detector = AutoModelForKeypointDetection.from_config(config.keypoint_detector_config)
  418. self.keypoint_encoder = SuperGlueKeypointEncoder(config)
  419. self.gnn = SuperGlueAttentionalGNN(config)
  420. self.final_projection = SuperGlueFinalProjection(config)
  421. bin_score = torch.nn.Parameter(torch.tensor(1.0))
  422. self.register_parameter("bin_score", bin_score)
  423. self.post_init()
  424. def _match_image_pair(
  425. self,
  426. keypoints: torch.Tensor,
  427. descriptors: torch.Tensor,
  428. scores: torch.Tensor,
  429. height: int,
  430. width: int,
  431. mask: torch.Tensor | None = None,
  432. output_attentions: bool | None = None,
  433. output_hidden_states: bool | None = None,
  434. ) -> tuple[torch.Tensor, torch.Tensor, tuple, tuple]:
  435. """
  436. Perform keypoint matching between two images.
  437. Args:
  438. keypoints (`torch.Tensor` of shape `(batch_size, 2, num_keypoints, 2)`):
  439. Keypoints detected in the pair of image.
  440. descriptors (`torch.Tensor` of shape `(batch_size, 2, descriptor_dim, num_keypoints)`):
  441. Descriptors of the keypoints detected in the image pair.
  442. scores (`torch.Tensor` of shape `(batch_size, 2, num_keypoints)`):
  443. Confidence scores of the keypoints detected in the image pair.
  444. height (`int`): Image height.
  445. width (`int`): Image width.
  446. mask (`torch.Tensor` of shape `(batch_size, 2, num_keypoints)`, *optional*):
  447. Mask indicating which values in the keypoints, matches and matching_scores tensors are keypoint matching
  448. information.
  449. output_attentions (`bool`, *optional*):
  450. Whether or not to return the attentions tensors. Default to `config.output_attentions`.
  451. output_hidden_states (`bool`, *optional*):
  452. Whether or not to return the hidden states of all layers. Default to `config.output_hidden_states`.
  453. Returns:
  454. matches (`torch.Tensor` of shape `(batch_size, 2, num_keypoints)`):
  455. For each image pair, for each keypoint in image0, the index of the keypoint in image1 that was matched
  456. with. And for each keypoint in image1, the index of the keypoint in image0 that was matched with.
  457. matching_scores (`torch.Tensor` of shape `(batch_size, 2, num_keypoints)`):
  458. Scores of predicted matches for each image pair
  459. all_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
  460. Tuple of `torch.FloatTensor` (one for the output of each stage) of shape `(1, 2, num_keypoints,
  461. num_channels)`.
  462. all_attentions (`tuple(torch.FloatTensor)`, *optional*):
  463. Tuple of `torch.FloatTensor` (one for each layer) of shape `(1, 2, num_heads, num_keypoints,
  464. num_keypoints)`.
  465. """
  466. all_hidden_states = () if output_hidden_states else None
  467. all_attentions = () if output_attentions else None
  468. if keypoints.shape[2] == 0: # no keypoints
  469. shape = keypoints.shape[:-1]
  470. return (
  471. keypoints.new_full(shape, -1, dtype=torch.int),
  472. keypoints.new_zeros(shape),
  473. all_hidden_states,
  474. all_attentions,
  475. )
  476. batch_size, _, num_keypoints, _ = keypoints.shape
  477. # (batch_size, 2, num_keypoints, 2) -> (batch_size * 2, num_keypoints, 2)
  478. keypoints = keypoints.reshape(batch_size * 2, num_keypoints, 2)
  479. descriptors = descriptors.reshape(batch_size * 2, num_keypoints, self.config.hidden_size)
  480. scores = scores.reshape(batch_size * 2, num_keypoints)
  481. mask = mask.reshape(batch_size * 2, num_keypoints) if mask is not None else None
  482. # Keypoint normalization
  483. keypoints = normalize_keypoints(keypoints, height, width)
  484. encoded_keypoints = self.keypoint_encoder(keypoints, scores, output_hidden_states=output_hidden_states)
  485. last_hidden_state = encoded_keypoints[0]
  486. # Keypoint MLP encoder.
  487. descriptors = descriptors + last_hidden_state
  488. if mask is not None:
  489. input_shape = descriptors.size()
  490. extended_attention_mask = self.get_extended_attention_mask(mask, input_shape)
  491. else:
  492. extended_attention_mask = torch.ones((batch_size, num_keypoints), device=keypoints.device)
  493. # Multi-layer Transformer network.
  494. gnn_outputs = self.gnn(
  495. descriptors,
  496. mask=extended_attention_mask,
  497. output_hidden_states=output_hidden_states,
  498. output_attentions=output_attentions,
  499. )
  500. descriptors = gnn_outputs[0]
  501. # Final MLP projection.
  502. projected_descriptors = self.final_projection(descriptors)
  503. # (batch_size * 2, num_keypoints, descriptor_dim) -> (batch_size, 2, num_keypoints, descriptor_dim)
  504. final_descriptors = projected_descriptors.reshape(batch_size, 2, num_keypoints, self.config.hidden_size)
  505. final_descriptors0 = final_descriptors[:, 0]
  506. final_descriptors1 = final_descriptors[:, 1]
  507. # Compute matching descriptor distance.
  508. scores = final_descriptors0 @ final_descriptors1.transpose(1, 2)
  509. scores = scores / self.config.hidden_size**0.5
  510. if mask is not None:
  511. mask = mask.reshape(batch_size, 2, num_keypoints)
  512. mask0 = mask[:, 0].unsqueeze(2)
  513. mask1 = mask[:, 1].unsqueeze(1)
  514. mask = torch.logical_and(mask0, mask1)
  515. scores = scores.masked_fill(mask == 0, torch.finfo(scores.dtype).min)
  516. # Run the optimal transport.
  517. scores = log_optimal_transport(scores, self.bin_score, iterations=self.config.sinkhorn_iterations)
  518. # Get the matches with score above "match_threshold".
  519. max0 = scores[:, :-1, :-1].max(2)
  520. max1 = scores[:, :-1, :-1].max(1)
  521. indices0 = max0.indices
  522. indices1 = max1.indices
  523. mutual0 = arange_like(indices0, 1)[None] == indices1.gather(1, indices0)
  524. mutual1 = arange_like(indices1, 1)[None] == indices0.gather(1, indices1)
  525. zero = scores.new_tensor(0)
  526. matching_scores0 = torch.where(mutual0, max0.values.exp(), zero)
  527. matching_scores0 = torch.where(matching_scores0 > self.config.matching_threshold, matching_scores0, zero)
  528. matching_scores1 = torch.where(mutual1, matching_scores0.gather(1, indices1), zero)
  529. valid0 = mutual0 & (matching_scores0 > zero)
  530. valid1 = mutual1 & valid0.gather(1, indices1)
  531. matches0 = torch.where(valid0, indices0, indices0.new_tensor(-1))
  532. matches1 = torch.where(valid1, indices1, indices1.new_tensor(-1))
  533. matches = torch.cat([matches0, matches1], dim=1).reshape(batch_size, 2, -1)
  534. matching_scores = torch.cat([matching_scores0, matching_scores1], dim=1).reshape(batch_size, 2, -1)
  535. if output_hidden_states:
  536. all_hidden_states = all_hidden_states + encoded_keypoints[1]
  537. all_hidden_states = all_hidden_states + gnn_outputs[1]
  538. all_hidden_states = all_hidden_states + (projected_descriptors,)
  539. all_hidden_states = tuple(
  540. x.reshape(batch_size, 2, num_keypoints, -1).transpose(-1, -2) for x in all_hidden_states
  541. )
  542. if output_attentions:
  543. all_attentions = all_attentions + gnn_outputs[2]
  544. all_attentions = tuple(x.reshape(batch_size, 2, -1, num_keypoints, num_keypoints) for x in all_attentions)
  545. return (
  546. matches,
  547. matching_scores,
  548. all_hidden_states,
  549. all_attentions,
  550. )
  551. @auto_docstring
  552. def forward(
  553. self,
  554. pixel_values: torch.FloatTensor,
  555. labels: torch.LongTensor | None = None,
  556. output_attentions: bool | None = None,
  557. output_hidden_states: bool | None = None,
  558. return_dict: bool | None = None,
  559. **kwargs,
  560. ) -> tuple | SuperGlueKeypointMatchingOutput:
  561. r"""
  562. Examples:
  563. ```python
  564. >>> from transformers import AutoImageProcessor, AutoModel
  565. >>> import torch
  566. >>> from PIL import Image
  567. >>> import httpx
  568. >>> from io import BytesIO
  569. >>> url = "https://github.com/magicleap/SuperGluePretrainedNetwork/blob/master/assets/phototourism_sample_images/london_bridge_78916675_4568141288.jpg?raw=true"
  570. >>> with httpx.stream("GET", url) as response:
  571. ... image_1 = Image.open(BytesIO(response.read()))
  572. >>> url = "https://github.com/magicleap/SuperGluePretrainedNetwork/blob/master/assets/phototourism_sample_images/london_bridge_19481797_2295892421.jpg?raw=true"
  573. >>> with httpx.stream("GET", url) as response:
  574. ... image_2 = Image.open(BytesIO(response.read()))
  575. >>> images = [image_1, image_2]
  576. >>> processor = AutoImageProcessor.from_pretrained("magic-leap-community/superglue_outdoor")
  577. >>> model = AutoModel.from_pretrained("magic-leap-community/superglue_outdoor")
  578. >>> with torch.no_grad():
  579. >>> inputs = processor(images, return_tensors="pt")
  580. >>> outputs = model(**inputs)
  581. ```"""
  582. loss = None
  583. if labels is not None:
  584. raise ValueError("SuperGlue is not trainable, no labels should be provided.")
  585. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  586. output_hidden_states = (
  587. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  588. )
  589. return_dict = return_dict if return_dict is not None else self.config.return_dict
  590. if pixel_values.ndim != 5 or pixel_values.size(1) != 2:
  591. raise ValueError("Input must be a 5D tensor of shape (batch_size, 2, num_channels, height, width)")
  592. batch_size, _, channels, height, width = pixel_values.shape
  593. pixel_values = pixel_values.reshape(batch_size * 2, channels, height, width)
  594. keypoint_detections = self.keypoint_detector(pixel_values)
  595. keypoints, scores, descriptors, mask = keypoint_detections[:4]
  596. keypoints = keypoints.reshape(batch_size, 2, -1, 2).to(pixel_values)
  597. scores = scores.reshape(batch_size, 2, -1).to(pixel_values)
  598. descriptors = descriptors.reshape(batch_size, 2, -1, self.config.hidden_size).to(pixel_values)
  599. mask = mask.reshape(batch_size, 2, -1)
  600. absolute_keypoints = keypoints.clone()
  601. absolute_keypoints[:, :, :, 0] = absolute_keypoints[:, :, :, 0] * width
  602. absolute_keypoints[:, :, :, 1] = absolute_keypoints[:, :, :, 1] * height
  603. matches, matching_scores, hidden_states, attentions = self._match_image_pair(
  604. absolute_keypoints,
  605. descriptors,
  606. scores,
  607. height,
  608. width,
  609. mask=mask,
  610. output_attentions=output_attentions,
  611. output_hidden_states=output_hidden_states,
  612. )
  613. if not return_dict:
  614. return tuple(
  615. v
  616. for v in [loss, matches, matching_scores, keypoints, mask, hidden_states, attentions]
  617. if v is not None
  618. )
  619. return SuperGlueKeypointMatchingOutput(
  620. loss=loss,
  621. matches=matches,
  622. matching_scores=matching_scores,
  623. keypoints=keypoints,
  624. mask=mask,
  625. hidden_states=hidden_states,
  626. attentions=attentions,
  627. )
  628. __all__ = ["SuperGluePreTrainedModel", "SuperGlueForKeypointMatching"]