| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757 |
- # Copyright 2024 The HuggingFace Team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """PyTorch SuperGlue model."""
- import math
- from dataclasses import dataclass
- import torch
- from torch import nn
- from transformers import PreTrainedModel
- from transformers.models.superglue.configuration_superglue import SuperGlueConfig
- from ... import initialization as init
- from ...utils import ModelOutput, auto_docstring, logging
- from ..auto import AutoModelForKeypointDetection
- logger = logging.get_logger(__name__)
- def concat_pairs(tensor_tuple0: tuple[torch.Tensor], tensor_tuple1: tuple[torch.Tensor]) -> tuple[torch.Tensor]:
- """
- Concatenate two tuples of tensors pairwise
- Args:
- tensor_tuple0 (`tuple[torch.Tensor]`):
- Tuple of tensors.
- tensor_tuple1 (`tuple[torch.Tensor]`):
- Tuple of tensors.
- Returns:
- (`tuple[torch.Tensor]`): Tuple of concatenated tensors.
- """
- return tuple(torch.cat([tensor0, tensor1]) for tensor0, tensor1 in zip(tensor_tuple0, tensor_tuple1))
- def normalize_keypoints(keypoints: torch.Tensor, height: int, width: int) -> torch.Tensor:
- """
- Normalize keypoints locations based on image image_shape
- Args:
- keypoints (`torch.Tensor` of shape `(batch_size, num_keypoints, 2)`):
- Keypoints locations in (x, y) format.
- height (`int`):
- Image height.
- width (`int`):
- Image width.
- Returns:
- Normalized keypoints locations of shape (`torch.Tensor` of shape `(batch_size, num_keypoints, 2)`).
- """
- size = torch.tensor([width, height], device=keypoints.device, dtype=keypoints.dtype)[None]
- center = size / 2
- scaling = size.max(1, keepdim=True).values * 0.7
- return (keypoints - center[:, None, :]) / scaling[:, None, :]
- def log_sinkhorn_iterations(
- log_cost_matrix: torch.Tensor,
- log_source_distribution: torch.Tensor,
- log_target_distribution: torch.Tensor,
- num_iterations: int,
- ) -> torch.Tensor:
- """
- Perform Sinkhorn Normalization in Log-space for stability
- Args:
- log_cost_matrix (`torch.Tensor` of shape `(batch_size, num_rows, num_columns)`):
- Logarithm of the cost matrix.
- log_source_distribution (`torch.Tensor` of shape `(batch_size, num_rows)`):
- Logarithm of the source distribution.
- log_target_distribution (`torch.Tensor` of shape `(batch_size, num_columns)`):
- Logarithm of the target distribution.
- Returns:
- log_cost_matrix (`torch.Tensor` of shape `(batch_size, num_rows, num_columns)`): Logarithm of the optimal
- transport matrix.
- """
- log_u_scaling = torch.zeros_like(log_source_distribution)
- log_v_scaling = torch.zeros_like(log_target_distribution)
- for _ in range(num_iterations):
- log_u_scaling = log_source_distribution - torch.logsumexp(log_cost_matrix + log_v_scaling.unsqueeze(1), dim=2)
- log_v_scaling = log_target_distribution - torch.logsumexp(log_cost_matrix + log_u_scaling.unsqueeze(2), dim=1)
- return log_cost_matrix + log_u_scaling.unsqueeze(2) + log_v_scaling.unsqueeze(1)
- def log_optimal_transport(scores: torch.Tensor, reg_param: torch.Tensor, iterations: int) -> torch.Tensor:
- """
- Perform Differentiable Optimal Transport in Log-space for stability
- Args:
- scores: (`torch.Tensor` of shape `(batch_size, num_rows, num_columns)`):
- Cost matrix.
- reg_param: (`torch.Tensor` of shape `(batch_size, 1, 1)`):
- Regularization parameter.
- iterations: (`int`):
- Number of Sinkhorn iterations.
- Returns:
- log_optimal_transport_matrix: (`torch.Tensor` of shape `(batch_size, num_rows, num_columns)`): Logarithm of the
- optimal transport matrix.
- """
- batch_size, num_rows, num_columns = scores.shape
- one_tensor = scores.new_tensor(1)
- num_rows_tensor, num_columns_tensor = (num_rows * one_tensor).to(scores), (num_columns * one_tensor).to(scores)
- source_reg_param = reg_param.expand(batch_size, num_rows, 1)
- target_reg_param = reg_param.expand(batch_size, 1, num_columns)
- reg_param = reg_param.expand(batch_size, 1, 1)
- couplings = torch.cat([torch.cat([scores, source_reg_param], -1), torch.cat([target_reg_param, reg_param], -1)], 1)
- log_normalization = -(num_rows_tensor + num_columns_tensor).log()
- log_source_distribution = torch.cat(
- [log_normalization.expand(num_rows), num_columns_tensor.log()[None] + log_normalization]
- )
- log_target_distribution = torch.cat(
- [log_normalization.expand(num_columns), num_rows_tensor.log()[None] + log_normalization]
- )
- log_source_distribution, log_target_distribution = (
- log_source_distribution[None].expand(batch_size, -1),
- log_target_distribution[None].expand(batch_size, -1),
- )
- log_optimal_transport_matrix = log_sinkhorn_iterations(
- couplings, log_source_distribution, log_target_distribution, num_iterations=iterations
- )
- log_optimal_transport_matrix = log_optimal_transport_matrix - log_normalization # multiply probabilities by M+N
- return log_optimal_transport_matrix
- def arange_like(x, dim: int) -> torch.Tensor:
- return x.new_ones(x.shape[dim]).cumsum(0) - 1
- @dataclass
- @auto_docstring(
- custom_intro="""
- Base class for outputs of SuperGlue keypoint matching models. Due to the nature of keypoint detection and matching, the number
- of keypoints is not fixed and can vary from image to image, which makes batching non-trivial. In the batch of
- images, the maximum number of matches is set as the dimension of the matches and matching scores. The mask tensor is
- used to indicate which values in the keypoints, matches and matching_scores tensors are keypoint matching
- information.
- """
- )
- class SuperGlueKeypointMatchingOutput(ModelOutput):
- r"""
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*):
- Loss computed during training.
- matches (`torch.FloatTensor` of shape `(batch_size, 2, num_matches)`):
- Index of keypoint matched in the other image.
- matching_scores (`torch.FloatTensor` of shape `(batch_size, 2, num_matches)`):
- Scores of predicted matches.
- keypoints (`torch.FloatTensor` of shape `(batch_size, num_keypoints, 2)`):
- Absolute (x, y) coordinates of predicted keypoints in a given image.
- mask (`torch.IntTensor` of shape `(batch_size, num_keypoints)`):
- Mask indicating which values in matches and matching_scores are keypoint matching information.
- hidden_states (`tuple[torch.FloatTensor, ...]`, *optional*):
- Tuple of `torch.FloatTensor` (one for the output of each stage) of shape `(batch_size, 2, num_channels,
- num_keypoints)`, returned when `output_hidden_states=True` is passed or when
- `config.output_hidden_states=True`)
- attentions (`tuple[torch.FloatTensor, ...]`, *optional*):
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, 2, num_heads, num_keypoints,
- num_keypoints)`, returned when `output_attentions=True` is passed or when `config.output_attentions=True`)
- """
- loss: torch.FloatTensor | None = None
- matches: torch.FloatTensor | None = None
- matching_scores: torch.FloatTensor | None = None
- keypoints: torch.FloatTensor | None = None
- mask: torch.IntTensor | None = None
- hidden_states: tuple[torch.FloatTensor] | None = None
- attentions: tuple[torch.FloatTensor] | None = None
- class SuperGlueMultiLayerPerceptron(nn.Module):
- def __init__(self, config: SuperGlueConfig, in_channels: int, out_channels: int) -> None:
- super().__init__()
- self.linear = nn.Linear(in_channels, out_channels)
- self.batch_norm = nn.BatchNorm1d(out_channels)
- self.activation = nn.ReLU()
- def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
- hidden_state = self.linear(hidden_state)
- hidden_state = hidden_state.transpose(-1, -2)
- hidden_state = self.batch_norm(hidden_state)
- hidden_state = hidden_state.transpose(-1, -2)
- hidden_state = self.activation(hidden_state)
- return hidden_state
- class SuperGlueKeypointEncoder(nn.Module):
- def __init__(self, config: SuperGlueConfig) -> None:
- super().__init__()
- layer_sizes = config.keypoint_encoder_sizes
- hidden_size = config.hidden_size
- # 3 here consists of 2 for the (x, y) coordinates and 1 for the score of the keypoint
- encoder_channels = [3] + layer_sizes + [hidden_size]
- layers = [
- SuperGlueMultiLayerPerceptron(config, encoder_channels[i - 1], encoder_channels[i])
- for i in range(1, len(encoder_channels) - 1)
- ]
- layers.append(nn.Linear(encoder_channels[-2], encoder_channels[-1]))
- self.encoder = nn.ModuleList(layers)
- def forward(
- self,
- keypoints: torch.Tensor,
- scores: torch.Tensor,
- output_hidden_states: bool | None = False,
- ) -> tuple[torch.Tensor, tuple[torch.Tensor] | None]:
- scores = scores.unsqueeze(2)
- hidden_state = torch.cat([keypoints, scores], dim=2)
- all_hidden_states = () if output_hidden_states else None
- for layer in self.encoder:
- hidden_state = layer(hidden_state)
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_state,)
- return hidden_state, all_hidden_states
- class SuperGlueSelfAttention(nn.Module):
- def __init__(self, config):
- super().__init__()
- if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
- raise ValueError(
- f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
- f"heads ({config.num_attention_heads})"
- )
- self.num_attention_heads = config.num_attention_heads
- self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
- self.all_head_size = self.num_attention_heads * self.attention_head_size
- self.query = nn.Linear(config.hidden_size, self.all_head_size)
- self.key = nn.Linear(config.hidden_size, self.all_head_size)
- self.value = nn.Linear(config.hidden_size, self.all_head_size)
- self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
- self.is_decoder = config.is_decoder
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: torch.FloatTensor | None = None,
- encoder_hidden_states: torch.FloatTensor | None = None,
- encoder_attention_mask: torch.FloatTensor | None = None,
- output_attentions: bool | None = False,
- ) -> tuple[torch.Tensor]:
- # If this is instantiated as a cross-attention module, the keys
- # and values come from an encoder; the attention mask needs to be
- # such that the encoder's padding tokens are not attended to.
- is_cross_attention = encoder_hidden_states is not None
- current_states = encoder_hidden_states if is_cross_attention else hidden_states
- attention_mask = encoder_attention_mask if is_cross_attention else attention_mask
- batch_size = hidden_states.shape[0]
- key_layer = (
- self.key(current_states)
- .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
- .transpose(1, 2)
- )
- value_layer = (
- self.value(current_states)
- .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
- .transpose(1, 2)
- )
- query_layer = (
- self.query(hidden_states)
- .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
- .transpose(1, 2)
- )
- # Take the dot product between "query" and "key" to get the raw attention scores.
- attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
- attention_scores = attention_scores / math.sqrt(self.attention_head_size)
- if attention_mask is not None:
- # Apply the attention mask is (precomputed for all layers in SuperGlueModel forward() function)
- attention_scores = attention_scores + attention_mask
- # Normalize the attention scores to probabilities.
- attention_probs = nn.functional.softmax(attention_scores, dim=-1)
- # This is actually dropping out entire tokens to attend to, which might
- # seem a bit unusual, but is taken from the original Transformer paper.
- attention_probs = self.dropout(attention_probs)
- context_layer = torch.matmul(attention_probs, value_layer)
- context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
- new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
- context_layer = context_layer.view(new_context_layer_shape)
- outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
- if self.is_decoder:
- outputs = outputs + (None,)
- return outputs
- class SuperGlueSelfOutput(nn.Module):
- def __init__(self, config: SuperGlueConfig):
- super().__init__()
- self.dense = nn.Linear(config.hidden_size, config.hidden_size)
- def forward(self, hidden_states: torch.Tensor, *args) -> torch.Tensor:
- hidden_states = self.dense(hidden_states)
- return hidden_states
- SUPERGLUE_SELF_ATTENTION_CLASSES = {
- "eager": SuperGlueSelfAttention,
- }
- class SuperGlueAttention(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.self = SUPERGLUE_SELF_ATTENTION_CLASSES[config._attn_implementation](config)
- self.output = SuperGlueSelfOutput(config)
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: torch.FloatTensor | None = None,
- encoder_hidden_states: torch.FloatTensor | None = None,
- encoder_attention_mask: torch.Tensor | None = None,
- output_attentions: bool | None = False,
- ) -> tuple[torch.Tensor]:
- self_outputs = self.self(
- hidden_states,
- attention_mask=attention_mask,
- encoder_hidden_states=encoder_hidden_states,
- encoder_attention_mask=encoder_attention_mask,
- output_attentions=output_attentions,
- )
- attention_output = self.output(self_outputs[0], hidden_states)
- outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
- return outputs
- class SuperGlueAttentionalPropagation(nn.Module):
- def __init__(self, config: SuperGlueConfig) -> None:
- super().__init__()
- hidden_size = config.hidden_size
- self.attention = SuperGlueAttention(config)
- mlp_channels = [hidden_size * 2, hidden_size * 2, hidden_size]
- layers = [
- SuperGlueMultiLayerPerceptron(config, mlp_channels[i - 1], mlp_channels[i])
- for i in range(1, len(mlp_channels) - 1)
- ]
- layers.append(nn.Linear(mlp_channels[-2], mlp_channels[-1]))
- self.mlp = nn.ModuleList(layers)
- def forward(
- self,
- descriptors: torch.Tensor,
- attention_mask: torch.Tensor | None = None,
- encoder_hidden_states: torch.Tensor | None = None,
- encoder_attention_mask: torch.Tensor | None = None,
- output_attentions: bool = False,
- output_hidden_states: bool = False,
- ) -> tuple[torch.Tensor, tuple[torch.Tensor] | None, tuple[torch.Tensor] | None]:
- attention_outputs = self.attention(
- descriptors,
- attention_mask=attention_mask,
- encoder_hidden_states=encoder_hidden_states,
- encoder_attention_mask=encoder_attention_mask,
- output_attentions=output_attentions,
- )
- output = attention_outputs[0]
- attention = attention_outputs[1:]
- hidden_state = torch.cat([descriptors, output], dim=2)
- all_hidden_states = () if output_hidden_states else None
- for layer in self.mlp:
- hidden_state = layer(hidden_state)
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_state,)
- return hidden_state, all_hidden_states, attention
- class SuperGlueAttentionalGNN(nn.Module):
- def __init__(self, config: SuperGlueConfig) -> None:
- super().__init__()
- self.hidden_size = config.hidden_size
- self.layers_types = config.gnn_layers_types
- self.layers = nn.ModuleList([SuperGlueAttentionalPropagation(config) for _ in range(len(self.layers_types))])
- def forward(
- self,
- descriptors: torch.Tensor,
- mask: torch.Tensor | None = None,
- output_attentions: bool = False,
- output_hidden_states: bool | None = False,
- ) -> tuple[torch.Tensor, tuple | None, tuple | None]:
- all_hidden_states = () if output_hidden_states else None
- all_attentions = () if output_attentions else None
- batch_size, num_keypoints, _ = descriptors.shape
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (descriptors,)
- for gnn_layer, layer_type in zip(self.layers, self.layers_types):
- encoder_hidden_states = None
- encoder_attention_mask = None
- if layer_type == "cross":
- encoder_hidden_states = (
- descriptors.reshape(-1, 2, num_keypoints, self.hidden_size)
- .flip(1)
- .reshape(batch_size, num_keypoints, self.hidden_size)
- )
- encoder_attention_mask = (
- mask.reshape(-1, 2, 1, 1, num_keypoints).flip(1).reshape(batch_size, 1, 1, num_keypoints)
- if mask is not None
- else None
- )
- gnn_outputs = gnn_layer(
- descriptors,
- attention_mask=mask,
- encoder_hidden_states=encoder_hidden_states,
- encoder_attention_mask=encoder_attention_mask,
- output_hidden_states=output_hidden_states,
- output_attentions=output_attentions,
- )
- delta = gnn_outputs[0]
- if output_hidden_states:
- all_hidden_states = all_hidden_states + gnn_outputs[1]
- if output_attentions:
- all_attentions = all_attentions + gnn_outputs[2]
- descriptors = descriptors + delta
- return descriptors, all_hidden_states, all_attentions
- class SuperGlueFinalProjection(nn.Module):
- def __init__(self, config: SuperGlueConfig) -> None:
- super().__init__()
- hidden_size = config.hidden_size
- self.final_proj = nn.Linear(hidden_size, hidden_size, bias=True)
- def forward(self, descriptors: torch.Tensor) -> torch.Tensor:
- return self.final_proj(descriptors)
- @auto_docstring
- class SuperGluePreTrainedModel(PreTrainedModel):
- config: SuperGlueConfig
- base_model_prefix = "superglue"
- main_input_name = "pixel_values"
- input_modalities = ("image",)
- @torch.no_grad()
- def _init_weights(self, module: nn.Module) -> None:
- """Initialize the weights"""
- super()._init_weights(module)
- if hasattr(module, "bin_score"):
- init.ones_(module.bin_score)
- @auto_docstring(
- custom_intro="""
- SuperGlue model taking images as inputs and outputting the matching of them.
- """
- )
- class SuperGlueForKeypointMatching(SuperGluePreTrainedModel):
- """SuperGlue feature matching middle-end
- Given two sets of keypoints and locations, we determine the
- correspondences by:
- 1. Keypoint Encoding (normalization + visual feature and location fusion)
- 2. Graph Neural Network with multiple self and cross-attention layers
- 3. Final projection layer
- 4. Optimal Transport Layer (a differentiable Hungarian matching algorithm)
- 5. Thresholding matrix based on mutual exclusivity and a match_threshold
- The correspondence ids use -1 to indicate non-matching points.
- Paul-Edouard Sarlin, Daniel DeTone, Tomasz Malisiewicz, and Andrew
- Rabinovich. SuperGlue: Learning Feature Matching with Graph Neural
- Networks. In CVPR, 2020. https://huggingface.co/papers/1911.11763
- """
- def __init__(self, config: SuperGlueConfig) -> None:
- super().__init__(config)
- self.keypoint_detector = AutoModelForKeypointDetection.from_config(config.keypoint_detector_config)
- self.keypoint_encoder = SuperGlueKeypointEncoder(config)
- self.gnn = SuperGlueAttentionalGNN(config)
- self.final_projection = SuperGlueFinalProjection(config)
- bin_score = torch.nn.Parameter(torch.tensor(1.0))
- self.register_parameter("bin_score", bin_score)
- self.post_init()
- def _match_image_pair(
- self,
- keypoints: torch.Tensor,
- descriptors: torch.Tensor,
- scores: torch.Tensor,
- height: int,
- width: int,
- mask: torch.Tensor | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- ) -> tuple[torch.Tensor, torch.Tensor, tuple, tuple]:
- """
- Perform keypoint matching between two images.
- Args:
- keypoints (`torch.Tensor` of shape `(batch_size, 2, num_keypoints, 2)`):
- Keypoints detected in the pair of image.
- descriptors (`torch.Tensor` of shape `(batch_size, 2, descriptor_dim, num_keypoints)`):
- Descriptors of the keypoints detected in the image pair.
- scores (`torch.Tensor` of shape `(batch_size, 2, num_keypoints)`):
- Confidence scores of the keypoints detected in the image pair.
- height (`int`): Image height.
- width (`int`): Image width.
- mask (`torch.Tensor` of shape `(batch_size, 2, num_keypoints)`, *optional*):
- Mask indicating which values in the keypoints, matches and matching_scores tensors are keypoint matching
- information.
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors. Default to `config.output_attentions`.
- output_hidden_states (`bool`, *optional*):
- Whether or not to return the hidden states of all layers. Default to `config.output_hidden_states`.
- Returns:
- matches (`torch.Tensor` of shape `(batch_size, 2, num_keypoints)`):
- For each image pair, for each keypoint in image0, the index of the keypoint in image1 that was matched
- with. And for each keypoint in image1, the index of the keypoint in image0 that was matched with.
- matching_scores (`torch.Tensor` of shape `(batch_size, 2, num_keypoints)`):
- Scores of predicted matches for each image pair
- all_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
- Tuple of `torch.FloatTensor` (one for the output of each stage) of shape `(1, 2, num_keypoints,
- num_channels)`.
- all_attentions (`tuple(torch.FloatTensor)`, *optional*):
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(1, 2, num_heads, num_keypoints,
- num_keypoints)`.
- """
- all_hidden_states = () if output_hidden_states else None
- all_attentions = () if output_attentions else None
- if keypoints.shape[2] == 0: # no keypoints
- shape = keypoints.shape[:-1]
- return (
- keypoints.new_full(shape, -1, dtype=torch.int),
- keypoints.new_zeros(shape),
- all_hidden_states,
- all_attentions,
- )
- batch_size, _, num_keypoints, _ = keypoints.shape
- # (batch_size, 2, num_keypoints, 2) -> (batch_size * 2, num_keypoints, 2)
- keypoints = keypoints.reshape(batch_size * 2, num_keypoints, 2)
- descriptors = descriptors.reshape(batch_size * 2, num_keypoints, self.config.hidden_size)
- scores = scores.reshape(batch_size * 2, num_keypoints)
- mask = mask.reshape(batch_size * 2, num_keypoints) if mask is not None else None
- # Keypoint normalization
- keypoints = normalize_keypoints(keypoints, height, width)
- encoded_keypoints = self.keypoint_encoder(keypoints, scores, output_hidden_states=output_hidden_states)
- last_hidden_state = encoded_keypoints[0]
- # Keypoint MLP encoder.
- descriptors = descriptors + last_hidden_state
- if mask is not None:
- input_shape = descriptors.size()
- extended_attention_mask = self.get_extended_attention_mask(mask, input_shape)
- else:
- extended_attention_mask = torch.ones((batch_size, num_keypoints), device=keypoints.device)
- # Multi-layer Transformer network.
- gnn_outputs = self.gnn(
- descriptors,
- mask=extended_attention_mask,
- output_hidden_states=output_hidden_states,
- output_attentions=output_attentions,
- )
- descriptors = gnn_outputs[0]
- # Final MLP projection.
- projected_descriptors = self.final_projection(descriptors)
- # (batch_size * 2, num_keypoints, descriptor_dim) -> (batch_size, 2, num_keypoints, descriptor_dim)
- final_descriptors = projected_descriptors.reshape(batch_size, 2, num_keypoints, self.config.hidden_size)
- final_descriptors0 = final_descriptors[:, 0]
- final_descriptors1 = final_descriptors[:, 1]
- # Compute matching descriptor distance.
- scores = final_descriptors0 @ final_descriptors1.transpose(1, 2)
- scores = scores / self.config.hidden_size**0.5
- if mask is not None:
- mask = mask.reshape(batch_size, 2, num_keypoints)
- mask0 = mask[:, 0].unsqueeze(2)
- mask1 = mask[:, 1].unsqueeze(1)
- mask = torch.logical_and(mask0, mask1)
- scores = scores.masked_fill(mask == 0, torch.finfo(scores.dtype).min)
- # Run the optimal transport.
- scores = log_optimal_transport(scores, self.bin_score, iterations=self.config.sinkhorn_iterations)
- # Get the matches with score above "match_threshold".
- max0 = scores[:, :-1, :-1].max(2)
- max1 = scores[:, :-1, :-1].max(1)
- indices0 = max0.indices
- indices1 = max1.indices
- mutual0 = arange_like(indices0, 1)[None] == indices1.gather(1, indices0)
- mutual1 = arange_like(indices1, 1)[None] == indices0.gather(1, indices1)
- zero = scores.new_tensor(0)
- matching_scores0 = torch.where(mutual0, max0.values.exp(), zero)
- matching_scores0 = torch.where(matching_scores0 > self.config.matching_threshold, matching_scores0, zero)
- matching_scores1 = torch.where(mutual1, matching_scores0.gather(1, indices1), zero)
- valid0 = mutual0 & (matching_scores0 > zero)
- valid1 = mutual1 & valid0.gather(1, indices1)
- matches0 = torch.where(valid0, indices0, indices0.new_tensor(-1))
- matches1 = torch.where(valid1, indices1, indices1.new_tensor(-1))
- matches = torch.cat([matches0, matches1], dim=1).reshape(batch_size, 2, -1)
- matching_scores = torch.cat([matching_scores0, matching_scores1], dim=1).reshape(batch_size, 2, -1)
- if output_hidden_states:
- all_hidden_states = all_hidden_states + encoded_keypoints[1]
- all_hidden_states = all_hidden_states + gnn_outputs[1]
- all_hidden_states = all_hidden_states + (projected_descriptors,)
- all_hidden_states = tuple(
- x.reshape(batch_size, 2, num_keypoints, -1).transpose(-1, -2) for x in all_hidden_states
- )
- if output_attentions:
- all_attentions = all_attentions + gnn_outputs[2]
- all_attentions = tuple(x.reshape(batch_size, 2, -1, num_keypoints, num_keypoints) for x in all_attentions)
- return (
- matches,
- matching_scores,
- all_hidden_states,
- all_attentions,
- )
- @auto_docstring
- def forward(
- self,
- pixel_values: torch.FloatTensor,
- labels: torch.LongTensor | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- **kwargs,
- ) -> tuple | SuperGlueKeypointMatchingOutput:
- r"""
- Examples:
- ```python
- >>> from transformers import AutoImageProcessor, AutoModel
- >>> import torch
- >>> from PIL import Image
- >>> import httpx
- >>> from io import BytesIO
- >>> url = "https://github.com/magicleap/SuperGluePretrainedNetwork/blob/master/assets/phototourism_sample_images/london_bridge_78916675_4568141288.jpg?raw=true"
- >>> with httpx.stream("GET", url) as response:
- ... image_1 = Image.open(BytesIO(response.read()))
- >>> url = "https://github.com/magicleap/SuperGluePretrainedNetwork/blob/master/assets/phototourism_sample_images/london_bridge_19481797_2295892421.jpg?raw=true"
- >>> with httpx.stream("GET", url) as response:
- ... image_2 = Image.open(BytesIO(response.read()))
- >>> images = [image_1, image_2]
- >>> processor = AutoImageProcessor.from_pretrained("magic-leap-community/superglue_outdoor")
- >>> model = AutoModel.from_pretrained("magic-leap-community/superglue_outdoor")
- >>> with torch.no_grad():
- >>> inputs = processor(images, return_tensors="pt")
- >>> outputs = model(**inputs)
- ```"""
- loss = None
- if labels is not None:
- raise ValueError("SuperGlue is not trainable, no labels should be provided.")
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- if pixel_values.ndim != 5 or pixel_values.size(1) != 2:
- raise ValueError("Input must be a 5D tensor of shape (batch_size, 2, num_channels, height, width)")
- batch_size, _, channels, height, width = pixel_values.shape
- pixel_values = pixel_values.reshape(batch_size * 2, channels, height, width)
- keypoint_detections = self.keypoint_detector(pixel_values)
- keypoints, scores, descriptors, mask = keypoint_detections[:4]
- keypoints = keypoints.reshape(batch_size, 2, -1, 2).to(pixel_values)
- scores = scores.reshape(batch_size, 2, -1).to(pixel_values)
- descriptors = descriptors.reshape(batch_size, 2, -1, self.config.hidden_size).to(pixel_values)
- mask = mask.reshape(batch_size, 2, -1)
- absolute_keypoints = keypoints.clone()
- absolute_keypoints[:, :, :, 0] = absolute_keypoints[:, :, :, 0] * width
- absolute_keypoints[:, :, :, 1] = absolute_keypoints[:, :, :, 1] * height
- matches, matching_scores, hidden_states, attentions = self._match_image_pair(
- absolute_keypoints,
- descriptors,
- scores,
- height,
- width,
- mask=mask,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- )
- if not return_dict:
- return tuple(
- v
- for v in [loss, matches, matching_scores, keypoints, mask, hidden_states, attentions]
- if v is not None
- )
- return SuperGlueKeypointMatchingOutput(
- loss=loss,
- matches=matches,
- matching_scores=matching_scores,
- keypoints=keypoints,
- mask=mask,
- hidden_states=hidden_states,
- attentions=attentions,
- )
- __all__ = ["SuperGluePreTrainedModel", "SuperGlueForKeypointMatching"]
|