modular_lightglue.py 42 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929
  1. # Copyright 2025 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from collections.abc import Callable
  15. from dataclasses import dataclass
  16. import numpy as np
  17. import torch
  18. from huggingface_hub.dataclasses import strict
  19. from torch import nn
  20. from torch.nn.utils.rnn import pad_sequence
  21. from ...configuration_utils import PreTrainedConfig
  22. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  23. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  24. from ...processing_utils import ImagesKwargs, Unpack
  25. from ...utils import ModelOutput, TensorType, auto_docstring, can_return_tuple, logging
  26. from ...utils.import_utils import requires
  27. from ..auto import CONFIG_MAPPING, AutoConfig
  28. from ..auto.modeling_auto import AutoModelForKeypointDetection
  29. from ..clip.modeling_clip import CLIPMLP
  30. from ..cohere.modeling_cohere import apply_rotary_pos_emb
  31. from ..llama.modeling_llama import LlamaAttention, eager_attention_forward
  32. from ..superglue.image_processing_pil_superglue import SuperGlueImageProcessorPil
  33. from ..superglue.image_processing_superglue import SuperGlueImageProcessor
  34. from ..superpoint import SuperPointConfig
  35. logger = logging.get_logger(__name__)
  36. @auto_docstring(checkpoint="ETH-CVG/lightglue_superpoint")
  37. @strict
  38. class LightGlueConfig(PreTrainedConfig):
  39. r"""
  40. keypoint_detector_config (`Union[AutoConfig, dict]`, *optional*, defaults to `SuperPointConfig`):
  41. The config object or dictionary of the keypoint detector.
  42. descriptor_dim (`int`, *optional*, defaults to 256):
  43. The dimension of the descriptors.
  44. depth_confidence (`float`, *optional*, defaults to 0.95):
  45. The confidence threshold used to perform early stopping
  46. width_confidence (`float`, *optional*, defaults to 0.99):
  47. The confidence threshold used to prune points
  48. filter_threshold (`float`, *optional*, defaults to 0.1):
  49. The confidence threshold used to filter matches
  50. Examples:
  51. ```python
  52. >>> from transformers import LightGlueConfig, LightGlueForKeypointMatching
  53. >>> # Initializing a LightGlue style configuration
  54. >>> configuration = LightGlueConfig()
  55. >>> # Initializing a model from the LightGlue style configuration
  56. >>> model = LightGlueForKeypointMatching(configuration)
  57. >>> # Accessing the model configuration
  58. >>> configuration = model.config
  59. ```
  60. """
  61. model_type = "lightglue"
  62. sub_configs = {"keypoint_detector_config": AutoConfig}
  63. keypoint_detector_config: dict | SuperPointConfig | None = None
  64. descriptor_dim: int = 256
  65. num_hidden_layers: int = 9
  66. num_attention_heads: int = 4
  67. num_key_value_heads: int | None = None
  68. depth_confidence: float = 0.95
  69. width_confidence: float = 0.99
  70. filter_threshold: float = 0.1
  71. initializer_range: float = 0.02
  72. hidden_act: str = "gelu"
  73. attention_dropout: float | int = 0.0
  74. attention_bias: bool = True
  75. def __post_init__(self, **kwargs):
  76. if self.num_key_value_heads is None:
  77. self.num_key_value_heads = self.num_attention_heads
  78. # Keypoint Detector is forced into eager attention mode because SuperPoint does not have Attention
  79. # See https://github.com/huggingface/transformers/pull/31718#discussion_r2109733153
  80. if isinstance(self.keypoint_detector_config, dict):
  81. self.keypoint_detector_config["model_type"] = self.keypoint_detector_config.get("model_type", "superpoint")
  82. self.keypoint_detector_config = CONFIG_MAPPING[self.keypoint_detector_config["model_type"]](
  83. **self.keypoint_detector_config, attn_implementation="eager"
  84. )
  85. elif self.keypoint_detector_config is None:
  86. self.keypoint_detector_config = CONFIG_MAPPING["superpoint"](attn_implementation="eager")
  87. self.intermediate_size = self.descriptor_dim * 2
  88. self.hidden_size = self.descriptor_dim
  89. super().__post_init__(**kwargs)
  90. def validate_architecture(self):
  91. """Part of `@strict`-powered validation. Validates the architecture of the config."""
  92. if self.descriptor_dim % self.num_attention_heads != 0:
  93. raise ValueError("descriptor_dim % num_heads is different from zero")
  94. @dataclass
  95. @auto_docstring(
  96. custom_intro="""
  97. Base class for outputs of LightGlue keypoint matching models. Due to the nature of keypoint detection and matching,
  98. the number of keypoints is not fixed and can vary from image to image, which makes batching non-trivial. In the
  99. batch of images, the maximum number of matches is set as the dimension of the matches and matching scores. The mask
  100. tensor is used to indicate which values in the keypoints, matches, matching_scores and prune tensors are keypoint
  101. matching information.
  102. """
  103. )
  104. class LightGlueKeypointMatchingOutput(ModelOutput):
  105. r"""
  106. loss (`torch.FloatTensor` of shape `(1,)`, *optional*):
  107. Loss computed during training.
  108. matches (`torch.FloatTensor` of shape `(batch_size, 2, num_matches)`):
  109. Index of keypoint matched in the other image.
  110. matching_scores (`torch.FloatTensor` of shape `(batch_size, 2, num_matches)`):
  111. Scores of predicted matches.
  112. keypoints (`torch.FloatTensor` of shape `(batch_size, num_keypoints, 2)`):
  113. Absolute (x, y) coordinates of predicted keypoints in a given image.
  114. prune (`torch.IntTensor` of shape `(batch_size, num_keypoints)`):
  115. Pruning mask indicating which keypoints are removed and at which layer.
  116. mask (`torch.BoolTensor` of shape `(batch_size, num_keypoints)`):
  117. Mask indicating which values in matches, matching_scores, keypoints and prune are keypoint matching
  118. information.
  119. hidden_states (`Tuple[torch.FloatTensor, ...]`, *optional*):
  120. Tuple of `torch.FloatTensor` (one for the output of each stage) of shape `(batch_size, 2, num_channels,
  121. num_keypoints)` returned when `output_hidden_states=True` is passed or when
  122. `config.output_hidden_states=True`
  123. attentions (`Tuple[torch.FloatTensor, ...]`, *optional*):
  124. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, 2, num_heads, num_keypoints,
  125. num_keypoints)` returned when `output_attentions=True` is passed or when
  126. `config.output_attentions=True`
  127. """
  128. loss: torch.FloatTensor | None = None
  129. matches: torch.FloatTensor | None = None
  130. matching_scores: torch.FloatTensor | None = None
  131. keypoints: torch.FloatTensor | None = None
  132. prune: torch.IntTensor | None = None
  133. mask: torch.FloatTensor | None = None
  134. hidden_states: tuple[torch.FloatTensor] | None = None
  135. attentions: tuple[torch.FloatTensor] | None = None
  136. class LightGlueImageProcessorKwargs(ImagesKwargs, total=False):
  137. r"""
  138. do_grayscale (`bool`, *optional*, defaults to `self.do_grayscale`):
  139. Whether to convert the image to grayscale. Can be overridden by `do_grayscale` in the `preprocess` method.
  140. """
  141. do_grayscale: bool
  142. class LightGlueImageProcessor(SuperGlueImageProcessor):
  143. def post_process_keypoint_matching(
  144. self,
  145. outputs: "LightGlueKeypointMatchingOutput",
  146. target_sizes: TensorType | list[tuple],
  147. threshold: float = 0.0,
  148. ) -> list[dict[str, torch.Tensor]]:
  149. return super().post_process_keypoint_matching(outputs, target_sizes, threshold)
  150. @requires(backends=("torch",))
  151. class LightGlueImageProcessorPil(SuperGlueImageProcessorPil):
  152. @requires(backends=("torch",))
  153. def post_process_keypoint_matching(
  154. self,
  155. outputs: "LightGlueKeypointMatchingOutput",
  156. target_sizes: TensorType | list[tuple],
  157. threshold: float = 0.0,
  158. ) -> list[dict[str, "torch.Tensor"]]:
  159. return super().post_process_keypoint_matching(outputs, target_sizes, threshold)
  160. class LightGluePositionalEncoder(nn.Module):
  161. def __init__(self, config: LightGlueConfig):
  162. super().__init__()
  163. self.projector = nn.Linear(2, config.descriptor_dim // config.num_attention_heads // 2, bias=False)
  164. def forward(
  165. self, keypoints: torch.Tensor, output_hidden_states: bool | None = False
  166. ) -> tuple[torch.Tensor] | tuple[torch.Tensor, torch.Tensor]:
  167. projected_keypoints = self.projector(keypoints)
  168. embeddings = projected_keypoints.repeat_interleave(2, dim=-1)
  169. cosines = torch.cos(embeddings)
  170. sines = torch.sin(embeddings)
  171. embeddings = (cosines, sines)
  172. output = (embeddings, projected_keypoints) if output_hidden_states else (embeddings,)
  173. return output
  174. class LightGlueAttention(LlamaAttention):
  175. def __init__(self, config: LightGlueConfig, layer_idx: int):
  176. super().__init__()
  177. del self.rotary_emb
  178. def forward(
  179. self,
  180. hidden_states: torch.Tensor,
  181. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  182. attention_mask: torch.Tensor | None = None,
  183. encoder_hidden_states: torch.Tensor | None = None,
  184. encoder_attention_mask: torch.Tensor | None = None,
  185. **kwargs: Unpack[FlashAttentionKwargs],
  186. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  187. input_shape = hidden_states.shape[:-1]
  188. hidden_shape = (*input_shape, -1, self.head_dim)
  189. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  190. is_cross_attention = encoder_hidden_states is not None
  191. current_states = encoder_hidden_states if is_cross_attention else hidden_states
  192. current_attention_mask = encoder_attention_mask if is_cross_attention else attention_mask
  193. key_states = self.k_proj(current_states).view(hidden_shape).transpose(1, 2)
  194. value_states = self.v_proj(current_states).view(hidden_shape).transpose(1, 2)
  195. if position_embeddings is not None:
  196. cos, sin = position_embeddings
  197. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  198. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  199. self.config._attn_implementation, eager_attention_forward
  200. )
  201. attn_output, attn_weights = attention_interface(
  202. self,
  203. query_states,
  204. key_states,
  205. value_states,
  206. current_attention_mask,
  207. dropout=0.0 if not self.training else self.attention_dropout,
  208. scaling=self.scaling,
  209. **kwargs,
  210. )
  211. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  212. attn_output = self.o_proj(attn_output)
  213. return attn_output, attn_weights
  214. class LightGlueMLP(CLIPMLP):
  215. def __init__(self, config: LightGlueConfig):
  216. super().__init__(config)
  217. self.fc1 = nn.Linear(config.intermediate_size, config.intermediate_size)
  218. self.layer_norm = nn.LayerNorm(config.intermediate_size, elementwise_affine=True)
  219. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  220. hidden_states = self.fc1(hidden_states)
  221. hidden_states = self.layer_norm(hidden_states)
  222. hidden_states = self.activation_fn(hidden_states)
  223. hidden_states = self.fc2(hidden_states)
  224. return hidden_states
  225. class LightGlueTransformerLayer(nn.Module):
  226. def __init__(self, config: LightGlueConfig, layer_idx: int):
  227. super().__init__()
  228. self.self_attention = LightGlueAttention(config, layer_idx)
  229. self.self_mlp = LightGlueMLP(config)
  230. self.cross_attention = LightGlueAttention(config, layer_idx)
  231. self.cross_mlp = LightGlueMLP(config)
  232. def forward(
  233. self,
  234. descriptors: torch.Tensor,
  235. keypoints: torch.Tensor,
  236. attention_mask: torch.Tensor,
  237. output_hidden_states: bool | None = False,
  238. output_attentions: bool | None = False,
  239. ) -> tuple[torch.Tensor, tuple[torch.Tensor] | None, tuple[torch.Tensor] | None]:
  240. all_hidden_states = () if output_hidden_states else None
  241. all_attentions = () if output_attentions else None
  242. if output_hidden_states:
  243. all_hidden_states = all_hidden_states + (descriptors,)
  244. batch_size, num_keypoints, descriptor_dim = descriptors.shape
  245. # Self attention block
  246. attention_output, self_attentions = self.self_attention(
  247. descriptors,
  248. position_embeddings=keypoints,
  249. attention_mask=attention_mask,
  250. output_attentions=output_attentions,
  251. )
  252. intermediate_states = torch.cat([descriptors, attention_output], dim=-1)
  253. output_states = self.self_mlp(intermediate_states)
  254. self_attention_descriptors = descriptors + output_states
  255. if output_hidden_states:
  256. self_attention_hidden_states = (intermediate_states, output_states)
  257. # Reshape hidden_states to group by image_pairs :
  258. # (batch_size, num_keypoints, descriptor_dim) -> (batch_size, 2, num_keypoints, descriptor_dim)
  259. # Flip dimension 1 to perform cross attention :
  260. # (image0, image1) -> (image1, image0)
  261. # Reshape back to original shape :
  262. # (batch_size, 2, num_keypoints, descriptor_dim) -> (batch_size, num_keypoints, descriptor_dim)
  263. encoder_hidden_states = (
  264. self_attention_descriptors.reshape(-1, 2, num_keypoints, descriptor_dim)
  265. .flip(1)
  266. .reshape(batch_size, num_keypoints, descriptor_dim)
  267. )
  268. # Same for mask
  269. encoder_attention_mask = (
  270. attention_mask.reshape(-1, 2, 1, 1, num_keypoints).flip(1).reshape(batch_size, 1, 1, num_keypoints)
  271. if attention_mask is not None
  272. else None
  273. )
  274. # Cross attention block
  275. cross_attention_output, cross_attentions = self.cross_attention(
  276. self_attention_descriptors,
  277. encoder_hidden_states=encoder_hidden_states,
  278. encoder_attention_mask=encoder_attention_mask,
  279. output_attentions=output_attentions,
  280. )
  281. cross_intermediate_states = torch.cat([self_attention_descriptors, cross_attention_output], dim=-1)
  282. cross_output_states = self.cross_mlp(cross_intermediate_states)
  283. descriptors = self_attention_descriptors + cross_output_states
  284. if output_hidden_states:
  285. cross_attention_hidden_states = (cross_intermediate_states, cross_output_states)
  286. all_hidden_states = (
  287. all_hidden_states
  288. + (self_attention_descriptors.reshape(batch_size, num_keypoints, descriptor_dim),)
  289. + self_attention_hidden_states
  290. + (descriptors.reshape(batch_size, num_keypoints, descriptor_dim),)
  291. + cross_attention_hidden_states
  292. )
  293. if output_attentions:
  294. all_attentions = all_attentions + (self_attentions,) + (cross_attentions,)
  295. return descriptors, all_hidden_states, all_attentions
  296. def sigmoid_log_double_softmax(
  297. similarity: torch.Tensor, matchability0: torch.Tensor, matchability1: torch.Tensor
  298. ) -> torch.Tensor:
  299. """create the log assignment matrix from logits and similarity"""
  300. batch_size, num_keypoints_0, num_keypoints_1 = similarity.shape
  301. certainties = nn.functional.logsigmoid(matchability0) + nn.functional.logsigmoid(matchability1).transpose(1, 2)
  302. scores0 = nn.functional.log_softmax(similarity, 2)
  303. scores1 = nn.functional.log_softmax(similarity.transpose(-1, -2).contiguous(), 2).transpose(-1, -2)
  304. scores = similarity.new_full((batch_size, num_keypoints_0 + 1, num_keypoints_1 + 1), 0)
  305. scores[:, :num_keypoints_0, :num_keypoints_1] = scores0 + scores1 + certainties
  306. scores[:, :-1, -1] = nn.functional.logsigmoid(-matchability0.squeeze(-1))
  307. scores[:, -1, :-1] = nn.functional.logsigmoid(-matchability1.squeeze(-1))
  308. return scores
  309. class LightGlueMatchAssignmentLayer(nn.Module):
  310. def __init__(self, config: LightGlueConfig):
  311. super().__init__()
  312. self.descriptor_dim = config.descriptor_dim
  313. self.final_projection = nn.Linear(self.descriptor_dim, self.descriptor_dim, bias=True)
  314. self.matchability = nn.Linear(self.descriptor_dim, 1, bias=True)
  315. def forward(self, descriptors: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
  316. batch_size, num_keypoints, descriptor_dim = descriptors.shape
  317. # Final projection and similarity computation
  318. m_descriptors = self.final_projection(descriptors)
  319. m_descriptors = m_descriptors / torch.tensor(self.descriptor_dim, device=m_descriptors.device) ** 0.25
  320. m_descriptors = m_descriptors.reshape(batch_size // 2, 2, num_keypoints, descriptor_dim)
  321. m_descriptors0 = m_descriptors[:, 0]
  322. m_descriptors1 = m_descriptors[:, 1]
  323. similarity = m_descriptors0 @ m_descriptors1.transpose(-1, -2)
  324. if mask is not None:
  325. mask = mask.reshape(batch_size // 2, 2, num_keypoints)
  326. mask0 = mask[:, 0].unsqueeze(-1)
  327. mask1 = mask[:, 1].unsqueeze(-1).transpose(-1, -2)
  328. mask = mask0 * mask1
  329. similarity = similarity.masked_fill(mask == 0, torch.finfo(similarity.dtype).min)
  330. # Compute matchability of descriptors
  331. matchability = self.matchability(descriptors)
  332. matchability = matchability.reshape(batch_size // 2, 2, num_keypoints, 1)
  333. matchability_0 = matchability[:, 0]
  334. matchability_1 = matchability[:, 1]
  335. # Compute scores from similarity and matchability
  336. scores = sigmoid_log_double_softmax(similarity, matchability_0, matchability_1)
  337. return scores
  338. def get_matchability(self, descriptors: torch.Tensor) -> torch.Tensor:
  339. """Get matchability of descriptors as a probability"""
  340. matchability = self.matchability(descriptors)
  341. matchability = nn.functional.sigmoid(matchability).squeeze(-1)
  342. return matchability
  343. class LightGlueTokenConfidenceLayer(nn.Module):
  344. def __init__(self, config: LightGlueConfig):
  345. super().__init__()
  346. self.token = nn.Linear(config.descriptor_dim, 1)
  347. def forward(self, descriptors: torch.Tensor) -> torch.Tensor:
  348. token = self.token(descriptors.detach())
  349. token = nn.functional.sigmoid(token).squeeze(-1)
  350. return token
  351. @auto_docstring
  352. class LightGluePreTrainedModel(PreTrainedModel):
  353. """
  354. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  355. models.
  356. """
  357. config: LightGlueConfig
  358. base_model_prefix = "lightglue"
  359. main_input_name = "pixel_values"
  360. input_modalities = ("image",)
  361. supports_gradient_checkpointing = False
  362. _supports_flash_attn = True
  363. _supports_sdpa = True
  364. def get_matches_from_scores(scores: torch.Tensor, threshold: float) -> tuple[torch.Tensor, torch.Tensor]:
  365. """obtain matches from a score matrix [Bx M+1 x N+1]"""
  366. batch_size, _, _ = scores.shape
  367. # For each keypoint, get the best match
  368. max0 = scores[:, :-1, :-1].max(2)
  369. max1 = scores[:, :-1, :-1].max(1)
  370. matches0 = max0.indices
  371. matches1 = max1.indices
  372. # Mutual check for matches
  373. indices0 = torch.arange(matches0.shape[1], device=matches0.device)[None]
  374. indices1 = torch.arange(matches1.shape[1], device=matches1.device)[None]
  375. mutual0 = indices0 == matches1.gather(1, matches0)
  376. mutual1 = indices1 == matches0.gather(1, matches1)
  377. # Get matching scores and filter based on mutual check and thresholding
  378. max0 = max0.values.exp()
  379. zero = max0.new_tensor(0)
  380. matching_scores0 = torch.where(mutual0, max0, zero)
  381. matching_scores1 = torch.where(mutual1, matching_scores0.gather(1, matches1), zero)
  382. valid0 = mutual0 & (matching_scores0 > threshold)
  383. valid1 = mutual1 & valid0.gather(1, matches1)
  384. # Filter matches based on mutual check and thresholding of scores
  385. matches0 = torch.where(valid0, matches0, -1)
  386. matches1 = torch.where(valid1, matches1, -1)
  387. matches = torch.stack([matches0, matches1]).transpose(0, 1).reshape(batch_size * 2, -1)
  388. matching_scores = torch.stack([matching_scores0, matching_scores1]).transpose(0, 1).reshape(batch_size * 2, -1)
  389. return matches, matching_scores
  390. def normalize_keypoints(keypoints: torch.Tensor, height: int, width: int) -> torch.Tensor:
  391. """
  392. Normalize keypoints locations based on image image_shape
  393. Args:
  394. keypoints (`torch.Tensor` of shape `(batch_size, num_keypoints, 2)`):
  395. Keypoints locations in (x, y) format.
  396. height (`int`):
  397. Image height.
  398. width (`int`):
  399. Image width.
  400. Returns:
  401. Normalized keypoints locations of shape (`torch.Tensor` of shape `(batch_size, num_keypoints, 2)`).
  402. """
  403. size = torch.tensor([width, height], device=keypoints.device, dtype=keypoints.dtype)[None]
  404. shift = size / 2
  405. scale = size.max(-1).values / 2
  406. keypoints = (keypoints - shift[..., None, :]) / scale[..., None, None]
  407. return keypoints
  408. @auto_docstring(
  409. custom_intro="""
  410. LightGlue model taking images as inputs and outputting the matching of them.
  411. """
  412. )
  413. class LightGlueForKeypointMatching(LightGluePreTrainedModel):
  414. """
  415. LightGlue is a model matching keypoints in images by leveraging detections from a keypoint detector such as
  416. SuperPoint. It is based on the SuperGlue architecture and is designed to be lightweight and efficient.
  417. It consists of :
  418. 1. Keypoint Encoder
  419. 2. A Graph Neural Network with self and cross attention layers
  420. 3. Matching Assignment layers
  421. The correspondence ids use -1 to indicate non-matching points.
  422. Philipp Lindenberger, Paul-Edouard Sarlin and Marc Pollefeys. LightGlue: Local Feature Matching at Light Speed.
  423. In ICCV 2023. https://huggingface.co/papers/2306.13643
  424. """
  425. def __init__(self, config: LightGlueConfig):
  426. super().__init__(config)
  427. self.keypoint_detector = AutoModelForKeypointDetection.from_config(config.keypoint_detector_config)
  428. self.keypoint_detector_descriptor_dim = config.keypoint_detector_config.descriptor_decoder_dim
  429. self.descriptor_dim = config.descriptor_dim
  430. self.num_layers = config.num_hidden_layers
  431. self.filter_threshold = config.filter_threshold
  432. self.depth_confidence = config.depth_confidence
  433. self.width_confidence = config.width_confidence
  434. if self.descriptor_dim != self.keypoint_detector_descriptor_dim:
  435. self.input_projection = nn.Linear(self.keypoint_detector_descriptor_dim, self.descriptor_dim, bias=True)
  436. else:
  437. self.input_projection = nn.Identity()
  438. self.positional_encoder = LightGluePositionalEncoder(config)
  439. self.transformer_layers = nn.ModuleList(
  440. [LightGlueTransformerLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]
  441. )
  442. self.match_assignment_layers = nn.ModuleList(
  443. [LightGlueMatchAssignmentLayer(config) for _ in range(config.num_hidden_layers)]
  444. )
  445. self.token_confidence = nn.ModuleList(
  446. [LightGlueTokenConfidenceLayer(config) for _ in range(config.num_hidden_layers - 1)]
  447. )
  448. self.post_init()
  449. def _get_confidence_threshold(self, layer_index: int) -> float:
  450. """scaled confidence threshold for a given layer"""
  451. threshold = 0.8 + 0.1 * np.exp(-4.0 * layer_index / self.num_layers)
  452. return np.clip(threshold, 0, 1)
  453. def _keypoint_processing(
  454. self, descriptors: torch.Tensor, keypoints: torch.Tensor, output_hidden_states: bool | None = False
  455. ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
  456. descriptors = descriptors.detach().contiguous()
  457. projected_descriptors = self.input_projection(descriptors)
  458. keypoint_encoding_output = self.positional_encoder(keypoints, output_hidden_states=output_hidden_states)
  459. return projected_descriptors, keypoint_encoding_output
  460. def _get_early_stopped_image_pairs(
  461. self, keypoint_confidences: torch.Tensor, layer_index: int, mask: torch.Tensor, num_points: torch.Tensor
  462. ) -> torch.Tensor:
  463. """evaluate whether we should stop inference based on the confidence of the keypoints"""
  464. batch_size, _ = mask.shape
  465. if layer_index < self.num_layers - 1:
  466. # If the current layer is not the last layer, we compute the confidence of the keypoints and check
  467. # if we should stop the forward pass through the transformer layers for each pair of images.
  468. keypoint_confidences = keypoint_confidences.masked_fill(mask == 0, 1)
  469. keypoint_confidences = keypoint_confidences.reshape(batch_size // 2, -1)
  470. threshold = self._get_confidence_threshold(layer_index)
  471. ratio_confident = 1.0 - (keypoint_confidences < threshold).float().sum(dim=1) / num_points
  472. early_stopped_pairs = ratio_confident > self.depth_confidence
  473. else:
  474. # If the current layer is the last layer, we stop the forward pass through the transformer layers for
  475. # all pairs of images.
  476. early_stopped_pairs = torch.ones(batch_size, dtype=torch.bool)
  477. return early_stopped_pairs
  478. def _get_keypoint_matching(self, descriptors, mask, layer_index, early_stops=None):
  479. if early_stops is not None:
  480. descriptors = descriptors[early_stops]
  481. mask = mask[early_stops]
  482. scores = self.match_assignment_layers[layer_index](descriptors, mask)
  483. matches, matching_scores = get_matches_from_scores(scores, self.filter_threshold)
  484. return matches, matching_scores
  485. def _get_pruning_mask(self, confidences: torch.Tensor, scores: torch.Tensor, layer_index: int) -> torch.Tensor:
  486. """mask points which should be removed"""
  487. keep = scores > (1 - self.width_confidence)
  488. if confidences is not None: # Low-confidence points are never pruned.
  489. keep |= confidences <= self._get_confidence_threshold(layer_index)
  490. return keep
  491. def _do_layer_keypoint_pruning(
  492. self,
  493. descriptors: torch.Tensor,
  494. keypoints: torch.Tensor,
  495. mask: torch.Tensor,
  496. indices: torch.Tensor,
  497. prune_output: torch.Tensor,
  498. keypoint_confidences: torch.Tensor,
  499. layer_index: int,
  500. ):
  501. """
  502. For a given layer, prune keypoints based on the confidence of the keypoints and the matchability of the
  503. descriptors.
  504. """
  505. batch_size, _, _ = descriptors.shape
  506. descriptors_matchability = self.match_assignment_layers[layer_index].get_matchability(descriptors)
  507. pruned_keypoints_mask = self._get_pruning_mask(keypoint_confidences, descriptors_matchability, layer_index)
  508. pruned_keypoints_mask = pruned_keypoints_mask.masked_fill(mask == 0, torch.tensor(False))
  509. # For each image, we extract the pruned indices and the corresponding descriptors and keypoints.
  510. pruned_descriptors, pruned_keypoints_0, pruned_keypoints_1, pruned_mask, pruned_indices = (
  511. [t[mask] for t, mask in zip(tensor, pruned_keypoints_mask)]
  512. for tensor in [descriptors, keypoints[0], keypoints[1], pruned_keypoints_mask, indices]
  513. )
  514. for i in range(batch_size):
  515. prune_output[i, pruned_indices[i]] += 1
  516. # Pad the pruned descriptors, keypoints, indices and mask to have the same shape across the batch.
  517. pruned_descriptors, pruned_keypoints_0, pruned_keypoints_1, pruned_mask = (
  518. pad_sequence(pruned_tensor, batch_first=True)
  519. for pruned_tensor in [pruned_descriptors, pruned_keypoints_0, pruned_keypoints_1, pruned_mask]
  520. )
  521. pruned_keypoints = (pruned_keypoints_0, pruned_keypoints_1)
  522. pruned_indices = pad_sequence(pruned_indices, batch_first=True, padding_value=-1)
  523. return pruned_descriptors, pruned_keypoints, pruned_indices, pruned_mask, prune_output
  524. def _concat_early_stopped_outputs(
  525. self,
  526. early_stops_indices,
  527. final_pruned_keypoints_indices,
  528. final_pruned_keypoints_iterations,
  529. matches,
  530. matching_scores,
  531. ):
  532. early_stops_indices = torch.stack(early_stops_indices)
  533. # Rearrange tensors to have the same order as the input batch
  534. ids = torch.arange(early_stops_indices.shape[0])
  535. order_indices = early_stops_indices[ids]
  536. early_stops_indices = early_stops_indices[order_indices]
  537. matches, final_pruned_keypoints_indices = (
  538. pad_sequence(tensor, batch_first=True, padding_value=-1)
  539. for tensor in [matches, final_pruned_keypoints_indices]
  540. )
  541. matching_scores, final_pruned_keypoints_iterations = (
  542. pad_sequence(tensor, batch_first=True, padding_value=0)
  543. for tensor in [matching_scores, final_pruned_keypoints_iterations]
  544. )
  545. matches, matching_scores, final_pruned_keypoints_indices, final_pruned_keypoints_iterations = (
  546. tensor[early_stops_indices]
  547. for tensor in [
  548. matches,
  549. matching_scores,
  550. final_pruned_keypoints_indices,
  551. final_pruned_keypoints_iterations,
  552. ]
  553. )
  554. return final_pruned_keypoints_indices, final_pruned_keypoints_iterations, matches, matching_scores
  555. def _do_final_keypoint_pruning(
  556. self,
  557. indices: torch.Tensor,
  558. matches: torch.Tensor,
  559. matching_scores: torch.Tensor,
  560. num_keypoints: torch.Tensor,
  561. ) -> tuple[torch.Tensor, torch.Tensor]:
  562. # (batch_size, num_keypoints) -> (batch_size // 2, 2, num_keypoints) -> 2 * (batch_size // 2, num_keypoints) to
  563. # have tensors from
  564. batch_size, _ = indices.shape
  565. indices, matches, matching_scores = (
  566. tensor.reshape(batch_size // 2, 2, -1) for tensor in [indices, matches, matching_scores]
  567. )
  568. indices0 = indices[:, 0]
  569. indices1 = indices[:, 1]
  570. matches0 = matches[:, 0]
  571. matches1 = matches[:, 1]
  572. matching_scores0 = matching_scores[:, 0]
  573. matching_scores1 = matching_scores[:, 1]
  574. # Prepare final matches and matching scores
  575. _matches = torch.full((batch_size // 2, 2, num_keypoints), -1, device=indices.device, dtype=matches.dtype)
  576. _matching_scores = torch.zeros(
  577. (batch_size // 2, 2, num_keypoints), device=indices.device, dtype=matching_scores.dtype
  578. )
  579. # Fill the matches and matching scores for each image pair
  580. for i in range(batch_size // 2):
  581. _matches[i, 0, indices0[i]] = torch.where(
  582. matches0[i] == -1, -1, indices1[i].gather(0, matches0[i].clamp(min=0))
  583. )
  584. _matches[i, 1, indices1[i]] = torch.where(
  585. matches1[i] == -1, -1, indices0[i].gather(0, matches1[i].clamp(min=0))
  586. )
  587. _matching_scores[i, 0, indices0[i]] = matching_scores0[i]
  588. _matching_scores[i, 1, indices1[i]] = matching_scores1[i]
  589. return _matches, _matching_scores
  590. def _match_image_pair(
  591. self,
  592. keypoints: torch.Tensor,
  593. descriptors: torch.Tensor,
  594. height: int,
  595. width: int,
  596. mask: torch.Tensor | None = None,
  597. output_attentions: bool | None = None,
  598. output_hidden_states: bool | None = None,
  599. ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, tuple, tuple]:
  600. all_hidden_states = () if output_hidden_states else None
  601. all_attentions = () if output_attentions else None
  602. if keypoints.shape[2] == 0: # no keypoints
  603. shape = keypoints.shape[:-1]
  604. return (
  605. keypoints.new_full(shape, -1, dtype=torch.int),
  606. keypoints.new_zeros(shape),
  607. keypoints.new_zeros(shape),
  608. all_hidden_states,
  609. all_attentions,
  610. )
  611. device = keypoints.device
  612. batch_size, _, initial_num_keypoints, _ = keypoints.shape
  613. num_points_per_pair = torch.sum(mask.reshape(batch_size, -1), dim=1)
  614. # (batch_size, 2, num_keypoints, 2) -> (batch_size * 2, num_keypoints, 2)
  615. keypoints = keypoints.reshape(batch_size * 2, initial_num_keypoints, 2)
  616. mask = mask.reshape(batch_size * 2, initial_num_keypoints) if mask is not None else None
  617. descriptors = descriptors.reshape(batch_size * 2, initial_num_keypoints, self.keypoint_detector_descriptor_dim)
  618. image_indices = torch.arange(batch_size * 2, device=device)
  619. # Keypoint normalization
  620. keypoints = normalize_keypoints(keypoints, height, width)
  621. descriptors, keypoint_encoding_output = self._keypoint_processing(
  622. descriptors, keypoints, output_hidden_states=output_hidden_states
  623. )
  624. keypoints = keypoint_encoding_output[0]
  625. # Early stop consists of stopping the forward pass through the transformer layers when the confidence of the
  626. # keypoints is above a certain threshold.
  627. do_early_stop = self.depth_confidence > 0
  628. # Keypoint pruning consists of removing keypoints from the input of the transformer layers when the confidence of
  629. # the keypoints is below a certain threshold.
  630. do_keypoint_pruning = self.width_confidence > 0
  631. early_stops_indices = []
  632. matches = []
  633. matching_scores = []
  634. final_pruned_keypoints_indices = []
  635. final_pruned_keypoints_iterations = []
  636. pruned_keypoints_indices = torch.arange(0, initial_num_keypoints, device=device).expand(batch_size * 2, -1)
  637. pruned_keypoints_iterations = torch.ones_like(pruned_keypoints_indices)
  638. for layer_index in range(self.num_layers):
  639. input_shape = descriptors.size()
  640. if mask is not None:
  641. extended_attention_mask = self.get_extended_attention_mask(mask, input_shape)
  642. else:
  643. extended_attention_mask = torch.ones((batch_size, input_shape[-2]), device=keypoints.device)
  644. layer_output = self.transformer_layers[layer_index](
  645. descriptors,
  646. keypoints,
  647. attention_mask=extended_attention_mask,
  648. output_hidden_states=output_hidden_states,
  649. output_attentions=output_attentions,
  650. )
  651. descriptors, hidden_states, attention = layer_output
  652. if output_hidden_states:
  653. all_hidden_states = all_hidden_states + hidden_states
  654. if output_attentions:
  655. all_attentions = all_attentions + attention
  656. if do_early_stop:
  657. if layer_index < self.num_layers - 1:
  658. # Get the confidence of the keypoints for the current layer
  659. keypoint_confidences = self.token_confidence[layer_index](descriptors)
  660. # Determine which pairs of images should be early stopped based on the confidence of the keypoints for
  661. # the current layer.
  662. early_stopped_pairs = self._get_early_stopped_image_pairs(
  663. keypoint_confidences, layer_index, mask, num_points=num_points_per_pair
  664. )
  665. else:
  666. # Early stopping always occurs at the last layer
  667. early_stopped_pairs = torch.ones(batch_size, dtype=torch.bool)
  668. if torch.any(early_stopped_pairs):
  669. # If a pair of images is considered early stopped, we compute the matches for the remaining
  670. # keypoints and stop the forward pass through the transformer layers for this pair of images.
  671. early_stops = early_stopped_pairs.repeat_interleave(2)
  672. early_stopped_image_indices = image_indices[early_stops]
  673. early_stopped_matches, early_stopped_matching_scores = self._get_keypoint_matching(
  674. descriptors, mask, layer_index, early_stops=early_stops
  675. )
  676. early_stops_indices.extend(list(early_stopped_image_indices))
  677. matches.extend(list(early_stopped_matches))
  678. matching_scores.extend(list(early_stopped_matching_scores))
  679. if do_keypoint_pruning:
  680. final_pruned_keypoints_indices.extend(list(pruned_keypoints_indices[early_stops]))
  681. final_pruned_keypoints_iterations.extend(list(pruned_keypoints_iterations[early_stops]))
  682. # Remove image pairs that have been early stopped from the forward pass
  683. num_points_per_pair = num_points_per_pair[~early_stopped_pairs]
  684. descriptors, keypoints_0, keypoint_1, mask, image_indices = tuple(
  685. tensor[~early_stops]
  686. for tensor in [descriptors, keypoints[0], keypoints[1], mask, image_indices]
  687. )
  688. keypoints = (keypoints_0, keypoint_1)
  689. if do_keypoint_pruning:
  690. pruned_keypoints_indices, pruned_keypoints_iterations, keypoint_confidences = tuple(
  691. tensor[~early_stops]
  692. for tensor in [
  693. pruned_keypoints_indices,
  694. pruned_keypoints_iterations,
  695. keypoint_confidences,
  696. ]
  697. )
  698. # If all pairs of images are early stopped, we stop the forward pass through the transformer
  699. # layers for all pairs of images.
  700. if torch.all(early_stopped_pairs):
  701. break
  702. if do_keypoint_pruning:
  703. # Prune keypoints from the input of the transformer layers for the next iterations if the confidence of
  704. # the keypoints is below a certain threshold.
  705. descriptors, keypoints, pruned_keypoints_indices, mask, pruned_keypoints_iterations = (
  706. self._do_layer_keypoint_pruning(
  707. descriptors,
  708. keypoints,
  709. mask,
  710. pruned_keypoints_indices,
  711. pruned_keypoints_iterations,
  712. keypoint_confidences,
  713. layer_index,
  714. )
  715. )
  716. if do_early_stop and do_keypoint_pruning:
  717. # Concatenate early stopped outputs together and perform final keypoint pruning
  718. final_pruned_keypoints_indices, final_pruned_keypoints_iterations, matches, matching_scores = (
  719. self._concat_early_stopped_outputs(
  720. early_stops_indices,
  721. final_pruned_keypoints_indices,
  722. final_pruned_keypoints_iterations,
  723. matches,
  724. matching_scores,
  725. )
  726. )
  727. matches, matching_scores = self._do_final_keypoint_pruning(
  728. final_pruned_keypoints_indices,
  729. matches,
  730. matching_scores,
  731. initial_num_keypoints,
  732. )
  733. else:
  734. matches, matching_scores = self._get_keypoint_matching(descriptors, mask, self.num_layers - 1)
  735. final_pruned_keypoints_iterations = torch.ones_like(matching_scores) * self.num_layers
  736. final_pruned_keypoints_iterations = final_pruned_keypoints_iterations.reshape(
  737. batch_size, 2, initial_num_keypoints
  738. )
  739. return (
  740. matches,
  741. matching_scores,
  742. final_pruned_keypoints_iterations,
  743. all_hidden_states,
  744. all_attentions,
  745. )
  746. @can_return_tuple
  747. @auto_docstring
  748. def forward(
  749. self,
  750. pixel_values: torch.FloatTensor,
  751. labels: torch.LongTensor | None = None,
  752. output_attentions: bool | None = None,
  753. output_hidden_states: bool | None = None,
  754. **kwargs,
  755. ) -> tuple | LightGlueKeypointMatchingOutput:
  756. loss = None
  757. if labels is not None:
  758. raise ValueError("LightGlue is not trainable, no labels should be provided.")
  759. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  760. output_hidden_states = (
  761. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  762. )
  763. if pixel_values.ndim != 5 or pixel_values.size(1) != 2:
  764. raise ValueError("Input must be a 5D tensor of shape (batch_size, 2, num_channels, height, width)")
  765. batch_size, _, channels, height, width = pixel_values.shape
  766. pixel_values = pixel_values.reshape(batch_size * 2, channels, height, width)
  767. keypoint_detections = self.keypoint_detector(pixel_values)
  768. keypoints, _, descriptors, mask = keypoint_detections[:4]
  769. keypoints = keypoints.reshape(batch_size, 2, -1, 2).to(pixel_values)
  770. descriptors = descriptors.reshape(batch_size, 2, -1, self.keypoint_detector_descriptor_dim).to(pixel_values)
  771. mask = mask.reshape(batch_size, 2, -1)
  772. absolute_keypoints = keypoints.clone()
  773. absolute_keypoints[:, :, :, 0] = absolute_keypoints[:, :, :, 0] * width
  774. absolute_keypoints[:, :, :, 1] = absolute_keypoints[:, :, :, 1] * height
  775. matches, matching_scores, prune, hidden_states, attentions = self._match_image_pair(
  776. absolute_keypoints,
  777. descriptors,
  778. height,
  779. width,
  780. mask=mask,
  781. output_attentions=output_attentions,
  782. output_hidden_states=output_hidden_states,
  783. )
  784. return LightGlueKeypointMatchingOutput(
  785. loss=loss,
  786. matches=matches,
  787. matching_scores=matching_scores,
  788. keypoints=keypoints,
  789. prune=prune,
  790. mask=mask,
  791. hidden_states=hidden_states,
  792. attentions=attentions,
  793. )
  794. __all__ = [
  795. "LightGluePreTrainedModel",
  796. "LightGlueForKeypointMatching",
  797. "LightGlueConfig",
  798. "LightGlueImageProcessor",
  799. "LightGlueImageProcessorPil",
  800. ]