| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929 |
- # Copyright 2025 The HuggingFace Team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from collections.abc import Callable
- from dataclasses import dataclass
- import numpy as np
- import torch
- from huggingface_hub.dataclasses import strict
- from torch import nn
- from torch.nn.utils.rnn import pad_sequence
- from ...configuration_utils import PreTrainedConfig
- from ...modeling_flash_attention_utils import FlashAttentionKwargs
- from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
- from ...processing_utils import ImagesKwargs, Unpack
- from ...utils import ModelOutput, TensorType, auto_docstring, can_return_tuple, logging
- from ...utils.import_utils import requires
- from ..auto import CONFIG_MAPPING, AutoConfig
- from ..auto.modeling_auto import AutoModelForKeypointDetection
- from ..clip.modeling_clip import CLIPMLP
- from ..cohere.modeling_cohere import apply_rotary_pos_emb
- from ..llama.modeling_llama import LlamaAttention, eager_attention_forward
- from ..superglue.image_processing_pil_superglue import SuperGlueImageProcessorPil
- from ..superglue.image_processing_superglue import SuperGlueImageProcessor
- from ..superpoint import SuperPointConfig
- logger = logging.get_logger(__name__)
- @auto_docstring(checkpoint="ETH-CVG/lightglue_superpoint")
- @strict
- class LightGlueConfig(PreTrainedConfig):
- r"""
- keypoint_detector_config (`Union[AutoConfig, dict]`, *optional*, defaults to `SuperPointConfig`):
- The config object or dictionary of the keypoint detector.
- descriptor_dim (`int`, *optional*, defaults to 256):
- The dimension of the descriptors.
- depth_confidence (`float`, *optional*, defaults to 0.95):
- The confidence threshold used to perform early stopping
- width_confidence (`float`, *optional*, defaults to 0.99):
- The confidence threshold used to prune points
- filter_threshold (`float`, *optional*, defaults to 0.1):
- The confidence threshold used to filter matches
- Examples:
- ```python
- >>> from transformers import LightGlueConfig, LightGlueForKeypointMatching
- >>> # Initializing a LightGlue style configuration
- >>> configuration = LightGlueConfig()
- >>> # Initializing a model from the LightGlue style configuration
- >>> model = LightGlueForKeypointMatching(configuration)
- >>> # Accessing the model configuration
- >>> configuration = model.config
- ```
- """
- model_type = "lightglue"
- sub_configs = {"keypoint_detector_config": AutoConfig}
- keypoint_detector_config: dict | SuperPointConfig | None = None
- descriptor_dim: int = 256
- num_hidden_layers: int = 9
- num_attention_heads: int = 4
- num_key_value_heads: int | None = None
- depth_confidence: float = 0.95
- width_confidence: float = 0.99
- filter_threshold: float = 0.1
- initializer_range: float = 0.02
- hidden_act: str = "gelu"
- attention_dropout: float | int = 0.0
- attention_bias: bool = True
- def __post_init__(self, **kwargs):
- if self.num_key_value_heads is None:
- self.num_key_value_heads = self.num_attention_heads
- # Keypoint Detector is forced into eager attention mode because SuperPoint does not have Attention
- # See https://github.com/huggingface/transformers/pull/31718#discussion_r2109733153
- if isinstance(self.keypoint_detector_config, dict):
- self.keypoint_detector_config["model_type"] = self.keypoint_detector_config.get("model_type", "superpoint")
- self.keypoint_detector_config = CONFIG_MAPPING[self.keypoint_detector_config["model_type"]](
- **self.keypoint_detector_config, attn_implementation="eager"
- )
- elif self.keypoint_detector_config is None:
- self.keypoint_detector_config = CONFIG_MAPPING["superpoint"](attn_implementation="eager")
- self.intermediate_size = self.descriptor_dim * 2
- self.hidden_size = self.descriptor_dim
- super().__post_init__(**kwargs)
- def validate_architecture(self):
- """Part of `@strict`-powered validation. Validates the architecture of the config."""
- if self.descriptor_dim % self.num_attention_heads != 0:
- raise ValueError("descriptor_dim % num_heads is different from zero")
- @dataclass
- @auto_docstring(
- custom_intro="""
- Base class for outputs of LightGlue keypoint matching models. Due to the nature of keypoint detection and matching,
- the number of keypoints is not fixed and can vary from image to image, which makes batching non-trivial. In the
- batch of images, the maximum number of matches is set as the dimension of the matches and matching scores. The mask
- tensor is used to indicate which values in the keypoints, matches, matching_scores and prune tensors are keypoint
- matching information.
- """
- )
- class LightGlueKeypointMatchingOutput(ModelOutput):
- r"""
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*):
- Loss computed during training.
- matches (`torch.FloatTensor` of shape `(batch_size, 2, num_matches)`):
- Index of keypoint matched in the other image.
- matching_scores (`torch.FloatTensor` of shape `(batch_size, 2, num_matches)`):
- Scores of predicted matches.
- keypoints (`torch.FloatTensor` of shape `(batch_size, num_keypoints, 2)`):
- Absolute (x, y) coordinates of predicted keypoints in a given image.
- prune (`torch.IntTensor` of shape `(batch_size, num_keypoints)`):
- Pruning mask indicating which keypoints are removed and at which layer.
- mask (`torch.BoolTensor` of shape `(batch_size, num_keypoints)`):
- Mask indicating which values in matches, matching_scores, keypoints and prune are keypoint matching
- information.
- hidden_states (`Tuple[torch.FloatTensor, ...]`, *optional*):
- Tuple of `torch.FloatTensor` (one for the output of each stage) of shape `(batch_size, 2, num_channels,
- num_keypoints)` returned when `output_hidden_states=True` is passed or when
- `config.output_hidden_states=True`
- attentions (`Tuple[torch.FloatTensor, ...]`, *optional*):
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, 2, num_heads, num_keypoints,
- num_keypoints)` returned when `output_attentions=True` is passed or when
- `config.output_attentions=True`
- """
- loss: torch.FloatTensor | None = None
- matches: torch.FloatTensor | None = None
- matching_scores: torch.FloatTensor | None = None
- keypoints: torch.FloatTensor | None = None
- prune: torch.IntTensor | None = None
- mask: torch.FloatTensor | None = None
- hidden_states: tuple[torch.FloatTensor] | None = None
- attentions: tuple[torch.FloatTensor] | None = None
- class LightGlueImageProcessorKwargs(ImagesKwargs, total=False):
- r"""
- do_grayscale (`bool`, *optional*, defaults to `self.do_grayscale`):
- Whether to convert the image to grayscale. Can be overridden by `do_grayscale` in the `preprocess` method.
- """
- do_grayscale: bool
- class LightGlueImageProcessor(SuperGlueImageProcessor):
- def post_process_keypoint_matching(
- self,
- outputs: "LightGlueKeypointMatchingOutput",
- target_sizes: TensorType | list[tuple],
- threshold: float = 0.0,
- ) -> list[dict[str, torch.Tensor]]:
- return super().post_process_keypoint_matching(outputs, target_sizes, threshold)
- @requires(backends=("torch",))
- class LightGlueImageProcessorPil(SuperGlueImageProcessorPil):
- @requires(backends=("torch",))
- def post_process_keypoint_matching(
- self,
- outputs: "LightGlueKeypointMatchingOutput",
- target_sizes: TensorType | list[tuple],
- threshold: float = 0.0,
- ) -> list[dict[str, "torch.Tensor"]]:
- return super().post_process_keypoint_matching(outputs, target_sizes, threshold)
- class LightGluePositionalEncoder(nn.Module):
- def __init__(self, config: LightGlueConfig):
- super().__init__()
- self.projector = nn.Linear(2, config.descriptor_dim // config.num_attention_heads // 2, bias=False)
- def forward(
- self, keypoints: torch.Tensor, output_hidden_states: bool | None = False
- ) -> tuple[torch.Tensor] | tuple[torch.Tensor, torch.Tensor]:
- projected_keypoints = self.projector(keypoints)
- embeddings = projected_keypoints.repeat_interleave(2, dim=-1)
- cosines = torch.cos(embeddings)
- sines = torch.sin(embeddings)
- embeddings = (cosines, sines)
- output = (embeddings, projected_keypoints) if output_hidden_states else (embeddings,)
- return output
- class LightGlueAttention(LlamaAttention):
- def __init__(self, config: LightGlueConfig, layer_idx: int):
- super().__init__()
- del self.rotary_emb
- def forward(
- self,
- hidden_states: torch.Tensor,
- position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
- attention_mask: torch.Tensor | None = None,
- encoder_hidden_states: torch.Tensor | None = None,
- encoder_attention_mask: torch.Tensor | None = None,
- **kwargs: Unpack[FlashAttentionKwargs],
- ) -> tuple[torch.Tensor, torch.Tensor | None]:
- input_shape = hidden_states.shape[:-1]
- hidden_shape = (*input_shape, -1, self.head_dim)
- query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- is_cross_attention = encoder_hidden_states is not None
- current_states = encoder_hidden_states if is_cross_attention else hidden_states
- current_attention_mask = encoder_attention_mask if is_cross_attention else attention_mask
- key_states = self.k_proj(current_states).view(hidden_shape).transpose(1, 2)
- value_states = self.v_proj(current_states).view(hidden_shape).transpose(1, 2)
- if position_embeddings is not None:
- cos, sin = position_embeddings
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
- attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
- self.config._attn_implementation, eager_attention_forward
- )
- attn_output, attn_weights = attention_interface(
- self,
- query_states,
- key_states,
- value_states,
- current_attention_mask,
- dropout=0.0 if not self.training else self.attention_dropout,
- scaling=self.scaling,
- **kwargs,
- )
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
- attn_output = self.o_proj(attn_output)
- return attn_output, attn_weights
- class LightGlueMLP(CLIPMLP):
- def __init__(self, config: LightGlueConfig):
- super().__init__(config)
- self.fc1 = nn.Linear(config.intermediate_size, config.intermediate_size)
- self.layer_norm = nn.LayerNorm(config.intermediate_size, elementwise_affine=True)
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- hidden_states = self.fc1(hidden_states)
- hidden_states = self.layer_norm(hidden_states)
- hidden_states = self.activation_fn(hidden_states)
- hidden_states = self.fc2(hidden_states)
- return hidden_states
- class LightGlueTransformerLayer(nn.Module):
- def __init__(self, config: LightGlueConfig, layer_idx: int):
- super().__init__()
- self.self_attention = LightGlueAttention(config, layer_idx)
- self.self_mlp = LightGlueMLP(config)
- self.cross_attention = LightGlueAttention(config, layer_idx)
- self.cross_mlp = LightGlueMLP(config)
- def forward(
- self,
- descriptors: torch.Tensor,
- keypoints: torch.Tensor,
- attention_mask: torch.Tensor,
- output_hidden_states: bool | None = False,
- output_attentions: bool | None = False,
- ) -> tuple[torch.Tensor, tuple[torch.Tensor] | None, tuple[torch.Tensor] | None]:
- all_hidden_states = () if output_hidden_states else None
- all_attentions = () if output_attentions else None
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (descriptors,)
- batch_size, num_keypoints, descriptor_dim = descriptors.shape
- # Self attention block
- attention_output, self_attentions = self.self_attention(
- descriptors,
- position_embeddings=keypoints,
- attention_mask=attention_mask,
- output_attentions=output_attentions,
- )
- intermediate_states = torch.cat([descriptors, attention_output], dim=-1)
- output_states = self.self_mlp(intermediate_states)
- self_attention_descriptors = descriptors + output_states
- if output_hidden_states:
- self_attention_hidden_states = (intermediate_states, output_states)
- # Reshape hidden_states to group by image_pairs :
- # (batch_size, num_keypoints, descriptor_dim) -> (batch_size, 2, num_keypoints, descriptor_dim)
- # Flip dimension 1 to perform cross attention :
- # (image0, image1) -> (image1, image0)
- # Reshape back to original shape :
- # (batch_size, 2, num_keypoints, descriptor_dim) -> (batch_size, num_keypoints, descriptor_dim)
- encoder_hidden_states = (
- self_attention_descriptors.reshape(-1, 2, num_keypoints, descriptor_dim)
- .flip(1)
- .reshape(batch_size, num_keypoints, descriptor_dim)
- )
- # Same for mask
- encoder_attention_mask = (
- attention_mask.reshape(-1, 2, 1, 1, num_keypoints).flip(1).reshape(batch_size, 1, 1, num_keypoints)
- if attention_mask is not None
- else None
- )
- # Cross attention block
- cross_attention_output, cross_attentions = self.cross_attention(
- self_attention_descriptors,
- encoder_hidden_states=encoder_hidden_states,
- encoder_attention_mask=encoder_attention_mask,
- output_attentions=output_attentions,
- )
- cross_intermediate_states = torch.cat([self_attention_descriptors, cross_attention_output], dim=-1)
- cross_output_states = self.cross_mlp(cross_intermediate_states)
- descriptors = self_attention_descriptors + cross_output_states
- if output_hidden_states:
- cross_attention_hidden_states = (cross_intermediate_states, cross_output_states)
- all_hidden_states = (
- all_hidden_states
- + (self_attention_descriptors.reshape(batch_size, num_keypoints, descriptor_dim),)
- + self_attention_hidden_states
- + (descriptors.reshape(batch_size, num_keypoints, descriptor_dim),)
- + cross_attention_hidden_states
- )
- if output_attentions:
- all_attentions = all_attentions + (self_attentions,) + (cross_attentions,)
- return descriptors, all_hidden_states, all_attentions
- def sigmoid_log_double_softmax(
- similarity: torch.Tensor, matchability0: torch.Tensor, matchability1: torch.Tensor
- ) -> torch.Tensor:
- """create the log assignment matrix from logits and similarity"""
- batch_size, num_keypoints_0, num_keypoints_1 = similarity.shape
- certainties = nn.functional.logsigmoid(matchability0) + nn.functional.logsigmoid(matchability1).transpose(1, 2)
- scores0 = nn.functional.log_softmax(similarity, 2)
- scores1 = nn.functional.log_softmax(similarity.transpose(-1, -2).contiguous(), 2).transpose(-1, -2)
- scores = similarity.new_full((batch_size, num_keypoints_0 + 1, num_keypoints_1 + 1), 0)
- scores[:, :num_keypoints_0, :num_keypoints_1] = scores0 + scores1 + certainties
- scores[:, :-1, -1] = nn.functional.logsigmoid(-matchability0.squeeze(-1))
- scores[:, -1, :-1] = nn.functional.logsigmoid(-matchability1.squeeze(-1))
- return scores
- class LightGlueMatchAssignmentLayer(nn.Module):
- def __init__(self, config: LightGlueConfig):
- super().__init__()
- self.descriptor_dim = config.descriptor_dim
- self.final_projection = nn.Linear(self.descriptor_dim, self.descriptor_dim, bias=True)
- self.matchability = nn.Linear(self.descriptor_dim, 1, bias=True)
- def forward(self, descriptors: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
- batch_size, num_keypoints, descriptor_dim = descriptors.shape
- # Final projection and similarity computation
- m_descriptors = self.final_projection(descriptors)
- m_descriptors = m_descriptors / torch.tensor(self.descriptor_dim, device=m_descriptors.device) ** 0.25
- m_descriptors = m_descriptors.reshape(batch_size // 2, 2, num_keypoints, descriptor_dim)
- m_descriptors0 = m_descriptors[:, 0]
- m_descriptors1 = m_descriptors[:, 1]
- similarity = m_descriptors0 @ m_descriptors1.transpose(-1, -2)
- if mask is not None:
- mask = mask.reshape(batch_size // 2, 2, num_keypoints)
- mask0 = mask[:, 0].unsqueeze(-1)
- mask1 = mask[:, 1].unsqueeze(-1).transpose(-1, -2)
- mask = mask0 * mask1
- similarity = similarity.masked_fill(mask == 0, torch.finfo(similarity.dtype).min)
- # Compute matchability of descriptors
- matchability = self.matchability(descriptors)
- matchability = matchability.reshape(batch_size // 2, 2, num_keypoints, 1)
- matchability_0 = matchability[:, 0]
- matchability_1 = matchability[:, 1]
- # Compute scores from similarity and matchability
- scores = sigmoid_log_double_softmax(similarity, matchability_0, matchability_1)
- return scores
- def get_matchability(self, descriptors: torch.Tensor) -> torch.Tensor:
- """Get matchability of descriptors as a probability"""
- matchability = self.matchability(descriptors)
- matchability = nn.functional.sigmoid(matchability).squeeze(-1)
- return matchability
- class LightGlueTokenConfidenceLayer(nn.Module):
- def __init__(self, config: LightGlueConfig):
- super().__init__()
- self.token = nn.Linear(config.descriptor_dim, 1)
- def forward(self, descriptors: torch.Tensor) -> torch.Tensor:
- token = self.token(descriptors.detach())
- token = nn.functional.sigmoid(token).squeeze(-1)
- return token
- @auto_docstring
- class LightGluePreTrainedModel(PreTrainedModel):
- """
- An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
- models.
- """
- config: LightGlueConfig
- base_model_prefix = "lightglue"
- main_input_name = "pixel_values"
- input_modalities = ("image",)
- supports_gradient_checkpointing = False
- _supports_flash_attn = True
- _supports_sdpa = True
- def get_matches_from_scores(scores: torch.Tensor, threshold: float) -> tuple[torch.Tensor, torch.Tensor]:
- """obtain matches from a score matrix [Bx M+1 x N+1]"""
- batch_size, _, _ = scores.shape
- # For each keypoint, get the best match
- max0 = scores[:, :-1, :-1].max(2)
- max1 = scores[:, :-1, :-1].max(1)
- matches0 = max0.indices
- matches1 = max1.indices
- # Mutual check for matches
- indices0 = torch.arange(matches0.shape[1], device=matches0.device)[None]
- indices1 = torch.arange(matches1.shape[1], device=matches1.device)[None]
- mutual0 = indices0 == matches1.gather(1, matches0)
- mutual1 = indices1 == matches0.gather(1, matches1)
- # Get matching scores and filter based on mutual check and thresholding
- max0 = max0.values.exp()
- zero = max0.new_tensor(0)
- matching_scores0 = torch.where(mutual0, max0, zero)
- matching_scores1 = torch.where(mutual1, matching_scores0.gather(1, matches1), zero)
- valid0 = mutual0 & (matching_scores0 > threshold)
- valid1 = mutual1 & valid0.gather(1, matches1)
- # Filter matches based on mutual check and thresholding of scores
- matches0 = torch.where(valid0, matches0, -1)
- matches1 = torch.where(valid1, matches1, -1)
- matches = torch.stack([matches0, matches1]).transpose(0, 1).reshape(batch_size * 2, -1)
- matching_scores = torch.stack([matching_scores0, matching_scores1]).transpose(0, 1).reshape(batch_size * 2, -1)
- return matches, matching_scores
- def normalize_keypoints(keypoints: torch.Tensor, height: int, width: int) -> torch.Tensor:
- """
- Normalize keypoints locations based on image image_shape
- Args:
- keypoints (`torch.Tensor` of shape `(batch_size, num_keypoints, 2)`):
- Keypoints locations in (x, y) format.
- height (`int`):
- Image height.
- width (`int`):
- Image width.
- Returns:
- Normalized keypoints locations of shape (`torch.Tensor` of shape `(batch_size, num_keypoints, 2)`).
- """
- size = torch.tensor([width, height], device=keypoints.device, dtype=keypoints.dtype)[None]
- shift = size / 2
- scale = size.max(-1).values / 2
- keypoints = (keypoints - shift[..., None, :]) / scale[..., None, None]
- return keypoints
- @auto_docstring(
- custom_intro="""
- LightGlue model taking images as inputs and outputting the matching of them.
- """
- )
- class LightGlueForKeypointMatching(LightGluePreTrainedModel):
- """
- LightGlue is a model matching keypoints in images by leveraging detections from a keypoint detector such as
- SuperPoint. It is based on the SuperGlue architecture and is designed to be lightweight and efficient.
- It consists of :
- 1. Keypoint Encoder
- 2. A Graph Neural Network with self and cross attention layers
- 3. Matching Assignment layers
- The correspondence ids use -1 to indicate non-matching points.
- Philipp Lindenberger, Paul-Edouard Sarlin and Marc Pollefeys. LightGlue: Local Feature Matching at Light Speed.
- In ICCV 2023. https://huggingface.co/papers/2306.13643
- """
- def __init__(self, config: LightGlueConfig):
- super().__init__(config)
- self.keypoint_detector = AutoModelForKeypointDetection.from_config(config.keypoint_detector_config)
- self.keypoint_detector_descriptor_dim = config.keypoint_detector_config.descriptor_decoder_dim
- self.descriptor_dim = config.descriptor_dim
- self.num_layers = config.num_hidden_layers
- self.filter_threshold = config.filter_threshold
- self.depth_confidence = config.depth_confidence
- self.width_confidence = config.width_confidence
- if self.descriptor_dim != self.keypoint_detector_descriptor_dim:
- self.input_projection = nn.Linear(self.keypoint_detector_descriptor_dim, self.descriptor_dim, bias=True)
- else:
- self.input_projection = nn.Identity()
- self.positional_encoder = LightGluePositionalEncoder(config)
- self.transformer_layers = nn.ModuleList(
- [LightGlueTransformerLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]
- )
- self.match_assignment_layers = nn.ModuleList(
- [LightGlueMatchAssignmentLayer(config) for _ in range(config.num_hidden_layers)]
- )
- self.token_confidence = nn.ModuleList(
- [LightGlueTokenConfidenceLayer(config) for _ in range(config.num_hidden_layers - 1)]
- )
- self.post_init()
- def _get_confidence_threshold(self, layer_index: int) -> float:
- """scaled confidence threshold for a given layer"""
- threshold = 0.8 + 0.1 * np.exp(-4.0 * layer_index / self.num_layers)
- return np.clip(threshold, 0, 1)
- def _keypoint_processing(
- self, descriptors: torch.Tensor, keypoints: torch.Tensor, output_hidden_states: bool | None = False
- ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
- descriptors = descriptors.detach().contiguous()
- projected_descriptors = self.input_projection(descriptors)
- keypoint_encoding_output = self.positional_encoder(keypoints, output_hidden_states=output_hidden_states)
- return projected_descriptors, keypoint_encoding_output
- def _get_early_stopped_image_pairs(
- self, keypoint_confidences: torch.Tensor, layer_index: int, mask: torch.Tensor, num_points: torch.Tensor
- ) -> torch.Tensor:
- """evaluate whether we should stop inference based on the confidence of the keypoints"""
- batch_size, _ = mask.shape
- if layer_index < self.num_layers - 1:
- # If the current layer is not the last layer, we compute the confidence of the keypoints and check
- # if we should stop the forward pass through the transformer layers for each pair of images.
- keypoint_confidences = keypoint_confidences.masked_fill(mask == 0, 1)
- keypoint_confidences = keypoint_confidences.reshape(batch_size // 2, -1)
- threshold = self._get_confidence_threshold(layer_index)
- ratio_confident = 1.0 - (keypoint_confidences < threshold).float().sum(dim=1) / num_points
- early_stopped_pairs = ratio_confident > self.depth_confidence
- else:
- # If the current layer is the last layer, we stop the forward pass through the transformer layers for
- # all pairs of images.
- early_stopped_pairs = torch.ones(batch_size, dtype=torch.bool)
- return early_stopped_pairs
- def _get_keypoint_matching(self, descriptors, mask, layer_index, early_stops=None):
- if early_stops is not None:
- descriptors = descriptors[early_stops]
- mask = mask[early_stops]
- scores = self.match_assignment_layers[layer_index](descriptors, mask)
- matches, matching_scores = get_matches_from_scores(scores, self.filter_threshold)
- return matches, matching_scores
- def _get_pruning_mask(self, confidences: torch.Tensor, scores: torch.Tensor, layer_index: int) -> torch.Tensor:
- """mask points which should be removed"""
- keep = scores > (1 - self.width_confidence)
- if confidences is not None: # Low-confidence points are never pruned.
- keep |= confidences <= self._get_confidence_threshold(layer_index)
- return keep
- def _do_layer_keypoint_pruning(
- self,
- descriptors: torch.Tensor,
- keypoints: torch.Tensor,
- mask: torch.Tensor,
- indices: torch.Tensor,
- prune_output: torch.Tensor,
- keypoint_confidences: torch.Tensor,
- layer_index: int,
- ):
- """
- For a given layer, prune keypoints based on the confidence of the keypoints and the matchability of the
- descriptors.
- """
- batch_size, _, _ = descriptors.shape
- descriptors_matchability = self.match_assignment_layers[layer_index].get_matchability(descriptors)
- pruned_keypoints_mask = self._get_pruning_mask(keypoint_confidences, descriptors_matchability, layer_index)
- pruned_keypoints_mask = pruned_keypoints_mask.masked_fill(mask == 0, torch.tensor(False))
- # For each image, we extract the pruned indices and the corresponding descriptors and keypoints.
- pruned_descriptors, pruned_keypoints_0, pruned_keypoints_1, pruned_mask, pruned_indices = (
- [t[mask] for t, mask in zip(tensor, pruned_keypoints_mask)]
- for tensor in [descriptors, keypoints[0], keypoints[1], pruned_keypoints_mask, indices]
- )
- for i in range(batch_size):
- prune_output[i, pruned_indices[i]] += 1
- # Pad the pruned descriptors, keypoints, indices and mask to have the same shape across the batch.
- pruned_descriptors, pruned_keypoints_0, pruned_keypoints_1, pruned_mask = (
- pad_sequence(pruned_tensor, batch_first=True)
- for pruned_tensor in [pruned_descriptors, pruned_keypoints_0, pruned_keypoints_1, pruned_mask]
- )
- pruned_keypoints = (pruned_keypoints_0, pruned_keypoints_1)
- pruned_indices = pad_sequence(pruned_indices, batch_first=True, padding_value=-1)
- return pruned_descriptors, pruned_keypoints, pruned_indices, pruned_mask, prune_output
- def _concat_early_stopped_outputs(
- self,
- early_stops_indices,
- final_pruned_keypoints_indices,
- final_pruned_keypoints_iterations,
- matches,
- matching_scores,
- ):
- early_stops_indices = torch.stack(early_stops_indices)
- # Rearrange tensors to have the same order as the input batch
- ids = torch.arange(early_stops_indices.shape[0])
- order_indices = early_stops_indices[ids]
- early_stops_indices = early_stops_indices[order_indices]
- matches, final_pruned_keypoints_indices = (
- pad_sequence(tensor, batch_first=True, padding_value=-1)
- for tensor in [matches, final_pruned_keypoints_indices]
- )
- matching_scores, final_pruned_keypoints_iterations = (
- pad_sequence(tensor, batch_first=True, padding_value=0)
- for tensor in [matching_scores, final_pruned_keypoints_iterations]
- )
- matches, matching_scores, final_pruned_keypoints_indices, final_pruned_keypoints_iterations = (
- tensor[early_stops_indices]
- for tensor in [
- matches,
- matching_scores,
- final_pruned_keypoints_indices,
- final_pruned_keypoints_iterations,
- ]
- )
- return final_pruned_keypoints_indices, final_pruned_keypoints_iterations, matches, matching_scores
- def _do_final_keypoint_pruning(
- self,
- indices: torch.Tensor,
- matches: torch.Tensor,
- matching_scores: torch.Tensor,
- num_keypoints: torch.Tensor,
- ) -> tuple[torch.Tensor, torch.Tensor]:
- # (batch_size, num_keypoints) -> (batch_size // 2, 2, num_keypoints) -> 2 * (batch_size // 2, num_keypoints) to
- # have tensors from
- batch_size, _ = indices.shape
- indices, matches, matching_scores = (
- tensor.reshape(batch_size // 2, 2, -1) for tensor in [indices, matches, matching_scores]
- )
- indices0 = indices[:, 0]
- indices1 = indices[:, 1]
- matches0 = matches[:, 0]
- matches1 = matches[:, 1]
- matching_scores0 = matching_scores[:, 0]
- matching_scores1 = matching_scores[:, 1]
- # Prepare final matches and matching scores
- _matches = torch.full((batch_size // 2, 2, num_keypoints), -1, device=indices.device, dtype=matches.dtype)
- _matching_scores = torch.zeros(
- (batch_size // 2, 2, num_keypoints), device=indices.device, dtype=matching_scores.dtype
- )
- # Fill the matches and matching scores for each image pair
- for i in range(batch_size // 2):
- _matches[i, 0, indices0[i]] = torch.where(
- matches0[i] == -1, -1, indices1[i].gather(0, matches0[i].clamp(min=0))
- )
- _matches[i, 1, indices1[i]] = torch.where(
- matches1[i] == -1, -1, indices0[i].gather(0, matches1[i].clamp(min=0))
- )
- _matching_scores[i, 0, indices0[i]] = matching_scores0[i]
- _matching_scores[i, 1, indices1[i]] = matching_scores1[i]
- return _matches, _matching_scores
- def _match_image_pair(
- self,
- keypoints: torch.Tensor,
- descriptors: torch.Tensor,
- height: int,
- width: int,
- mask: torch.Tensor | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, tuple, tuple]:
- all_hidden_states = () if output_hidden_states else None
- all_attentions = () if output_attentions else None
- if keypoints.shape[2] == 0: # no keypoints
- shape = keypoints.shape[:-1]
- return (
- keypoints.new_full(shape, -1, dtype=torch.int),
- keypoints.new_zeros(shape),
- keypoints.new_zeros(shape),
- all_hidden_states,
- all_attentions,
- )
- device = keypoints.device
- batch_size, _, initial_num_keypoints, _ = keypoints.shape
- num_points_per_pair = torch.sum(mask.reshape(batch_size, -1), dim=1)
- # (batch_size, 2, num_keypoints, 2) -> (batch_size * 2, num_keypoints, 2)
- keypoints = keypoints.reshape(batch_size * 2, initial_num_keypoints, 2)
- mask = mask.reshape(batch_size * 2, initial_num_keypoints) if mask is not None else None
- descriptors = descriptors.reshape(batch_size * 2, initial_num_keypoints, self.keypoint_detector_descriptor_dim)
- image_indices = torch.arange(batch_size * 2, device=device)
- # Keypoint normalization
- keypoints = normalize_keypoints(keypoints, height, width)
- descriptors, keypoint_encoding_output = self._keypoint_processing(
- descriptors, keypoints, output_hidden_states=output_hidden_states
- )
- keypoints = keypoint_encoding_output[0]
- # Early stop consists of stopping the forward pass through the transformer layers when the confidence of the
- # keypoints is above a certain threshold.
- do_early_stop = self.depth_confidence > 0
- # Keypoint pruning consists of removing keypoints from the input of the transformer layers when the confidence of
- # the keypoints is below a certain threshold.
- do_keypoint_pruning = self.width_confidence > 0
- early_stops_indices = []
- matches = []
- matching_scores = []
- final_pruned_keypoints_indices = []
- final_pruned_keypoints_iterations = []
- pruned_keypoints_indices = torch.arange(0, initial_num_keypoints, device=device).expand(batch_size * 2, -1)
- pruned_keypoints_iterations = torch.ones_like(pruned_keypoints_indices)
- for layer_index in range(self.num_layers):
- input_shape = descriptors.size()
- if mask is not None:
- extended_attention_mask = self.get_extended_attention_mask(mask, input_shape)
- else:
- extended_attention_mask = torch.ones((batch_size, input_shape[-2]), device=keypoints.device)
- layer_output = self.transformer_layers[layer_index](
- descriptors,
- keypoints,
- attention_mask=extended_attention_mask,
- output_hidden_states=output_hidden_states,
- output_attentions=output_attentions,
- )
- descriptors, hidden_states, attention = layer_output
- if output_hidden_states:
- all_hidden_states = all_hidden_states + hidden_states
- if output_attentions:
- all_attentions = all_attentions + attention
- if do_early_stop:
- if layer_index < self.num_layers - 1:
- # Get the confidence of the keypoints for the current layer
- keypoint_confidences = self.token_confidence[layer_index](descriptors)
- # Determine which pairs of images should be early stopped based on the confidence of the keypoints for
- # the current layer.
- early_stopped_pairs = self._get_early_stopped_image_pairs(
- keypoint_confidences, layer_index, mask, num_points=num_points_per_pair
- )
- else:
- # Early stopping always occurs at the last layer
- early_stopped_pairs = torch.ones(batch_size, dtype=torch.bool)
- if torch.any(early_stopped_pairs):
- # If a pair of images is considered early stopped, we compute the matches for the remaining
- # keypoints and stop the forward pass through the transformer layers for this pair of images.
- early_stops = early_stopped_pairs.repeat_interleave(2)
- early_stopped_image_indices = image_indices[early_stops]
- early_stopped_matches, early_stopped_matching_scores = self._get_keypoint_matching(
- descriptors, mask, layer_index, early_stops=early_stops
- )
- early_stops_indices.extend(list(early_stopped_image_indices))
- matches.extend(list(early_stopped_matches))
- matching_scores.extend(list(early_stopped_matching_scores))
- if do_keypoint_pruning:
- final_pruned_keypoints_indices.extend(list(pruned_keypoints_indices[early_stops]))
- final_pruned_keypoints_iterations.extend(list(pruned_keypoints_iterations[early_stops]))
- # Remove image pairs that have been early stopped from the forward pass
- num_points_per_pair = num_points_per_pair[~early_stopped_pairs]
- descriptors, keypoints_0, keypoint_1, mask, image_indices = tuple(
- tensor[~early_stops]
- for tensor in [descriptors, keypoints[0], keypoints[1], mask, image_indices]
- )
- keypoints = (keypoints_0, keypoint_1)
- if do_keypoint_pruning:
- pruned_keypoints_indices, pruned_keypoints_iterations, keypoint_confidences = tuple(
- tensor[~early_stops]
- for tensor in [
- pruned_keypoints_indices,
- pruned_keypoints_iterations,
- keypoint_confidences,
- ]
- )
- # If all pairs of images are early stopped, we stop the forward pass through the transformer
- # layers for all pairs of images.
- if torch.all(early_stopped_pairs):
- break
- if do_keypoint_pruning:
- # Prune keypoints from the input of the transformer layers for the next iterations if the confidence of
- # the keypoints is below a certain threshold.
- descriptors, keypoints, pruned_keypoints_indices, mask, pruned_keypoints_iterations = (
- self._do_layer_keypoint_pruning(
- descriptors,
- keypoints,
- mask,
- pruned_keypoints_indices,
- pruned_keypoints_iterations,
- keypoint_confidences,
- layer_index,
- )
- )
- if do_early_stop and do_keypoint_pruning:
- # Concatenate early stopped outputs together and perform final keypoint pruning
- final_pruned_keypoints_indices, final_pruned_keypoints_iterations, matches, matching_scores = (
- self._concat_early_stopped_outputs(
- early_stops_indices,
- final_pruned_keypoints_indices,
- final_pruned_keypoints_iterations,
- matches,
- matching_scores,
- )
- )
- matches, matching_scores = self._do_final_keypoint_pruning(
- final_pruned_keypoints_indices,
- matches,
- matching_scores,
- initial_num_keypoints,
- )
- else:
- matches, matching_scores = self._get_keypoint_matching(descriptors, mask, self.num_layers - 1)
- final_pruned_keypoints_iterations = torch.ones_like(matching_scores) * self.num_layers
- final_pruned_keypoints_iterations = final_pruned_keypoints_iterations.reshape(
- batch_size, 2, initial_num_keypoints
- )
- return (
- matches,
- matching_scores,
- final_pruned_keypoints_iterations,
- all_hidden_states,
- all_attentions,
- )
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- pixel_values: torch.FloatTensor,
- labels: torch.LongTensor | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- **kwargs,
- ) -> tuple | LightGlueKeypointMatchingOutput:
- loss = None
- if labels is not None:
- raise ValueError("LightGlue is not trainable, no labels should be provided.")
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- if pixel_values.ndim != 5 or pixel_values.size(1) != 2:
- raise ValueError("Input must be a 5D tensor of shape (batch_size, 2, num_channels, height, width)")
- batch_size, _, channels, height, width = pixel_values.shape
- pixel_values = pixel_values.reshape(batch_size * 2, channels, height, width)
- keypoint_detections = self.keypoint_detector(pixel_values)
- keypoints, _, descriptors, mask = keypoint_detections[:4]
- keypoints = keypoints.reshape(batch_size, 2, -1, 2).to(pixel_values)
- descriptors = descriptors.reshape(batch_size, 2, -1, self.keypoint_detector_descriptor_dim).to(pixel_values)
- mask = mask.reshape(batch_size, 2, -1)
- absolute_keypoints = keypoints.clone()
- absolute_keypoints[:, :, :, 0] = absolute_keypoints[:, :, :, 0] * width
- absolute_keypoints[:, :, :, 1] = absolute_keypoints[:, :, :, 1] * height
- matches, matching_scores, prune, hidden_states, attentions = self._match_image_pair(
- absolute_keypoints,
- descriptors,
- height,
- width,
- mask=mask,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- )
- return LightGlueKeypointMatchingOutput(
- loss=loss,
- matches=matches,
- matching_scores=matching_scores,
- keypoints=keypoints,
- prune=prune,
- mask=mask,
- hidden_states=hidden_states,
- attentions=attentions,
- )
- __all__ = [
- "LightGluePreTrainedModel",
- "LightGlueForKeypointMatching",
- "LightGlueConfig",
- "LightGlueImageProcessor",
- "LightGlueImageProcessorPil",
- ]
|