modeling_efficientloftr.py 60 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390
  1. # Copyright 2025 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from collections.abc import Callable
  15. from dataclasses import dataclass
  16. from typing import Optional
  17. import torch
  18. from torch import nn
  19. from ... import initialization as init
  20. from ...activations import ACT2CLS, ACT2FN
  21. from ...modeling_layers import GradientCheckpointingLayer
  22. from ...modeling_outputs import BackboneOutput
  23. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
  24. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  25. from ...processing_utils import Unpack
  26. from ...pytorch_utils import compile_compatible_method_lru_cache
  27. from ...utils import (
  28. ModelOutput,
  29. TransformersKwargs,
  30. auto_docstring,
  31. can_return_tuple,
  32. torch_int,
  33. )
  34. from ...utils.generic import maybe_autocast, merge_with_config_defaults
  35. from ...utils.output_capturing import capture_outputs
  36. from .configuration_efficientloftr import EfficientLoFTRConfig
  37. @dataclass
  38. @auto_docstring(
  39. custom_intro="""
  40. Base class for outputs of EfficientLoFTR keypoint matching models. Due to the nature of keypoint detection and matching, the number
  41. of keypoints is not fixed and can vary from image to image, which makes batching non-trivial. In the batch of
  42. images, the maximum number of matches is set as the dimension of the matches and matching scores.
  43. """
  44. )
  45. class EfficientLoFTRKeypointMatchingOutput(ModelOutput):
  46. r"""
  47. loss (`torch.FloatTensor` of shape `(1,)`, *optional*):
  48. Loss computed during training.
  49. matches (`torch.FloatTensor` of shape `(batch_size, 2, num_matches)`):
  50. Index of keypoint matched in the other image.
  51. matching_scores (`torch.FloatTensor` of shape `(batch_size, 2, num_matches)`):
  52. Scores of predicted matches.
  53. keypoints (`torch.FloatTensor` of shape `(batch_size, num_keypoints, 2)`):
  54. Absolute (x, y) coordinates of predicted keypoints in a given image.
  55. hidden_states (`tuple[torch.FloatTensor, ...]`, *optional*):
  56. Tuple of `torch.FloatTensor` (one for the output of each stage) of shape `(batch_size, 2, num_channels,
  57. num_keypoints)`, returned when `output_hidden_states=True` is passed or when
  58. `config.output_hidden_states=True`)
  59. attentions (`tuple[torch.FloatTensor, ...]`, *optional*):
  60. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, 2, num_heads, num_keypoints,
  61. num_keypoints)`, returned when `output_attentions=True` is passed or when `config.output_attentions=True`)
  62. """
  63. loss: torch.FloatTensor | None = None
  64. matches: torch.FloatTensor | None = None
  65. matching_scores: torch.FloatTensor | None = None
  66. keypoints: torch.FloatTensor | None = None
  67. hidden_states: tuple[torch.FloatTensor] | None = None
  68. attentions: tuple[torch.FloatTensor] | None = None
  69. @compile_compatible_method_lru_cache(maxsize=32)
  70. def compute_embeddings(inv_freq: torch.Tensor, embed_height: int, embed_width: int, hidden_size: int) -> torch.Tensor:
  71. i_indices = torch.ones(embed_height, embed_width, dtype=inv_freq.dtype, device=inv_freq.device)
  72. j_indices = torch.ones(embed_height, embed_width, dtype=inv_freq.dtype, device=inv_freq.device)
  73. i_indices = i_indices.cumsum(0).unsqueeze(-1)
  74. j_indices = j_indices.cumsum(1).unsqueeze(-1)
  75. emb = torch.zeros(1, embed_height, embed_width, hidden_size // 2, dtype=inv_freq.dtype, device=inv_freq.device)
  76. emb[:, :, :, 0::2] = i_indices * inv_freq
  77. emb[:, :, :, 1::2] = j_indices * inv_freq
  78. return emb
  79. # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->EfficientLoFTR
  80. class EfficientLoFTRRotaryEmbedding(nn.Module):
  81. inv_freq: torch.Tensor # fix linting for `register_buffer`
  82. # Ignore copy
  83. def __init__(self, config: EfficientLoFTRConfig, device=None):
  84. super().__init__()
  85. self.config = config
  86. self.rope_type = self.config.rope_parameters["rope_type"]
  87. rope_init_fn: Callable = self.compute_default_rope_parameters
  88. if self.rope_type != "default":
  89. rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  90. inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
  91. self.register_buffer("inv_freq", inv_freq, persistent=False)
  92. self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
  93. @staticmethod
  94. # Ignore copy
  95. def compute_default_rope_parameters(
  96. config: EfficientLoFTRConfig | None = None,
  97. device: Optional["torch.device"] = None,
  98. seq_len: int | None = None,
  99. ) -> tuple["torch.Tensor", float]:
  100. """
  101. Computes the inverse frequencies according to the original RoPE implementation
  102. Args:
  103. config ([`~transformers.PreTrainedConfig`]):
  104. The model configuration.
  105. device (`torch.device`):
  106. The device to use for initialization of the inverse frequencies.
  107. seq_len (`int`, *optional*):
  108. The current sequence length. Unused for this type of RoPE.
  109. Returns:
  110. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  111. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  112. """
  113. base = config.rope_parameters["rope_theta"]
  114. partial_rotary_factor = config.rope_parameters.get("partial_rotary_factor", 1.0)
  115. head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  116. dim = int(head_dim * partial_rotary_factor)
  117. attention_factor = 1.0 # Unused in this type of RoPE
  118. # Compute the inverse frequencies
  119. inv_freq = 1.0 / (
  120. base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
  121. )
  122. return inv_freq, attention_factor
  123. # Ignore copy
  124. @torch.no_grad()
  125. def forward(
  126. self, x: torch.Tensor, position_ids: torch.LongTensor | None = None, layer_type=None
  127. ) -> tuple[torch.Tensor, torch.Tensor]:
  128. feats_height, feats_width = x.shape[-2:]
  129. embed_height = (feats_height - self.config.q_aggregation_kernel_size) // self.config.q_aggregation_stride + 1
  130. embed_width = (feats_width - self.config.q_aggregation_kernel_size) // self.config.q_aggregation_stride + 1
  131. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  132. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  133. emb = compute_embeddings(self.inv_freq, embed_height, embed_width, self.config.hidden_size)
  134. sin = emb.sin()
  135. cos = emb.cos()
  136. sin = sin.repeat_interleave(2, dim=-1)
  137. cos = cos.repeat_interleave(2, dim=-1)
  138. sin = sin.to(device=x.device, dtype=x.dtype)
  139. cos = cos.to(device=x.device, dtype=x.dtype)
  140. return cos, sin
  141. # Copied from transformers.models.rt_detr_v2.modeling_rt_detr_v2.RTDetrV2ConvNormLayer with RTDetrV2->EfficientLoFTR
  142. class EfficientLoFTRConvNormLayer(nn.Module):
  143. def __init__(self, config, in_channels, out_channels, kernel_size, stride, padding=None, activation=None):
  144. super().__init__()
  145. self.conv = nn.Conv2d(
  146. in_channels,
  147. out_channels,
  148. kernel_size,
  149. stride,
  150. padding=(kernel_size - 1) // 2 if padding is None else padding,
  151. bias=False,
  152. )
  153. self.norm = nn.BatchNorm2d(out_channels, config.batch_norm_eps)
  154. self.activation = nn.Identity() if activation is None else ACT2CLS[activation]()
  155. def forward(self, hidden_state):
  156. hidden_state = self.conv(hidden_state)
  157. hidden_state = self.norm(hidden_state)
  158. hidden_state = self.activation(hidden_state)
  159. return hidden_state
  160. class EfficientLoFTRRepVGGBlock(GradientCheckpointingLayer):
  161. """
  162. RepVGG architecture block introduced by the work "RepVGG: Making VGG-style ConvNets Great Again".
  163. """
  164. def __init__(self, config: EfficientLoFTRConfig, stage_idx: int, block_idx: int):
  165. super().__init__()
  166. in_channels = config.stage_block_in_channels[stage_idx][block_idx]
  167. out_channels = config.stage_block_out_channels[stage_idx][block_idx]
  168. stride = config.stage_block_stride[stage_idx][block_idx]
  169. activation = config.activation_function
  170. self.conv1 = EfficientLoFTRConvNormLayer(
  171. config, in_channels, out_channels, kernel_size=3, stride=stride, padding=1
  172. )
  173. self.conv2 = EfficientLoFTRConvNormLayer(
  174. config, in_channels, out_channels, kernel_size=1, stride=stride, padding=0
  175. )
  176. self.identity = nn.BatchNorm2d(in_channels) if in_channels == out_channels and stride == 1 else None
  177. self.activation = nn.Identity() if activation is None else ACT2FN[activation]
  178. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  179. if self.identity is not None:
  180. identity_out = self.identity(hidden_states)
  181. else:
  182. identity_out = 0
  183. hidden_states = self.conv1(hidden_states) + self.conv2(hidden_states) + identity_out
  184. hidden_states = self.activation(hidden_states)
  185. return hidden_states
  186. class EfficientLoFTRRepVGGStage(nn.Module):
  187. def __init__(self, config: EfficientLoFTRConfig, stage_idx: int):
  188. super().__init__()
  189. self.blocks = nn.ModuleList([])
  190. for block_idx in range(config.stage_num_blocks[stage_idx]):
  191. self.blocks.append(
  192. EfficientLoFTRRepVGGBlock(
  193. config,
  194. stage_idx,
  195. block_idx,
  196. )
  197. )
  198. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  199. for block in self.blocks:
  200. hidden_states = block(hidden_states)
  201. return hidden_states
  202. class EfficientLoFTRepVGG(nn.Module):
  203. def __init__(self, config: EfficientLoFTRConfig):
  204. super().__init__()
  205. self.stages = nn.ModuleList([])
  206. for stage_idx in range(len(config.stage_stride)):
  207. stage = EfficientLoFTRRepVGGStage(config, stage_idx)
  208. self.stages.append(stage)
  209. def forward(self, hidden_states: torch.Tensor) -> list[torch.Tensor]:
  210. outputs = []
  211. for stage in self.stages:
  212. hidden_states = stage(hidden_states)
  213. outputs.append(hidden_states)
  214. # Exclude first stage in outputs
  215. outputs = outputs[1:]
  216. return outputs
  217. class EfficientLoFTRAggregationLayer(nn.Module):
  218. def __init__(self, config: EfficientLoFTRConfig):
  219. super().__init__()
  220. hidden_size = config.hidden_size
  221. self.q_aggregation = nn.Conv2d(
  222. hidden_size,
  223. hidden_size,
  224. kernel_size=config.q_aggregation_kernel_size,
  225. padding=0,
  226. stride=config.q_aggregation_stride,
  227. bias=False,
  228. groups=hidden_size,
  229. )
  230. self.kv_aggregation = torch.nn.MaxPool2d(
  231. kernel_size=config.kv_aggregation_kernel_size, stride=config.kv_aggregation_stride
  232. )
  233. self.norm = nn.LayerNorm(hidden_size)
  234. def forward(
  235. self,
  236. hidden_states: torch.Tensor,
  237. encoder_hidden_states: torch.Tensor | None = None,
  238. ) -> tuple[torch.Tensor, torch.Tensor]:
  239. query_states = hidden_states
  240. is_cross_attention = encoder_hidden_states is not None
  241. kv_states = encoder_hidden_states if is_cross_attention else hidden_states
  242. query_states = self.q_aggregation(query_states)
  243. kv_states = self.kv_aggregation(kv_states)
  244. query_states = query_states.permute(0, 2, 3, 1)
  245. kv_states = kv_states.permute(0, 2, 3, 1)
  246. hidden_states = self.norm(query_states)
  247. encoder_hidden_states = self.norm(kv_states)
  248. return hidden_states, encoder_hidden_states
  249. # Copied from transformers.models.cohere.modeling_cohere.rotate_half
  250. def rotate_half(x):
  251. # Split and rotate. Note that this function is different from e.g. Llama.
  252. x1 = x[..., ::2]
  253. x2 = x[..., 1::2]
  254. rot_x = torch.stack([-x2, x1], dim=-1).flatten(-2)
  255. return rot_x
  256. # Copied from transformers.models.cohere.modeling_cohere.apply_rotary_pos_emb
  257. def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
  258. """Applies Rotary Position Embedding to the query and key tensors.
  259. Args:
  260. q (`torch.Tensor`): The query tensor.
  261. k (`torch.Tensor`): The key tensor.
  262. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  263. sin (`torch.Tensor`): The sine part of the rotary embedding.
  264. unsqueeze_dim (`int`, *optional*, defaults to 1):
  265. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  266. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  267. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  268. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  269. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  270. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  271. Returns:
  272. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  273. """
  274. dtype = q.dtype
  275. q = q.float()
  276. k = k.float()
  277. cos = cos.unsqueeze(unsqueeze_dim)
  278. sin = sin.unsqueeze(unsqueeze_dim)
  279. q_embed = (q * cos) + (rotate_half(q) * sin)
  280. k_embed = (k * cos) + (rotate_half(k) * sin)
  281. return q_embed.to(dtype=dtype), k_embed.to(dtype=dtype)
  282. # Copied from transformers.models.cohere.modeling_cohere.repeat_kv
  283. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  284. """
  285. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  286. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  287. """
  288. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  289. if n_rep == 1:
  290. return hidden_states
  291. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  292. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  293. # Copied from transformers.models.llama.modeling_llama.eager_attention_forward
  294. def eager_attention_forward(
  295. module: nn.Module,
  296. query: torch.Tensor,
  297. key: torch.Tensor,
  298. value: torch.Tensor,
  299. attention_mask: torch.Tensor | None,
  300. scaling: float,
  301. dropout: float = 0.0,
  302. **kwargs: Unpack[TransformersKwargs],
  303. ):
  304. key_states = repeat_kv(key, module.num_key_value_groups)
  305. value_states = repeat_kv(value, module.num_key_value_groups)
  306. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  307. if attention_mask is not None:
  308. attn_weights = attn_weights + attention_mask
  309. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  310. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  311. attn_output = torch.matmul(attn_weights, value_states)
  312. attn_output = attn_output.transpose(1, 2).contiguous()
  313. return attn_output, attn_weights
  314. class EfficientLoFTRAttention(nn.Module):
  315. """Multi-headed attention from 'Attention Is All You Need' paper"""
  316. def __init__(self, config: EfficientLoFTRConfig, layer_idx: int):
  317. super().__init__()
  318. self.config = config
  319. self.layer_idx = layer_idx
  320. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  321. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  322. self.scaling = self.head_dim**-0.5
  323. self.attention_dropout = config.attention_dropout
  324. self.is_causal = False
  325. self.q_proj = nn.Linear(
  326. config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
  327. )
  328. self.k_proj = nn.Linear(
  329. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  330. )
  331. self.v_proj = nn.Linear(
  332. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  333. )
  334. self.o_proj = nn.Linear(
  335. config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
  336. )
  337. def forward(
  338. self,
  339. hidden_states: torch.Tensor,
  340. encoder_hidden_states: torch.Tensor | None = None,
  341. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  342. **kwargs: Unpack[TransformersKwargs],
  343. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  344. batch_size, seq_len, dim = hidden_states.shape
  345. input_shape = hidden_states.shape[:-1]
  346. query_states = self.q_proj(hidden_states).view(batch_size, seq_len, -1, dim)
  347. current_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
  348. key_states = self.k_proj(current_states).view(batch_size, seq_len, -1, dim)
  349. value_states = self.v_proj(current_states).view(batch_size, seq_len, -1, self.head_dim).transpose(1, 2)
  350. if position_embeddings is not None:
  351. cos, sin = position_embeddings
  352. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, unsqueeze_dim=2)
  353. query_states = query_states.view(batch_size, seq_len, -1, self.head_dim).transpose(1, 2)
  354. key_states = key_states.view(batch_size, seq_len, -1, self.head_dim).transpose(1, 2)
  355. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  356. self.config._attn_implementation, eager_attention_forward
  357. )
  358. attn_output, attn_weights = attention_interface(
  359. self,
  360. query_states,
  361. key_states,
  362. value_states,
  363. attention_mask=None,
  364. dropout=0.0 if not self.training else self.attention_dropout,
  365. scaling=self.scaling,
  366. **kwargs,
  367. )
  368. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  369. attn_output = self.o_proj(attn_output)
  370. return attn_output, attn_weights
  371. class EfficientLoFTRMLP(nn.Module):
  372. def __init__(self, config: EfficientLoFTRConfig):
  373. super().__init__()
  374. hidden_size = config.hidden_size
  375. intermediate_size = config.intermediate_size
  376. self.fc1 = nn.Linear(hidden_size * 2, intermediate_size, bias=False)
  377. self.activation = ACT2FN[config.mlp_activation_function]
  378. self.fc2 = nn.Linear(intermediate_size, hidden_size, bias=False)
  379. self.layer_norm = nn.LayerNorm(hidden_size)
  380. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  381. hidden_states = self.fc1(hidden_states)
  382. hidden_states = self.activation(hidden_states)
  383. hidden_states = self.fc2(hidden_states)
  384. hidden_states = self.layer_norm(hidden_states)
  385. return hidden_states
  386. class EfficientLoFTRAggregatedAttention(nn.Module):
  387. def __init__(self, config: EfficientLoFTRConfig, layer_idx: int):
  388. super().__init__()
  389. self.q_aggregation_kernel_size = config.q_aggregation_kernel_size
  390. self.aggregation = EfficientLoFTRAggregationLayer(config)
  391. self.attention = EfficientLoFTRAttention(config, layer_idx)
  392. self.mlp = EfficientLoFTRMLP(config)
  393. def forward(
  394. self,
  395. hidden_states: torch.Tensor,
  396. encoder_hidden_states: torch.Tensor | None = None,
  397. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  398. **kwargs: Unpack[TransformersKwargs],
  399. ) -> torch.Tensor:
  400. batch_size, embed_dim, _, _ = hidden_states.shape
  401. # Aggregate features
  402. aggregated_hidden_states, aggregated_encoder_hidden_states = self.aggregation(
  403. hidden_states, encoder_hidden_states
  404. )
  405. _, aggregated_h, aggregated_w, _ = aggregated_hidden_states.shape
  406. # Multi-head attention
  407. aggregated_hidden_states = aggregated_hidden_states.reshape(batch_size, -1, embed_dim)
  408. aggregated_encoder_hidden_states = aggregated_encoder_hidden_states.reshape(batch_size, -1, embed_dim)
  409. attn_output, _ = self.attention(
  410. aggregated_hidden_states,
  411. aggregated_encoder_hidden_states,
  412. position_embeddings=position_embeddings,
  413. **kwargs,
  414. )
  415. # Upsample features
  416. # (batch_size, seq_len, embed_dim) -> (batch_size, embed_dim, h, w) with seq_len = h * w
  417. attn_output = attn_output.permute(0, 2, 1)
  418. attn_output = attn_output.reshape(batch_size, embed_dim, aggregated_h, aggregated_w)
  419. attn_output = torch.nn.functional.interpolate(
  420. attn_output, scale_factor=self.q_aggregation_kernel_size, mode="bilinear", align_corners=False
  421. )
  422. intermediate_states = torch.cat([hidden_states, attn_output], dim=1)
  423. intermediate_states = intermediate_states.permute(0, 2, 3, 1)
  424. output_states = self.mlp(intermediate_states)
  425. output_states = output_states.permute(0, 3, 1, 2)
  426. hidden_states = hidden_states + output_states
  427. return hidden_states
  428. class EfficientLoFTRLocalFeatureTransformerLayer(GradientCheckpointingLayer):
  429. def __init__(self, config: EfficientLoFTRConfig, layer_idx: int):
  430. super().__init__()
  431. self.self_attention = EfficientLoFTRAggregatedAttention(config, layer_idx)
  432. self.cross_attention = EfficientLoFTRAggregatedAttention(config, layer_idx)
  433. def forward(
  434. self,
  435. hidden_states: torch.Tensor,
  436. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  437. **kwargs: Unpack[TransformersKwargs],
  438. ) -> torch.Tensor:
  439. batch_size, _, embed_dim, height, width = hidden_states.shape
  440. hidden_states = hidden_states.reshape(-1, embed_dim, height, width)
  441. hidden_states = self.self_attention(hidden_states, position_embeddings=position_embeddings, **kwargs)
  442. ###
  443. # Implementation of a bug in the original implementation regarding the cross-attention
  444. # See : https://github.com/zju3dv/MatchAnything/issues/26
  445. hidden_states = hidden_states.reshape(-1, 2, embed_dim, height, width)
  446. features_0 = hidden_states[:, 0]
  447. features_1 = hidden_states[:, 1]
  448. features_0 = self.cross_attention(features_0, features_1, **kwargs)
  449. features_1 = self.cross_attention(features_1, features_0, **kwargs)
  450. hidden_states = torch.stack((features_0, features_1), dim=1)
  451. ###
  452. return hidden_states
  453. class EfficientLoFTRLocalFeatureTransformer(nn.Module):
  454. def __init__(self, config: EfficientLoFTRConfig):
  455. super().__init__()
  456. self.layers = nn.ModuleList(
  457. [
  458. EfficientLoFTRLocalFeatureTransformerLayer(config, layer_idx=i)
  459. for i in range(config.num_attention_layers)
  460. ]
  461. )
  462. def forward(
  463. self,
  464. hidden_states: torch.Tensor,
  465. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  466. **kwargs: Unpack[TransformersKwargs],
  467. ) -> torch.Tensor:
  468. for layer in self.layers:
  469. hidden_states = layer(hidden_states, position_embeddings=position_embeddings, **kwargs)
  470. return hidden_states
  471. class EfficientLoFTROutConvBlock(nn.Module):
  472. def __init__(self, config: EfficientLoFTRConfig, hidden_size: int, intermediate_size: int):
  473. super().__init__()
  474. self.out_conv1 = nn.Conv2d(hidden_size, intermediate_size, kernel_size=1, stride=1, padding=0, bias=False)
  475. self.out_conv2 = nn.Conv2d(
  476. intermediate_size, intermediate_size, kernel_size=3, stride=1, padding=1, bias=False
  477. )
  478. self.batch_norm = nn.BatchNorm2d(intermediate_size)
  479. self.activation = ACT2CLS[config.mlp_activation_function]()
  480. self.out_conv3 = nn.Conv2d(intermediate_size, hidden_size, kernel_size=3, stride=1, padding=1, bias=False)
  481. def forward(self, hidden_states: torch.Tensor, residual_states: torch.Tensor) -> torch.Tensor:
  482. residual_states = self.out_conv1(residual_states)
  483. residual_states = residual_states + hidden_states
  484. residual_states = self.out_conv2(residual_states)
  485. residual_states = self.batch_norm(residual_states)
  486. residual_states = self.activation(residual_states)
  487. residual_states = self.out_conv3(residual_states)
  488. residual_states = nn.functional.interpolate(
  489. residual_states, scale_factor=2.0, mode="bilinear", align_corners=False
  490. )
  491. return residual_states
  492. class EfficientLoFTRFineFusionLayer(nn.Module):
  493. def __init__(self, config: EfficientLoFTRConfig):
  494. super().__init__()
  495. self.fine_kernel_size = config.fine_kernel_size
  496. fine_fusion_dims = config.fine_fusion_dims
  497. self.out_conv = nn.Conv2d(
  498. fine_fusion_dims[0], fine_fusion_dims[0], kernel_size=1, stride=1, padding=0, bias=False
  499. )
  500. self.out_conv_layers = nn.ModuleList()
  501. for i in range(1, len(fine_fusion_dims)):
  502. out_conv = EfficientLoFTROutConvBlock(config, fine_fusion_dims[i], fine_fusion_dims[i - 1])
  503. self.out_conv_layers.append(out_conv)
  504. def forward_pyramid(
  505. self,
  506. hidden_states: torch.Tensor,
  507. residual_states: list[torch.Tensor],
  508. ) -> torch.Tensor:
  509. hidden_states = self.out_conv(hidden_states)
  510. hidden_states = nn.functional.interpolate(
  511. hidden_states, scale_factor=2.0, mode="bilinear", align_corners=False
  512. )
  513. for i, layer in enumerate(self.out_conv_layers):
  514. hidden_states = layer(hidden_states, residual_states[i])
  515. return hidden_states
  516. def forward(
  517. self,
  518. coarse_features: torch.Tensor,
  519. residual_features: list[torch.Tensor] | tuple[torch.Tensor],
  520. ) -> tuple[torch.Tensor, torch.Tensor]:
  521. """
  522. For each image pair, compute the fine features of pixels.
  523. In both images, compute a patch of fine features center cropped around each coarse pixel.
  524. In the first image, the feature patch is kernel_size large and long.
  525. In the second image, it is (kernel_size + 2) large and long.
  526. """
  527. batch_size, _, embed_dim, coarse_height, coarse_width = coarse_features.shape
  528. coarse_features = coarse_features.reshape(-1, embed_dim, coarse_height, coarse_width)
  529. residual_features = list(reversed(residual_features))
  530. # 1. Fine feature extraction
  531. fine_features = self.forward_pyramid(coarse_features, residual_features)
  532. _, fine_embed_dim, fine_height, fine_width = fine_features.shape
  533. fine_features = fine_features.reshape(batch_size, 2, fine_embed_dim, fine_height, fine_width)
  534. fine_features_0 = fine_features[:, 0]
  535. fine_features_1 = fine_features[:, 1]
  536. # 2. Unfold all local windows in crops
  537. stride = int(fine_height // coarse_height)
  538. fine_features_0 = nn.functional.unfold(
  539. fine_features_0, kernel_size=self.fine_kernel_size, stride=stride, padding=0
  540. )
  541. _, _, seq_len = fine_features_0.shape
  542. fine_features_0 = fine_features_0.reshape(batch_size, -1, self.fine_kernel_size**2, seq_len)
  543. fine_features_0 = fine_features_0.permute(0, 3, 2, 1)
  544. fine_features_1 = nn.functional.unfold(
  545. fine_features_1, kernel_size=self.fine_kernel_size + 2, stride=stride, padding=1
  546. )
  547. fine_features_1 = fine_features_1.reshape(batch_size, -1, (self.fine_kernel_size + 2) ** 2, seq_len)
  548. fine_features_1 = fine_features_1.permute(0, 3, 2, 1)
  549. return fine_features_0, fine_features_1
  550. @auto_docstring
  551. class EfficientLoFTRPreTrainedModel(PreTrainedModel):
  552. """
  553. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  554. models.
  555. """
  556. config_class = EfficientLoFTRConfig
  557. base_model_prefix = "efficientloftr"
  558. main_input_name = "pixel_values"
  559. input_modalities = ("image",)
  560. supports_gradient_checkpointing = True
  561. _supports_flash_attn = True
  562. _supports_sdpa = True
  563. _can_record_outputs = {
  564. "hidden_states": EfficientLoFTRRepVGGBlock,
  565. "attentions": EfficientLoFTRAttention,
  566. }
  567. @torch.no_grad()
  568. def _init_weights(self, module: nn.Module) -> None:
  569. """Initialize the weights"""
  570. if isinstance(module, (nn.Linear, nn.Conv2d, nn.Conv1d, nn.BatchNorm2d)):
  571. init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
  572. if module.bias is not None:
  573. init.zeros_(module.bias)
  574. if getattr(module, "running_mean", None) is not None:
  575. init.zeros_(module.running_mean)
  576. init.ones_(module.running_var)
  577. init.zeros_(module.num_batches_tracked)
  578. elif isinstance(module, nn.LayerNorm):
  579. init.zeros_(module.bias)
  580. init.ones_(module.weight)
  581. elif isinstance(module, EfficientLoFTRRotaryEmbedding):
  582. rope_fn = (
  583. ROPE_INIT_FUNCTIONS[module.rope_type]
  584. if module.rope_type != "default"
  585. else module.compute_default_rope_parameters
  586. )
  587. buffer_value, _ = rope_fn(module.config)
  588. init.copy_(module.inv_freq, buffer_value)
  589. init.copy_(module.original_inv_freq, buffer_value)
  590. # Copied from transformers.models.superpoint.modeling_superpoint.SuperPointPreTrainedModel.extract_one_channel_pixel_values with SuperPoint->EfficientLoFTR
  591. def extract_one_channel_pixel_values(self, pixel_values: torch.FloatTensor) -> torch.FloatTensor:
  592. """
  593. Assuming pixel_values has shape (batch_size, 3, height, width), and that all channels values are the same,
  594. extract the first channel value to get a tensor of shape (batch_size, 1, height, width) for EfficientLoFTR. This is
  595. a workaround for the issue discussed in :
  596. https://github.com/huggingface/transformers/pull/25786#issuecomment-1730176446
  597. Args:
  598. pixel_values: torch.FloatTensor of shape (batch_size, 3, height, width)
  599. Returns:
  600. pixel_values: torch.FloatTensor of shape (batch_size, 1, height, width)
  601. """
  602. return pixel_values[:, 0, :, :][:, None, :, :]
  603. @auto_docstring(
  604. custom_intro="""
  605. EfficientLoFTR model taking images as inputs and outputting the features of the images.
  606. """
  607. )
  608. class EfficientLoFTRModel(EfficientLoFTRPreTrainedModel):
  609. def __init__(self, config: EfficientLoFTRConfig):
  610. super().__init__(config)
  611. self.config = config
  612. self.backbone = EfficientLoFTRepVGG(config)
  613. self.local_feature_transformer = EfficientLoFTRLocalFeatureTransformer(config)
  614. self.rotary_emb = EfficientLoFTRRotaryEmbedding(config=config)
  615. self.post_init()
  616. @merge_with_config_defaults
  617. @capture_outputs
  618. @auto_docstring
  619. def forward(
  620. self,
  621. pixel_values: torch.FloatTensor,
  622. labels: torch.LongTensor | None = None,
  623. **kwargs: Unpack[TransformersKwargs],
  624. ) -> BackboneOutput:
  625. r"""
  626. Examples:
  627. ```python
  628. >>> from transformers import AutoImageProcessor, AutoModel
  629. >>> import torch
  630. >>> from PIL import Image
  631. >>> import httpx
  632. >>> from io import BytesIO
  633. >>> url = "https://github.com/magicleap/SuperGluePretrainedNetwork/blob/master/assets/phototourism_sample_images/london_bridge_78916675_4568141288.jpg?raw=true"
  634. >>> with httpx.stream("GET", url) as response:
  635. ... image1 = Image.open(BytesIO(response.read()))
  636. >>> url = "https://github.com/magicleap/SuperGluePretrainedNetwork/blob/master/assets/phototourism_sample_images/london_bridge_19481797_2295892421.jpg?raw=true"
  637. >>> with httpx.stream("GET", url) as response:
  638. ... image2 = Image.open(BytesIO(response.read()))
  639. >>> images = [image1, image2]
  640. >>> processor = AutoImageProcessor.from_pretrained("zju-community/efficient_loftr")
  641. >>> model = AutoModel.from_pretrained("zju-community/efficient_loftr")
  642. >>> with torch.no_grad():
  643. >>> inputs = processor(images, return_tensors="pt")
  644. >>> outputs = model(**inputs)
  645. ```"""
  646. if labels is not None:
  647. raise ValueError("EfficientLoFTR is not trainable, no labels should be provided.")
  648. if pixel_values.ndim != 5 or pixel_values.size(1) != 2:
  649. raise ValueError("Input must be a 5D tensor of shape (batch_size, 2, num_channels, height, width)")
  650. batch_size, _, channels, height, width = pixel_values.shape
  651. pixel_values = pixel_values.reshape(batch_size * 2, channels, height, width)
  652. pixel_values = self.extract_one_channel_pixel_values(pixel_values)
  653. # 1. Local Feature CNN
  654. features = self.backbone(pixel_values)
  655. # Last stage outputs are coarse outputs
  656. coarse_features = features[-1]
  657. # Rest is residual features used in EfficientLoFTRFineFusionLayer
  658. residual_features = features[:-1]
  659. coarse_embed_dim, coarse_height, coarse_width = coarse_features.shape[-3:]
  660. # 2. Coarse-level LoFTR module
  661. cos, sin = self.rotary_emb(coarse_features)
  662. cos = cos.expand(batch_size * 2, -1, -1, -1).reshape(batch_size * 2, -1, coarse_embed_dim)
  663. sin = sin.expand(batch_size * 2, -1, -1, -1).reshape(batch_size * 2, -1, coarse_embed_dim)
  664. position_embeddings = (cos, sin)
  665. coarse_features = coarse_features.reshape(batch_size, 2, coarse_embed_dim, coarse_height, coarse_width)
  666. coarse_features = self.local_feature_transformer(
  667. coarse_features, position_embeddings=position_embeddings, **kwargs
  668. )
  669. features = (coarse_features,) + tuple(residual_features)
  670. return BackboneOutput(feature_maps=features)
  671. def mask_border(tensor: torch.Tensor, border_margin: int, value: bool | float | int) -> torch.Tensor:
  672. """
  673. Mask a tensor border with a given value
  674. Args:
  675. tensor (`torch.Tensor` of shape `(batch_size, height_0, width_0, height_1, width_1)`):
  676. The tensor to mask
  677. border_margin (`int`) :
  678. The size of the border
  679. value (`Union[bool, int, float]`):
  680. The value to place in the tensor's borders
  681. Returns:
  682. tensor (`torch.Tensor` of shape `(batch_size, height_0, width_0, height_1, width_1)`):
  683. The masked tensor
  684. """
  685. if border_margin <= 0:
  686. return tensor
  687. tensor[:, :border_margin] = value
  688. tensor[:, :, :border_margin] = value
  689. tensor[:, :, :, :border_margin] = value
  690. tensor[:, :, :, :, :border_margin] = value
  691. tensor[:, -border_margin:] = value
  692. tensor[:, :, -border_margin:] = value
  693. tensor[:, :, :, -border_margin:] = value
  694. tensor[:, :, :, :, -border_margin:] = value
  695. return tensor
  696. def create_meshgrid(
  697. height: int | torch.Tensor,
  698. width: int | torch.Tensor,
  699. normalized_coordinates: bool = False,
  700. device: torch.device | None = None,
  701. dtype: torch.dtype | None = None,
  702. ) -> torch.Tensor:
  703. """
  704. Copied from kornia library : kornia/kornia/utils/grid.py:26
  705. Generate a coordinate grid for an image.
  706. When the flag ``normalized_coordinates`` is set to True, the grid is
  707. normalized to be in the range :math:`[-1,1]` to be consistent with the pytorch
  708. function :py:func:`torch.nn.functional.grid_sample`.
  709. Args:
  710. height (`int`):
  711. The image height (rows).
  712. width (`int`):
  713. The image width (cols).
  714. normalized_coordinates (`bool`):
  715. Whether to normalize coordinates in the range :math:`[-1,1]` in order to be consistent with the
  716. PyTorch function :py:func:`torch.nn.functional.grid_sample`.
  717. device (`torch.device`):
  718. The device on which the grid will be generated.
  719. dtype (`torch.dtype`):
  720. The data type of the generated grid.
  721. Return:
  722. grid (`torch.Tensor` of shape `(1, height, width, 2)`):
  723. The grid tensor.
  724. Example:
  725. >>> create_meshgrid(2, 2)
  726. tensor([[[[-1., -1.],
  727. [ 1., -1.]],
  728. <BLANKLINE>
  729. [[-1., 1.],
  730. [ 1., 1.]]]])
  731. >>> create_meshgrid(2, 2, normalized_coordinates=False)
  732. tensor([[[[0., 0.],
  733. [1., 0.]],
  734. <BLANKLINE>
  735. [[0., 1.],
  736. [1., 1.]]]])
  737. """
  738. xs = torch.linspace(0, width - 1, width, device=device, dtype=dtype)
  739. ys = torch.linspace(0, height - 1, height, device=device, dtype=dtype)
  740. if normalized_coordinates:
  741. xs = (xs / (width - 1) - 0.5) * 2
  742. ys = (ys / (height - 1) - 0.5) * 2
  743. grid = torch.stack(torch.meshgrid(ys, xs, indexing="ij"), dim=-1)
  744. grid = grid.permute(1, 0, 2).unsqueeze(0)
  745. return grid
  746. def spatial_expectation2d(input: torch.Tensor, normalized_coordinates: bool = True) -> torch.Tensor:
  747. r"""
  748. Copied from kornia library : kornia/geometry/subpix/dsnt.py:76
  749. Compute the expectation of coordinate values using spatial probabilities.
  750. The input heatmap is assumed to represent a valid spatial probability distribution,
  751. which can be achieved using :func:`~kornia.geometry.subpixel.spatial_softmax2d`.
  752. Args:
  753. input (`torch.Tensor` of shape `(batch_size, embed_dim, height, width)`):
  754. The input tensor representing dense spatial probabilities.
  755. normalized_coordinates (`bool`):
  756. Whether to return the coordinates normalized in the range of :math:`[-1, 1]`. Otherwise, it will return
  757. the coordinates in the range of the input shape.
  758. Returns:
  759. output (`torch.Tensor` of shape `(batch_size, embed_dim, 2)`)
  760. Expected value of the 2D coordinates. Output order of the coordinates is (x, y).
  761. Examples:
  762. >>> heatmaps = torch.tensor([[[
  763. ... [0., 0., 0.],
  764. ... [0., 0., 0.],
  765. ... [0., 1., 0.]]]])
  766. >>> spatial_expectation2d(heatmaps, False)
  767. tensor([[[1., 2.]]])
  768. """
  769. batch_size, embed_dim, height, width = input.shape
  770. # Create coordinates grid.
  771. grid = create_meshgrid(height, width, normalized_coordinates, input.device)
  772. grid = grid.to(input.dtype)
  773. pos_x = grid[..., 0].reshape(-1)
  774. pos_y = grid[..., 1].reshape(-1)
  775. input_flat = input.view(batch_size, embed_dim, -1)
  776. # Compute the expectation of the coordinates.
  777. expected_y = torch.sum(pos_y * input_flat, -1, keepdim=True)
  778. expected_x = torch.sum(pos_x * input_flat, -1, keepdim=True)
  779. output = torch.cat([expected_x, expected_y], -1)
  780. return output.view(batch_size, embed_dim, 2)
  781. @auto_docstring(
  782. custom_intro="""
  783. EfficientLoFTR model taking images as inputs and outputting the matching of them.
  784. """
  785. )
  786. class EfficientLoFTRForKeypointMatching(EfficientLoFTRPreTrainedModel):
  787. """EfficientLoFTR dense image matcher
  788. Given two images, we determine the correspondences by:
  789. 1. Extracting coarse and fine features through a backbone
  790. 2. Transforming coarse features through self and cross attention
  791. 3. Matching coarse features to obtain coarse coordinates of matches
  792. 4. Obtaining full resolution fine features by fusing transformed and backbone coarse features
  793. 5. Refining the coarse matches using fine feature patches centered at each coarse match in a two-stage refinement
  794. Yifan Wang, Xingyi He, Sida Peng, Dongli Tan and Xiaowei Zhou.
  795. Efficient LoFTR: Semi-Dense Local Feature Matching with Sparse-Like Speed
  796. In CVPR, 2024. https://huggingface.co/papers/2403.04765
  797. """
  798. def __init__(self, config: EfficientLoFTRConfig):
  799. super().__init__(config)
  800. self.config = config
  801. self.efficientloftr = EfficientLoFTRModel(config)
  802. self.refinement_layer = EfficientLoFTRFineFusionLayer(config)
  803. self.post_init()
  804. def _get_matches_from_scores(self, scores: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
  805. """
  806. Based on a keypoint score matrix, compute the best keypoint matches between the first and second image.
  807. Since each image pair can have different number of matches, the matches are concatenated together for all pair
  808. in the batch and a batch_indices tensor is returned to specify which match belong to which element in the batch.
  809. Note:
  810. This step can be done as a postprocessing step, because does not involve any model weights/params.
  811. However, we keep it in the modeling code for consistency with other keypoint matching models AND for
  812. easier torch.compile/torch.export (all ops are in torch).
  813. Args:
  814. scores (`torch.Tensor` of shape `(batch_size, height_0, width_0, height_1, width_1)`):
  815. Scores of keypoints
  816. Returns:
  817. matched_indices (`torch.Tensor` of shape `(2, num_matches)`):
  818. Indices representing which pixel in the first image matches which pixel in the second image
  819. matching_scores (`torch.Tensor` of shape `(num_matches,)`):
  820. Scores of each match
  821. """
  822. batch_size, height0, width0, height1, width1 = scores.shape
  823. scores = scores.view(batch_size, height0 * width0, height1 * width1)
  824. # For each keypoint, get the best match
  825. max_0 = scores.max(2, keepdim=True).values
  826. max_1 = scores.max(1, keepdim=True).values
  827. # 1. Thresholding
  828. mask = scores > self.config.coarse_matching_threshold
  829. # 2. Border removal
  830. mask = mask.reshape(batch_size, height0, width0, height1, width1)
  831. mask = mask_border(mask, self.config.coarse_matching_border_removal, False)
  832. mask = mask.reshape(batch_size, height0 * width0, height1 * width1)
  833. # 3. Mutual nearest neighbors
  834. mask = mask * (scores == max_0) * (scores == max_1)
  835. # 4. Fine coarse matches
  836. masked_scores = scores * mask
  837. matching_scores_0, max_indices_0 = masked_scores.max(1)
  838. matching_scores_1, max_indices_1 = masked_scores.max(2)
  839. matching_indices = torch.cat([max_indices_0, max_indices_1]).reshape(batch_size, 2, -1)
  840. matching_scores = torch.stack([matching_scores_0, matching_scores_1], dim=1)
  841. # For the keypoints not meeting the threshold score, set the indices to -1 which corresponds to no matches found
  842. matching_indices = torch.where(matching_scores > 0, matching_indices, -1)
  843. return matching_indices, matching_scores
  844. def _coarse_matching(
  845. self, coarse_features: torch.Tensor, coarse_scale: float
  846. ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  847. """
  848. For each image pair, compute the matching confidence between each coarse element (by default (image_height / 8)
  849. * (image_width / 8 elements)) from the first image to the second image.
  850. Note:
  851. This step can be done as a postprocessing step, because does not involve any model weights/params.
  852. However, we keep it in the modeling code for consistency with other keypoint matching models AND for
  853. easier torch.compile/torch.export (all ops are in torch).
  854. Args:
  855. coarse_features (`torch.Tensor` of shape `(batch_size, 2, hidden_size, coarse_height, coarse_width)`):
  856. Coarse features
  857. coarse_scale (`float`): Scale between the image size and the coarse size
  858. Returns:
  859. keypoints (`torch.Tensor` of shape `(batch_size, 2, num_matches, 2)`):
  860. Keypoints coordinates.
  861. matching_scores (`torch.Tensor` of shape `(batch_size, 2, num_matches)`):
  862. The confidence matching score of each keypoint.
  863. matched_indices (`torch.Tensor` of shape `(batch_size, 2, num_matches)`):
  864. Indices which indicates which keypoint in an image matched with which keypoint in the other image. For
  865. both image in the pair.
  866. """
  867. batch_size, _, embed_dim, height, width = coarse_features.shape
  868. # (batch_size, 2, embed_dim, height, width) -> (batch_size, 2, height * width, embed_dim)
  869. coarse_features = coarse_features.permute(0, 1, 3, 4, 2)
  870. coarse_features = coarse_features.reshape(batch_size, 2, -1, embed_dim)
  871. coarse_features = coarse_features / coarse_features.shape[-1] ** 0.5
  872. coarse_features_0 = coarse_features[:, 0]
  873. coarse_features_1 = coarse_features[:, 1]
  874. similarity = coarse_features_0 @ coarse_features_1.transpose(-1, -2)
  875. similarity = similarity / self.config.coarse_matching_temperature
  876. if self.config.coarse_matching_skip_softmax:
  877. confidence = similarity
  878. else:
  879. confidence = nn.functional.softmax(similarity, 1) * nn.functional.softmax(similarity, 2)
  880. confidence = confidence.view(batch_size, height, width, height, width)
  881. matched_indices, matching_scores = self._get_matches_from_scores(confidence)
  882. keypoints = torch.stack([matched_indices % width, matched_indices // width], dim=-1) * coarse_scale
  883. return keypoints, matching_scores, matched_indices
  884. def _get_first_stage_fine_matching(
  885. self,
  886. fine_confidence: torch.Tensor,
  887. coarse_matched_keypoints: torch.Tensor,
  888. fine_window_size: int,
  889. fine_scale: float,
  890. ) -> tuple[torch.Tensor, torch.Tensor]:
  891. """
  892. For each coarse pixel, retrieve the highest fine confidence score and index.
  893. The index represents the matching between a pixel position in the fine window in the first image and a pixel
  894. position in the fine window of the second image.
  895. For example, for a fine_window_size of 64 (8 * 8), the index 2474 represents the matching between the index 38
  896. (2474 // 64) in the fine window of the first image, and the index 42 in the second image. This means that 38
  897. which corresponds to the position (4, 6) (4 // 8 and 4 % 8) is matched with the position (5, 2). In this example
  898. the coarse matched coordinate will be shifted to the matched fine coordinates in the first and second image.
  899. Note:
  900. This step can be done as a postprocessing step, because does not involve any model weights/params.
  901. However, we keep it in the modeling code for consistency with other keypoint matching models AND for
  902. easier torch.compile/torch.export (all ops are in torch).
  903. Args:
  904. fine_confidence (`torch.Tensor` of shape `(num_matches, fine_window_size, fine_window_size)`):
  905. First stage confidence of matching fine features between the first and the second image
  906. coarse_matched_keypoints (`torch.Tensor` of shape `(2, num_matches, 2)`):
  907. Coarse matched keypoint between the first and the second image.
  908. fine_window_size (`int`):
  909. Size of the window used to refine matches
  910. fine_scale (`float`):
  911. Scale between the size of fine features and coarse features
  912. Returns:
  913. indices (`torch.Tensor` of shape `(2, num_matches, 1)`):
  914. Indices of the fine coordinate matched in the fine window
  915. fine_matches (`torch.Tensor` of shape `(2, num_matches, 2)`):
  916. Coordinates of matched keypoints after the first fine stage
  917. """
  918. batch_size, num_keypoints, _, _ = fine_confidence.shape
  919. fine_kernel_size = torch_int(fine_window_size**0.5)
  920. fine_confidence = fine_confidence.reshape(batch_size, num_keypoints, -1)
  921. values, indices = torch.max(fine_confidence, dim=-1)
  922. indices = indices[..., None]
  923. indices_0 = indices // fine_window_size
  924. indices_1 = indices % fine_window_size
  925. grid = create_meshgrid(
  926. fine_kernel_size,
  927. fine_kernel_size,
  928. normalized_coordinates=False,
  929. device=fine_confidence.device,
  930. dtype=fine_confidence.dtype,
  931. )
  932. grid = grid - (fine_kernel_size // 2) + 0.5
  933. grid = grid.reshape(1, 1, -1, 2).expand(batch_size, num_keypoints, -1, -1)
  934. delta_0 = torch.gather(grid, 1, indices_0.unsqueeze(-1).expand(-1, -1, -1, 2)).squeeze(2)
  935. delta_1 = torch.gather(grid, 1, indices_1.unsqueeze(-1).expand(-1, -1, -1, 2)).squeeze(2)
  936. fine_matches_0 = coarse_matched_keypoints[:, 0] + delta_0 * fine_scale
  937. fine_matches_1 = coarse_matched_keypoints[:, 1] + delta_1 * fine_scale
  938. indices = torch.stack([indices_0, indices_1], dim=1)
  939. fine_matches = torch.stack([fine_matches_0, fine_matches_1], dim=1)
  940. return indices, fine_matches
  941. def _get_second_stage_fine_matching(
  942. self,
  943. indices: torch.Tensor,
  944. fine_matches: torch.Tensor,
  945. fine_confidence: torch.Tensor,
  946. fine_window_size: int,
  947. fine_scale: float,
  948. ) -> torch.Tensor:
  949. """
  950. For the given position in their respective fine windows, retrieve the 3x3 fine confidences around this position.
  951. After applying softmax to these confidences, compute the 2D spatial expected coordinates.
  952. Shift the first stage fine matching with these expected coordinates.
  953. Note:
  954. This step can be done as a postprocessing step, because does not involve any model weights/params.
  955. However, we keep it in the modeling code for consistency with other keypoint matching models AND for
  956. easier torch.compile/torch.export (all ops are in torch).
  957. Args:
  958. indices (`torch.Tensor` of shape `(batch_size, 2, num_keypoints)`):
  959. Indices representing the position of each keypoint in the fine window
  960. fine_matches (`torch.Tensor` of shape `(2, num_matches, 2)`):
  961. Coordinates of matched keypoints after the first fine stage
  962. fine_confidence (`torch.Tensor` of shape `(num_matches, fine_window_size, fine_window_size)`):
  963. Second stage confidence of matching fine features between the first and the second image
  964. fine_window_size (`int`):
  965. Size of the window used to refine matches
  966. fine_scale (`float`):
  967. Scale between the size of fine features and coarse features
  968. Returns:
  969. fine_matches (`torch.Tensor` of shape `(2, num_matches, 2)`):
  970. Coordinates of matched keypoints after the second fine stage
  971. """
  972. batch_size, num_keypoints, _, _ = fine_confidence.shape
  973. fine_kernel_size = torch_int(fine_window_size**0.5)
  974. indices_0 = indices[:, 0]
  975. indices_1 = indices[:, 1]
  976. indices_1_i = indices_1 // fine_kernel_size
  977. indices_1_j = indices_1 % fine_kernel_size
  978. # matches_indices, indices_0, indices_1_i, indices_1_j of shape (num_matches, 3, 3)
  979. batch_indices = torch.arange(batch_size, device=indices_0.device).reshape(batch_size, 1, 1, 1)
  980. matches_indices = torch.arange(num_keypoints, device=indices_0.device).reshape(1, num_keypoints, 1, 1)
  981. indices_0 = indices_0[..., None]
  982. indices_1_i = indices_1_i[..., None]
  983. indices_1_j = indices_1_j[..., None]
  984. delta = create_meshgrid(3, 3, normalized_coordinates=True, device=indices_0.device).to(torch.long)
  985. delta = delta[None, ...]
  986. indices_1_i = indices_1_i + delta[..., 1]
  987. indices_1_j = indices_1_j + delta[..., 0]
  988. fine_confidence = fine_confidence.reshape(
  989. batch_size, num_keypoints, fine_window_size, fine_kernel_size + 2, fine_kernel_size + 2
  990. )
  991. # (batch_size, seq_len, fine_window_size, fine_kernel_size + 2, fine_kernel_size + 2) -> (batch_size, seq_len, 3, 3)
  992. fine_confidence = fine_confidence[batch_indices, matches_indices, indices_0, indices_1_i, indices_1_j]
  993. fine_confidence = fine_confidence.reshape(batch_size, num_keypoints, 9)
  994. fine_confidence = nn.functional.softmax(
  995. fine_confidence / self.config.fine_matching_regress_temperature, dim=-1
  996. )
  997. heatmap = fine_confidence.reshape(batch_size, num_keypoints, 3, 3)
  998. fine_coordinates_normalized = spatial_expectation2d(heatmap, True)[0]
  999. fine_matches_0 = fine_matches[:, 0]
  1000. fine_matches_1 = fine_matches[:, 1] + (fine_coordinates_normalized * (3 // 2) * fine_scale)
  1001. fine_matches = torch.stack([fine_matches_0, fine_matches_1], dim=1)
  1002. return fine_matches
  1003. def _fine_matching(
  1004. self,
  1005. fine_features_0: torch.Tensor,
  1006. fine_features_1: torch.Tensor,
  1007. coarse_matched_keypoints: torch.Tensor,
  1008. fine_scale: float,
  1009. ) -> torch.Tensor:
  1010. """
  1011. For each coarse pixel with a corresponding window of fine features, compute the matching confidence between fine
  1012. features in the first image and the second image.
  1013. Fine features are sliced in two part :
  1014. - The first part used for the first stage are the first fine_hidden_size - config.fine_matching_slicedim (64 - 8
  1015. = 56 by default) features.
  1016. - The second part used for the second stage are the last config.fine_matching_slicedim (8 by default) features.
  1017. Each part is used to compute a fine confidence tensor of the following shape :
  1018. (batch_size, (coarse_height * coarse_width), fine_window_size, fine_window_size)
  1019. They correspond to the score between each fine pixel in the first image and each fine pixel in the second image.
  1020. Args:
  1021. fine_features_0 (`torch.Tensor` of shape `(num_matches, fine_kernel_size ** 2, fine_kernel_size ** 2)`):
  1022. Fine features from the first image
  1023. fine_features_1 (`torch.Tensor` of shape `(num_matches, (fine_kernel_size + 2) ** 2, (fine_kernel_size + 2)
  1024. ** 2)`):
  1025. Fine features from the second image
  1026. coarse_matched_keypoints (`torch.Tensor` of shape `(2, num_matches, 2)`):
  1027. Keypoint coordinates found in coarse matching for the first and second image
  1028. fine_scale (`int`):
  1029. Scale between the size of fine features and coarse features
  1030. Returns:
  1031. fine_coordinates (`torch.Tensor` of shape `(2, num_matches, 2)`):
  1032. Matched keypoint between the first and the second image. All matched keypoints are concatenated in the
  1033. second dimension.
  1034. """
  1035. batch_size, num_keypoints, fine_window_size, fine_embed_dim = fine_features_0.shape
  1036. fine_matching_slice_dim = self.config.fine_matching_slice_dim
  1037. fine_kernel_size = torch_int(fine_window_size**0.5)
  1038. # Split fine features into first and second stage features
  1039. split_fine_features_0 = torch.split(fine_features_0, fine_embed_dim - fine_matching_slice_dim, -1)
  1040. split_fine_features_1 = torch.split(fine_features_1, fine_embed_dim - fine_matching_slice_dim, -1)
  1041. # Retrieve first stage fine features
  1042. fine_features_0 = split_fine_features_0[0]
  1043. fine_features_1 = split_fine_features_1[0]
  1044. # Normalize first stage fine features
  1045. fine_features_0 = fine_features_0 / fine_features_0.shape[-1] ** 0.5
  1046. fine_features_1 = fine_features_1 / fine_features_1.shape[-1] ** 0.5
  1047. # Compute first stage confidence
  1048. fine_confidence = fine_features_0 @ fine_features_1.transpose(-1, -2)
  1049. fine_confidence = nn.functional.softmax(fine_confidence, 1) * nn.functional.softmax(fine_confidence, 2)
  1050. fine_confidence = fine_confidence.reshape(
  1051. batch_size, num_keypoints, fine_window_size, fine_kernel_size + 2, fine_kernel_size + 2
  1052. )
  1053. fine_confidence = fine_confidence[..., 1:-1, 1:-1]
  1054. first_stage_fine_confidence = fine_confidence.reshape(
  1055. batch_size, num_keypoints, fine_window_size, fine_window_size
  1056. )
  1057. fine_indices, fine_matches = self._get_first_stage_fine_matching(
  1058. first_stage_fine_confidence,
  1059. coarse_matched_keypoints,
  1060. fine_window_size,
  1061. fine_scale,
  1062. )
  1063. # Retrieve second stage fine features
  1064. fine_features_0 = split_fine_features_0[1]
  1065. fine_features_1 = split_fine_features_1[1]
  1066. # Normalize second stage fine features
  1067. fine_features_1 = fine_features_1 / fine_matching_slice_dim**0.5
  1068. # Compute second stage fine confidence
  1069. second_stage_fine_confidence = fine_features_0 @ fine_features_1.transpose(-1, -2)
  1070. fine_coordinates = self._get_second_stage_fine_matching(
  1071. fine_indices,
  1072. fine_matches,
  1073. second_stage_fine_confidence,
  1074. fine_window_size,
  1075. fine_scale,
  1076. )
  1077. return fine_coordinates
  1078. @auto_docstring
  1079. @can_return_tuple
  1080. def forward(
  1081. self,
  1082. pixel_values: torch.FloatTensor,
  1083. labels: torch.LongTensor | None = None,
  1084. **kwargs: Unpack[TransformersKwargs],
  1085. ) -> EfficientLoFTRKeypointMatchingOutput:
  1086. r"""
  1087. Examples:
  1088. ```python
  1089. >>> from transformers import AutoImageProcessor, AutoModel
  1090. >>> import torch
  1091. >>> from PIL import Image
  1092. >>> import httpx
  1093. >>> from io import BytesIO
  1094. >>> url = "https://github.com/magicleap/SuperGluePretrainedNetwork/blob/master/assets/phototourism_sample_images/london_bridge_78916675_4568141288.jpg?raw=true"
  1095. >>> with httpx.stream("GET", url) as response:
  1096. ... image1 = Image.open(BytesIO(response.read()))
  1097. >>> url = "https://github.com/magicleap/SuperGluePretrainedNetwork/blob/master/assets/phototourism_sample_images/london_bridge_19481797_2295892421.jpg?raw=true"
  1098. >>> with httpx.stream("GET", url) as response:
  1099. ... image2 = Image.open(BytesIO(response.read()))
  1100. >>> images = [image1, image2]
  1101. >>> processor = AutoImageProcessor.from_pretrained("zju-community/efficient_loftr")
  1102. >>> model = AutoModel.from_pretrained("zju-community/efficient_loftr")
  1103. >>> with torch.no_grad():
  1104. >>> inputs = processor(images, return_tensors="pt")
  1105. >>> outputs = model(**inputs)
  1106. ```"""
  1107. if labels is not None:
  1108. raise ValueError("SuperGlue is not trainable, no labels should be provided.")
  1109. # 1. Extract coarse and residual features
  1110. model_outputs: BackboneOutput = self.efficientloftr(pixel_values, **kwargs)
  1111. features = model_outputs.feature_maps
  1112. # 2. Compute coarse-level matching
  1113. coarse_features = features[0]
  1114. coarse_embed_dim, coarse_height, coarse_width = coarse_features.shape[-3:]
  1115. batch_size, _, channels, height, width = pixel_values.shape
  1116. coarse_scale = height / coarse_height
  1117. coarse_keypoints, coarse_matching_scores, coarse_matched_indices = self._coarse_matching(
  1118. coarse_features, coarse_scale
  1119. )
  1120. # 3. Fine-level refinement
  1121. residual_features = features[1:]
  1122. coarse_features = coarse_features / self.config.hidden_size**0.5
  1123. fine_features_0, fine_features_1 = self.refinement_layer(coarse_features, residual_features)
  1124. # Filter fine features with coarse matches indices
  1125. _, _, num_keypoints = coarse_matching_scores.shape
  1126. batch_indices = torch.arange(batch_size)[..., None]
  1127. fine_features_0 = fine_features_0[batch_indices, coarse_matched_indices[:, 0]]
  1128. fine_features_1 = fine_features_1[batch_indices, coarse_matched_indices[:, 1]]
  1129. # 4. Computer fine-level matching
  1130. fine_height = torch_int(coarse_height * coarse_scale)
  1131. fine_scale = height / fine_height
  1132. matching_keypoints = self._fine_matching(fine_features_0, fine_features_1, coarse_keypoints, fine_scale)
  1133. matching_keypoints[:, :, :, 0] = matching_keypoints[:, :, :, 0] / width
  1134. matching_keypoints[:, :, :, 1] = matching_keypoints[:, :, :, 1] / height
  1135. loss = None
  1136. return EfficientLoFTRKeypointMatchingOutput(
  1137. loss=loss,
  1138. matches=coarse_matched_indices,
  1139. matching_scores=coarse_matching_scores,
  1140. keypoints=matching_keypoints,
  1141. hidden_states=model_outputs.hidden_states,
  1142. attentions=model_outputs.attentions,
  1143. )
  1144. __all__ = ["EfficientLoFTRPreTrainedModel", "EfficientLoFTRModel", "EfficientLoFTRForKeypointMatching"]