| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390 |
- # Copyright 2025 The HuggingFace Team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from collections.abc import Callable
- from dataclasses import dataclass
- from typing import Optional
- import torch
- from torch import nn
- from ... import initialization as init
- from ...activations import ACT2CLS, ACT2FN
- from ...modeling_layers import GradientCheckpointingLayer
- from ...modeling_outputs import BackboneOutput
- from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
- from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
- from ...processing_utils import Unpack
- from ...pytorch_utils import compile_compatible_method_lru_cache
- from ...utils import (
- ModelOutput,
- TransformersKwargs,
- auto_docstring,
- can_return_tuple,
- torch_int,
- )
- from ...utils.generic import maybe_autocast, merge_with_config_defaults
- from ...utils.output_capturing import capture_outputs
- from .configuration_efficientloftr import EfficientLoFTRConfig
- @dataclass
- @auto_docstring(
- custom_intro="""
- Base class for outputs of EfficientLoFTR keypoint matching models. Due to the nature of keypoint detection and matching, the number
- of keypoints is not fixed and can vary from image to image, which makes batching non-trivial. In the batch of
- images, the maximum number of matches is set as the dimension of the matches and matching scores.
- """
- )
- class EfficientLoFTRKeypointMatchingOutput(ModelOutput):
- r"""
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*):
- Loss computed during training.
- matches (`torch.FloatTensor` of shape `(batch_size, 2, num_matches)`):
- Index of keypoint matched in the other image.
- matching_scores (`torch.FloatTensor` of shape `(batch_size, 2, num_matches)`):
- Scores of predicted matches.
- keypoints (`torch.FloatTensor` of shape `(batch_size, num_keypoints, 2)`):
- Absolute (x, y) coordinates of predicted keypoints in a given image.
- hidden_states (`tuple[torch.FloatTensor, ...]`, *optional*):
- Tuple of `torch.FloatTensor` (one for the output of each stage) of shape `(batch_size, 2, num_channels,
- num_keypoints)`, returned when `output_hidden_states=True` is passed or when
- `config.output_hidden_states=True`)
- attentions (`tuple[torch.FloatTensor, ...]`, *optional*):
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, 2, num_heads, num_keypoints,
- num_keypoints)`, returned when `output_attentions=True` is passed or when `config.output_attentions=True`)
- """
- loss: torch.FloatTensor | None = None
- matches: torch.FloatTensor | None = None
- matching_scores: torch.FloatTensor | None = None
- keypoints: torch.FloatTensor | None = None
- hidden_states: tuple[torch.FloatTensor] | None = None
- attentions: tuple[torch.FloatTensor] | None = None
- @compile_compatible_method_lru_cache(maxsize=32)
- def compute_embeddings(inv_freq: torch.Tensor, embed_height: int, embed_width: int, hidden_size: int) -> torch.Tensor:
- i_indices = torch.ones(embed_height, embed_width, dtype=inv_freq.dtype, device=inv_freq.device)
- j_indices = torch.ones(embed_height, embed_width, dtype=inv_freq.dtype, device=inv_freq.device)
- i_indices = i_indices.cumsum(0).unsqueeze(-1)
- j_indices = j_indices.cumsum(1).unsqueeze(-1)
- emb = torch.zeros(1, embed_height, embed_width, hidden_size // 2, dtype=inv_freq.dtype, device=inv_freq.device)
- emb[:, :, :, 0::2] = i_indices * inv_freq
- emb[:, :, :, 1::2] = j_indices * inv_freq
- return emb
- # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->EfficientLoFTR
- class EfficientLoFTRRotaryEmbedding(nn.Module):
- inv_freq: torch.Tensor # fix linting for `register_buffer`
- # Ignore copy
- def __init__(self, config: EfficientLoFTRConfig, device=None):
- super().__init__()
- self.config = config
- self.rope_type = self.config.rope_parameters["rope_type"]
- rope_init_fn: Callable = self.compute_default_rope_parameters
- if self.rope_type != "default":
- rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
- inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
- self.register_buffer("inv_freq", inv_freq, persistent=False)
- self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
- @staticmethod
- # Ignore copy
- def compute_default_rope_parameters(
- config: EfficientLoFTRConfig | None = None,
- device: Optional["torch.device"] = None,
- seq_len: int | None = None,
- ) -> tuple["torch.Tensor", float]:
- """
- Computes the inverse frequencies according to the original RoPE implementation
- Args:
- config ([`~transformers.PreTrainedConfig`]):
- The model configuration.
- device (`torch.device`):
- The device to use for initialization of the inverse frequencies.
- seq_len (`int`, *optional*):
- The current sequence length. Unused for this type of RoPE.
- Returns:
- Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
- post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
- """
- base = config.rope_parameters["rope_theta"]
- partial_rotary_factor = config.rope_parameters.get("partial_rotary_factor", 1.0)
- head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
- dim = int(head_dim * partial_rotary_factor)
- attention_factor = 1.0 # Unused in this type of RoPE
- # Compute the inverse frequencies
- inv_freq = 1.0 / (
- base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
- )
- return inv_freq, attention_factor
- # Ignore copy
- @torch.no_grad()
- def forward(
- self, x: torch.Tensor, position_ids: torch.LongTensor | None = None, layer_type=None
- ) -> tuple[torch.Tensor, torch.Tensor]:
- feats_height, feats_width = x.shape[-2:]
- embed_height = (feats_height - self.config.q_aggregation_kernel_size) // self.config.q_aggregation_stride + 1
- embed_width = (feats_width - self.config.q_aggregation_kernel_size) // self.config.q_aggregation_stride + 1
- device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
- with maybe_autocast(device_type=device_type, enabled=False): # Force float32
- emb = compute_embeddings(self.inv_freq, embed_height, embed_width, self.config.hidden_size)
- sin = emb.sin()
- cos = emb.cos()
- sin = sin.repeat_interleave(2, dim=-1)
- cos = cos.repeat_interleave(2, dim=-1)
- sin = sin.to(device=x.device, dtype=x.dtype)
- cos = cos.to(device=x.device, dtype=x.dtype)
- return cos, sin
- # Copied from transformers.models.rt_detr_v2.modeling_rt_detr_v2.RTDetrV2ConvNormLayer with RTDetrV2->EfficientLoFTR
- class EfficientLoFTRConvNormLayer(nn.Module):
- def __init__(self, config, in_channels, out_channels, kernel_size, stride, padding=None, activation=None):
- super().__init__()
- self.conv = nn.Conv2d(
- in_channels,
- out_channels,
- kernel_size,
- stride,
- padding=(kernel_size - 1) // 2 if padding is None else padding,
- bias=False,
- )
- self.norm = nn.BatchNorm2d(out_channels, config.batch_norm_eps)
- self.activation = nn.Identity() if activation is None else ACT2CLS[activation]()
- def forward(self, hidden_state):
- hidden_state = self.conv(hidden_state)
- hidden_state = self.norm(hidden_state)
- hidden_state = self.activation(hidden_state)
- return hidden_state
- class EfficientLoFTRRepVGGBlock(GradientCheckpointingLayer):
- """
- RepVGG architecture block introduced by the work "RepVGG: Making VGG-style ConvNets Great Again".
- """
- def __init__(self, config: EfficientLoFTRConfig, stage_idx: int, block_idx: int):
- super().__init__()
- in_channels = config.stage_block_in_channels[stage_idx][block_idx]
- out_channels = config.stage_block_out_channels[stage_idx][block_idx]
- stride = config.stage_block_stride[stage_idx][block_idx]
- activation = config.activation_function
- self.conv1 = EfficientLoFTRConvNormLayer(
- config, in_channels, out_channels, kernel_size=3, stride=stride, padding=1
- )
- self.conv2 = EfficientLoFTRConvNormLayer(
- config, in_channels, out_channels, kernel_size=1, stride=stride, padding=0
- )
- self.identity = nn.BatchNorm2d(in_channels) if in_channels == out_channels and stride == 1 else None
- self.activation = nn.Identity() if activation is None else ACT2FN[activation]
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- if self.identity is not None:
- identity_out = self.identity(hidden_states)
- else:
- identity_out = 0
- hidden_states = self.conv1(hidden_states) + self.conv2(hidden_states) + identity_out
- hidden_states = self.activation(hidden_states)
- return hidden_states
- class EfficientLoFTRRepVGGStage(nn.Module):
- def __init__(self, config: EfficientLoFTRConfig, stage_idx: int):
- super().__init__()
- self.blocks = nn.ModuleList([])
- for block_idx in range(config.stage_num_blocks[stage_idx]):
- self.blocks.append(
- EfficientLoFTRRepVGGBlock(
- config,
- stage_idx,
- block_idx,
- )
- )
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- for block in self.blocks:
- hidden_states = block(hidden_states)
- return hidden_states
- class EfficientLoFTRepVGG(nn.Module):
- def __init__(self, config: EfficientLoFTRConfig):
- super().__init__()
- self.stages = nn.ModuleList([])
- for stage_idx in range(len(config.stage_stride)):
- stage = EfficientLoFTRRepVGGStage(config, stage_idx)
- self.stages.append(stage)
- def forward(self, hidden_states: torch.Tensor) -> list[torch.Tensor]:
- outputs = []
- for stage in self.stages:
- hidden_states = stage(hidden_states)
- outputs.append(hidden_states)
- # Exclude first stage in outputs
- outputs = outputs[1:]
- return outputs
- class EfficientLoFTRAggregationLayer(nn.Module):
- def __init__(self, config: EfficientLoFTRConfig):
- super().__init__()
- hidden_size = config.hidden_size
- self.q_aggregation = nn.Conv2d(
- hidden_size,
- hidden_size,
- kernel_size=config.q_aggregation_kernel_size,
- padding=0,
- stride=config.q_aggregation_stride,
- bias=False,
- groups=hidden_size,
- )
- self.kv_aggregation = torch.nn.MaxPool2d(
- kernel_size=config.kv_aggregation_kernel_size, stride=config.kv_aggregation_stride
- )
- self.norm = nn.LayerNorm(hidden_size)
- def forward(
- self,
- hidden_states: torch.Tensor,
- encoder_hidden_states: torch.Tensor | None = None,
- ) -> tuple[torch.Tensor, torch.Tensor]:
- query_states = hidden_states
- is_cross_attention = encoder_hidden_states is not None
- kv_states = encoder_hidden_states if is_cross_attention else hidden_states
- query_states = self.q_aggregation(query_states)
- kv_states = self.kv_aggregation(kv_states)
- query_states = query_states.permute(0, 2, 3, 1)
- kv_states = kv_states.permute(0, 2, 3, 1)
- hidden_states = self.norm(query_states)
- encoder_hidden_states = self.norm(kv_states)
- return hidden_states, encoder_hidden_states
- # Copied from transformers.models.cohere.modeling_cohere.rotate_half
- def rotate_half(x):
- # Split and rotate. Note that this function is different from e.g. Llama.
- x1 = x[..., ::2]
- x2 = x[..., 1::2]
- rot_x = torch.stack([-x2, x1], dim=-1).flatten(-2)
- return rot_x
- # Copied from transformers.models.cohere.modeling_cohere.apply_rotary_pos_emb
- def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
- """Applies Rotary Position Embedding to the query and key tensors.
- Args:
- q (`torch.Tensor`): The query tensor.
- k (`torch.Tensor`): The key tensor.
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
- sin (`torch.Tensor`): The sine part of the rotary embedding.
- unsqueeze_dim (`int`, *optional*, defaults to 1):
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
- Returns:
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
- """
- dtype = q.dtype
- q = q.float()
- k = k.float()
- cos = cos.unsqueeze(unsqueeze_dim)
- sin = sin.unsqueeze(unsqueeze_dim)
- q_embed = (q * cos) + (rotate_half(q) * sin)
- k_embed = (k * cos) + (rotate_half(k) * sin)
- return q_embed.to(dtype=dtype), k_embed.to(dtype=dtype)
- # Copied from transformers.models.cohere.modeling_cohere.repeat_kv
- def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
- """
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
- """
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
- if n_rep == 1:
- return hidden_states
- hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
- # Copied from transformers.models.llama.modeling_llama.eager_attention_forward
- def eager_attention_forward(
- module: nn.Module,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- attention_mask: torch.Tensor | None,
- scaling: float,
- dropout: float = 0.0,
- **kwargs: Unpack[TransformersKwargs],
- ):
- key_states = repeat_kv(key, module.num_key_value_groups)
- value_states = repeat_kv(value, module.num_key_value_groups)
- attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
- if attention_mask is not None:
- attn_weights = attn_weights + attention_mask
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
- attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
- attn_output = torch.matmul(attn_weights, value_states)
- attn_output = attn_output.transpose(1, 2).contiguous()
- return attn_output, attn_weights
- class EfficientLoFTRAttention(nn.Module):
- """Multi-headed attention from 'Attention Is All You Need' paper"""
- def __init__(self, config: EfficientLoFTRConfig, layer_idx: int):
- super().__init__()
- self.config = config
- self.layer_idx = layer_idx
- self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
- self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
- self.scaling = self.head_dim**-0.5
- self.attention_dropout = config.attention_dropout
- self.is_causal = False
- self.q_proj = nn.Linear(
- config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
- )
- self.k_proj = nn.Linear(
- config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
- )
- self.v_proj = nn.Linear(
- config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
- )
- self.o_proj = nn.Linear(
- config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
- )
- def forward(
- self,
- hidden_states: torch.Tensor,
- encoder_hidden_states: torch.Tensor | None = None,
- position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple[torch.Tensor, torch.Tensor | None]:
- batch_size, seq_len, dim = hidden_states.shape
- input_shape = hidden_states.shape[:-1]
- query_states = self.q_proj(hidden_states).view(batch_size, seq_len, -1, dim)
- current_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
- key_states = self.k_proj(current_states).view(batch_size, seq_len, -1, dim)
- value_states = self.v_proj(current_states).view(batch_size, seq_len, -1, self.head_dim).transpose(1, 2)
- if position_embeddings is not None:
- cos, sin = position_embeddings
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, unsqueeze_dim=2)
- query_states = query_states.view(batch_size, seq_len, -1, self.head_dim).transpose(1, 2)
- key_states = key_states.view(batch_size, seq_len, -1, self.head_dim).transpose(1, 2)
- attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
- self.config._attn_implementation, eager_attention_forward
- )
- attn_output, attn_weights = attention_interface(
- self,
- query_states,
- key_states,
- value_states,
- attention_mask=None,
- dropout=0.0 if not self.training else self.attention_dropout,
- scaling=self.scaling,
- **kwargs,
- )
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
- attn_output = self.o_proj(attn_output)
- return attn_output, attn_weights
- class EfficientLoFTRMLP(nn.Module):
- def __init__(self, config: EfficientLoFTRConfig):
- super().__init__()
- hidden_size = config.hidden_size
- intermediate_size = config.intermediate_size
- self.fc1 = nn.Linear(hidden_size * 2, intermediate_size, bias=False)
- self.activation = ACT2FN[config.mlp_activation_function]
- self.fc2 = nn.Linear(intermediate_size, hidden_size, bias=False)
- self.layer_norm = nn.LayerNorm(hidden_size)
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- hidden_states = self.fc1(hidden_states)
- hidden_states = self.activation(hidden_states)
- hidden_states = self.fc2(hidden_states)
- hidden_states = self.layer_norm(hidden_states)
- return hidden_states
- class EfficientLoFTRAggregatedAttention(nn.Module):
- def __init__(self, config: EfficientLoFTRConfig, layer_idx: int):
- super().__init__()
- self.q_aggregation_kernel_size = config.q_aggregation_kernel_size
- self.aggregation = EfficientLoFTRAggregationLayer(config)
- self.attention = EfficientLoFTRAttention(config, layer_idx)
- self.mlp = EfficientLoFTRMLP(config)
- def forward(
- self,
- hidden_states: torch.Tensor,
- encoder_hidden_states: torch.Tensor | None = None,
- position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> torch.Tensor:
- batch_size, embed_dim, _, _ = hidden_states.shape
- # Aggregate features
- aggregated_hidden_states, aggregated_encoder_hidden_states = self.aggregation(
- hidden_states, encoder_hidden_states
- )
- _, aggregated_h, aggregated_w, _ = aggregated_hidden_states.shape
- # Multi-head attention
- aggregated_hidden_states = aggregated_hidden_states.reshape(batch_size, -1, embed_dim)
- aggregated_encoder_hidden_states = aggregated_encoder_hidden_states.reshape(batch_size, -1, embed_dim)
- attn_output, _ = self.attention(
- aggregated_hidden_states,
- aggregated_encoder_hidden_states,
- position_embeddings=position_embeddings,
- **kwargs,
- )
- # Upsample features
- # (batch_size, seq_len, embed_dim) -> (batch_size, embed_dim, h, w) with seq_len = h * w
- attn_output = attn_output.permute(0, 2, 1)
- attn_output = attn_output.reshape(batch_size, embed_dim, aggregated_h, aggregated_w)
- attn_output = torch.nn.functional.interpolate(
- attn_output, scale_factor=self.q_aggregation_kernel_size, mode="bilinear", align_corners=False
- )
- intermediate_states = torch.cat([hidden_states, attn_output], dim=1)
- intermediate_states = intermediate_states.permute(0, 2, 3, 1)
- output_states = self.mlp(intermediate_states)
- output_states = output_states.permute(0, 3, 1, 2)
- hidden_states = hidden_states + output_states
- return hidden_states
- class EfficientLoFTRLocalFeatureTransformerLayer(GradientCheckpointingLayer):
- def __init__(self, config: EfficientLoFTRConfig, layer_idx: int):
- super().__init__()
- self.self_attention = EfficientLoFTRAggregatedAttention(config, layer_idx)
- self.cross_attention = EfficientLoFTRAggregatedAttention(config, layer_idx)
- def forward(
- self,
- hidden_states: torch.Tensor,
- position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> torch.Tensor:
- batch_size, _, embed_dim, height, width = hidden_states.shape
- hidden_states = hidden_states.reshape(-1, embed_dim, height, width)
- hidden_states = self.self_attention(hidden_states, position_embeddings=position_embeddings, **kwargs)
- ###
- # Implementation of a bug in the original implementation regarding the cross-attention
- # See : https://github.com/zju3dv/MatchAnything/issues/26
- hidden_states = hidden_states.reshape(-1, 2, embed_dim, height, width)
- features_0 = hidden_states[:, 0]
- features_1 = hidden_states[:, 1]
- features_0 = self.cross_attention(features_0, features_1, **kwargs)
- features_1 = self.cross_attention(features_1, features_0, **kwargs)
- hidden_states = torch.stack((features_0, features_1), dim=1)
- ###
- return hidden_states
- class EfficientLoFTRLocalFeatureTransformer(nn.Module):
- def __init__(self, config: EfficientLoFTRConfig):
- super().__init__()
- self.layers = nn.ModuleList(
- [
- EfficientLoFTRLocalFeatureTransformerLayer(config, layer_idx=i)
- for i in range(config.num_attention_layers)
- ]
- )
- def forward(
- self,
- hidden_states: torch.Tensor,
- position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> torch.Tensor:
- for layer in self.layers:
- hidden_states = layer(hidden_states, position_embeddings=position_embeddings, **kwargs)
- return hidden_states
- class EfficientLoFTROutConvBlock(nn.Module):
- def __init__(self, config: EfficientLoFTRConfig, hidden_size: int, intermediate_size: int):
- super().__init__()
- self.out_conv1 = nn.Conv2d(hidden_size, intermediate_size, kernel_size=1, stride=1, padding=0, bias=False)
- self.out_conv2 = nn.Conv2d(
- intermediate_size, intermediate_size, kernel_size=3, stride=1, padding=1, bias=False
- )
- self.batch_norm = nn.BatchNorm2d(intermediate_size)
- self.activation = ACT2CLS[config.mlp_activation_function]()
- self.out_conv3 = nn.Conv2d(intermediate_size, hidden_size, kernel_size=3, stride=1, padding=1, bias=False)
- def forward(self, hidden_states: torch.Tensor, residual_states: torch.Tensor) -> torch.Tensor:
- residual_states = self.out_conv1(residual_states)
- residual_states = residual_states + hidden_states
- residual_states = self.out_conv2(residual_states)
- residual_states = self.batch_norm(residual_states)
- residual_states = self.activation(residual_states)
- residual_states = self.out_conv3(residual_states)
- residual_states = nn.functional.interpolate(
- residual_states, scale_factor=2.0, mode="bilinear", align_corners=False
- )
- return residual_states
- class EfficientLoFTRFineFusionLayer(nn.Module):
- def __init__(self, config: EfficientLoFTRConfig):
- super().__init__()
- self.fine_kernel_size = config.fine_kernel_size
- fine_fusion_dims = config.fine_fusion_dims
- self.out_conv = nn.Conv2d(
- fine_fusion_dims[0], fine_fusion_dims[0], kernel_size=1, stride=1, padding=0, bias=False
- )
- self.out_conv_layers = nn.ModuleList()
- for i in range(1, len(fine_fusion_dims)):
- out_conv = EfficientLoFTROutConvBlock(config, fine_fusion_dims[i], fine_fusion_dims[i - 1])
- self.out_conv_layers.append(out_conv)
- def forward_pyramid(
- self,
- hidden_states: torch.Tensor,
- residual_states: list[torch.Tensor],
- ) -> torch.Tensor:
- hidden_states = self.out_conv(hidden_states)
- hidden_states = nn.functional.interpolate(
- hidden_states, scale_factor=2.0, mode="bilinear", align_corners=False
- )
- for i, layer in enumerate(self.out_conv_layers):
- hidden_states = layer(hidden_states, residual_states[i])
- return hidden_states
- def forward(
- self,
- coarse_features: torch.Tensor,
- residual_features: list[torch.Tensor] | tuple[torch.Tensor],
- ) -> tuple[torch.Tensor, torch.Tensor]:
- """
- For each image pair, compute the fine features of pixels.
- In both images, compute a patch of fine features center cropped around each coarse pixel.
- In the first image, the feature patch is kernel_size large and long.
- In the second image, it is (kernel_size + 2) large and long.
- """
- batch_size, _, embed_dim, coarse_height, coarse_width = coarse_features.shape
- coarse_features = coarse_features.reshape(-1, embed_dim, coarse_height, coarse_width)
- residual_features = list(reversed(residual_features))
- # 1. Fine feature extraction
- fine_features = self.forward_pyramid(coarse_features, residual_features)
- _, fine_embed_dim, fine_height, fine_width = fine_features.shape
- fine_features = fine_features.reshape(batch_size, 2, fine_embed_dim, fine_height, fine_width)
- fine_features_0 = fine_features[:, 0]
- fine_features_1 = fine_features[:, 1]
- # 2. Unfold all local windows in crops
- stride = int(fine_height // coarse_height)
- fine_features_0 = nn.functional.unfold(
- fine_features_0, kernel_size=self.fine_kernel_size, stride=stride, padding=0
- )
- _, _, seq_len = fine_features_0.shape
- fine_features_0 = fine_features_0.reshape(batch_size, -1, self.fine_kernel_size**2, seq_len)
- fine_features_0 = fine_features_0.permute(0, 3, 2, 1)
- fine_features_1 = nn.functional.unfold(
- fine_features_1, kernel_size=self.fine_kernel_size + 2, stride=stride, padding=1
- )
- fine_features_1 = fine_features_1.reshape(batch_size, -1, (self.fine_kernel_size + 2) ** 2, seq_len)
- fine_features_1 = fine_features_1.permute(0, 3, 2, 1)
- return fine_features_0, fine_features_1
- @auto_docstring
- class EfficientLoFTRPreTrainedModel(PreTrainedModel):
- """
- An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
- models.
- """
- config_class = EfficientLoFTRConfig
- base_model_prefix = "efficientloftr"
- main_input_name = "pixel_values"
- input_modalities = ("image",)
- supports_gradient_checkpointing = True
- _supports_flash_attn = True
- _supports_sdpa = True
- _can_record_outputs = {
- "hidden_states": EfficientLoFTRRepVGGBlock,
- "attentions": EfficientLoFTRAttention,
- }
- @torch.no_grad()
- def _init_weights(self, module: nn.Module) -> None:
- """Initialize the weights"""
- if isinstance(module, (nn.Linear, nn.Conv2d, nn.Conv1d, nn.BatchNorm2d)):
- init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
- if module.bias is not None:
- init.zeros_(module.bias)
- if getattr(module, "running_mean", None) is not None:
- init.zeros_(module.running_mean)
- init.ones_(module.running_var)
- init.zeros_(module.num_batches_tracked)
- elif isinstance(module, nn.LayerNorm):
- init.zeros_(module.bias)
- init.ones_(module.weight)
- elif isinstance(module, EfficientLoFTRRotaryEmbedding):
- rope_fn = (
- ROPE_INIT_FUNCTIONS[module.rope_type]
- if module.rope_type != "default"
- else module.compute_default_rope_parameters
- )
- buffer_value, _ = rope_fn(module.config)
- init.copy_(module.inv_freq, buffer_value)
- init.copy_(module.original_inv_freq, buffer_value)
- # Copied from transformers.models.superpoint.modeling_superpoint.SuperPointPreTrainedModel.extract_one_channel_pixel_values with SuperPoint->EfficientLoFTR
- def extract_one_channel_pixel_values(self, pixel_values: torch.FloatTensor) -> torch.FloatTensor:
- """
- Assuming pixel_values has shape (batch_size, 3, height, width), and that all channels values are the same,
- extract the first channel value to get a tensor of shape (batch_size, 1, height, width) for EfficientLoFTR. This is
- a workaround for the issue discussed in :
- https://github.com/huggingface/transformers/pull/25786#issuecomment-1730176446
- Args:
- pixel_values: torch.FloatTensor of shape (batch_size, 3, height, width)
- Returns:
- pixel_values: torch.FloatTensor of shape (batch_size, 1, height, width)
- """
- return pixel_values[:, 0, :, :][:, None, :, :]
- @auto_docstring(
- custom_intro="""
- EfficientLoFTR model taking images as inputs and outputting the features of the images.
- """
- )
- class EfficientLoFTRModel(EfficientLoFTRPreTrainedModel):
- def __init__(self, config: EfficientLoFTRConfig):
- super().__init__(config)
- self.config = config
- self.backbone = EfficientLoFTRepVGG(config)
- self.local_feature_transformer = EfficientLoFTRLocalFeatureTransformer(config)
- self.rotary_emb = EfficientLoFTRRotaryEmbedding(config=config)
- self.post_init()
- @merge_with_config_defaults
- @capture_outputs
- @auto_docstring
- def forward(
- self,
- pixel_values: torch.FloatTensor,
- labels: torch.LongTensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> BackboneOutput:
- r"""
- Examples:
- ```python
- >>> from transformers import AutoImageProcessor, AutoModel
- >>> import torch
- >>> from PIL import Image
- >>> import httpx
- >>> from io import BytesIO
- >>> url = "https://github.com/magicleap/SuperGluePretrainedNetwork/blob/master/assets/phototourism_sample_images/london_bridge_78916675_4568141288.jpg?raw=true"
- >>> with httpx.stream("GET", url) as response:
- ... image1 = Image.open(BytesIO(response.read()))
- >>> url = "https://github.com/magicleap/SuperGluePretrainedNetwork/blob/master/assets/phototourism_sample_images/london_bridge_19481797_2295892421.jpg?raw=true"
- >>> with httpx.stream("GET", url) as response:
- ... image2 = Image.open(BytesIO(response.read()))
- >>> images = [image1, image2]
- >>> processor = AutoImageProcessor.from_pretrained("zju-community/efficient_loftr")
- >>> model = AutoModel.from_pretrained("zju-community/efficient_loftr")
- >>> with torch.no_grad():
- >>> inputs = processor(images, return_tensors="pt")
- >>> outputs = model(**inputs)
- ```"""
- if labels is not None:
- raise ValueError("EfficientLoFTR is not trainable, no labels should be provided.")
- if pixel_values.ndim != 5 or pixel_values.size(1) != 2:
- raise ValueError("Input must be a 5D tensor of shape (batch_size, 2, num_channels, height, width)")
- batch_size, _, channels, height, width = pixel_values.shape
- pixel_values = pixel_values.reshape(batch_size * 2, channels, height, width)
- pixel_values = self.extract_one_channel_pixel_values(pixel_values)
- # 1. Local Feature CNN
- features = self.backbone(pixel_values)
- # Last stage outputs are coarse outputs
- coarse_features = features[-1]
- # Rest is residual features used in EfficientLoFTRFineFusionLayer
- residual_features = features[:-1]
- coarse_embed_dim, coarse_height, coarse_width = coarse_features.shape[-3:]
- # 2. Coarse-level LoFTR module
- cos, sin = self.rotary_emb(coarse_features)
- cos = cos.expand(batch_size * 2, -1, -1, -1).reshape(batch_size * 2, -1, coarse_embed_dim)
- sin = sin.expand(batch_size * 2, -1, -1, -1).reshape(batch_size * 2, -1, coarse_embed_dim)
- position_embeddings = (cos, sin)
- coarse_features = coarse_features.reshape(batch_size, 2, coarse_embed_dim, coarse_height, coarse_width)
- coarse_features = self.local_feature_transformer(
- coarse_features, position_embeddings=position_embeddings, **kwargs
- )
- features = (coarse_features,) + tuple(residual_features)
- return BackboneOutput(feature_maps=features)
- def mask_border(tensor: torch.Tensor, border_margin: int, value: bool | float | int) -> torch.Tensor:
- """
- Mask a tensor border with a given value
- Args:
- tensor (`torch.Tensor` of shape `(batch_size, height_0, width_0, height_1, width_1)`):
- The tensor to mask
- border_margin (`int`) :
- The size of the border
- value (`Union[bool, int, float]`):
- The value to place in the tensor's borders
- Returns:
- tensor (`torch.Tensor` of shape `(batch_size, height_0, width_0, height_1, width_1)`):
- The masked tensor
- """
- if border_margin <= 0:
- return tensor
- tensor[:, :border_margin] = value
- tensor[:, :, :border_margin] = value
- tensor[:, :, :, :border_margin] = value
- tensor[:, :, :, :, :border_margin] = value
- tensor[:, -border_margin:] = value
- tensor[:, :, -border_margin:] = value
- tensor[:, :, :, -border_margin:] = value
- tensor[:, :, :, :, -border_margin:] = value
- return tensor
- def create_meshgrid(
- height: int | torch.Tensor,
- width: int | torch.Tensor,
- normalized_coordinates: bool = False,
- device: torch.device | None = None,
- dtype: torch.dtype | None = None,
- ) -> torch.Tensor:
- """
- Copied from kornia library : kornia/kornia/utils/grid.py:26
- Generate a coordinate grid for an image.
- When the flag ``normalized_coordinates`` is set to True, the grid is
- normalized to be in the range :math:`[-1,1]` to be consistent with the pytorch
- function :py:func:`torch.nn.functional.grid_sample`.
- Args:
- height (`int`):
- The image height (rows).
- width (`int`):
- The image width (cols).
- normalized_coordinates (`bool`):
- Whether to normalize coordinates in the range :math:`[-1,1]` in order to be consistent with the
- PyTorch function :py:func:`torch.nn.functional.grid_sample`.
- device (`torch.device`):
- The device on which the grid will be generated.
- dtype (`torch.dtype`):
- The data type of the generated grid.
- Return:
- grid (`torch.Tensor` of shape `(1, height, width, 2)`):
- The grid tensor.
- Example:
- >>> create_meshgrid(2, 2)
- tensor([[[[-1., -1.],
- [ 1., -1.]],
- <BLANKLINE>
- [[-1., 1.],
- [ 1., 1.]]]])
- >>> create_meshgrid(2, 2, normalized_coordinates=False)
- tensor([[[[0., 0.],
- [1., 0.]],
- <BLANKLINE>
- [[0., 1.],
- [1., 1.]]]])
- """
- xs = torch.linspace(0, width - 1, width, device=device, dtype=dtype)
- ys = torch.linspace(0, height - 1, height, device=device, dtype=dtype)
- if normalized_coordinates:
- xs = (xs / (width - 1) - 0.5) * 2
- ys = (ys / (height - 1) - 0.5) * 2
- grid = torch.stack(torch.meshgrid(ys, xs, indexing="ij"), dim=-1)
- grid = grid.permute(1, 0, 2).unsqueeze(0)
- return grid
- def spatial_expectation2d(input: torch.Tensor, normalized_coordinates: bool = True) -> torch.Tensor:
- r"""
- Copied from kornia library : kornia/geometry/subpix/dsnt.py:76
- Compute the expectation of coordinate values using spatial probabilities.
- The input heatmap is assumed to represent a valid spatial probability distribution,
- which can be achieved using :func:`~kornia.geometry.subpixel.spatial_softmax2d`.
- Args:
- input (`torch.Tensor` of shape `(batch_size, embed_dim, height, width)`):
- The input tensor representing dense spatial probabilities.
- normalized_coordinates (`bool`):
- Whether to return the coordinates normalized in the range of :math:`[-1, 1]`. Otherwise, it will return
- the coordinates in the range of the input shape.
- Returns:
- output (`torch.Tensor` of shape `(batch_size, embed_dim, 2)`)
- Expected value of the 2D coordinates. Output order of the coordinates is (x, y).
- Examples:
- >>> heatmaps = torch.tensor([[[
- ... [0., 0., 0.],
- ... [0., 0., 0.],
- ... [0., 1., 0.]]]])
- >>> spatial_expectation2d(heatmaps, False)
- tensor([[[1., 2.]]])
- """
- batch_size, embed_dim, height, width = input.shape
- # Create coordinates grid.
- grid = create_meshgrid(height, width, normalized_coordinates, input.device)
- grid = grid.to(input.dtype)
- pos_x = grid[..., 0].reshape(-1)
- pos_y = grid[..., 1].reshape(-1)
- input_flat = input.view(batch_size, embed_dim, -1)
- # Compute the expectation of the coordinates.
- expected_y = torch.sum(pos_y * input_flat, -1, keepdim=True)
- expected_x = torch.sum(pos_x * input_flat, -1, keepdim=True)
- output = torch.cat([expected_x, expected_y], -1)
- return output.view(batch_size, embed_dim, 2)
- @auto_docstring(
- custom_intro="""
- EfficientLoFTR model taking images as inputs and outputting the matching of them.
- """
- )
- class EfficientLoFTRForKeypointMatching(EfficientLoFTRPreTrainedModel):
- """EfficientLoFTR dense image matcher
- Given two images, we determine the correspondences by:
- 1. Extracting coarse and fine features through a backbone
- 2. Transforming coarse features through self and cross attention
- 3. Matching coarse features to obtain coarse coordinates of matches
- 4. Obtaining full resolution fine features by fusing transformed and backbone coarse features
- 5. Refining the coarse matches using fine feature patches centered at each coarse match in a two-stage refinement
- Yifan Wang, Xingyi He, Sida Peng, Dongli Tan and Xiaowei Zhou.
- Efficient LoFTR: Semi-Dense Local Feature Matching with Sparse-Like Speed
- In CVPR, 2024. https://huggingface.co/papers/2403.04765
- """
- def __init__(self, config: EfficientLoFTRConfig):
- super().__init__(config)
- self.config = config
- self.efficientloftr = EfficientLoFTRModel(config)
- self.refinement_layer = EfficientLoFTRFineFusionLayer(config)
- self.post_init()
- def _get_matches_from_scores(self, scores: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
- """
- Based on a keypoint score matrix, compute the best keypoint matches between the first and second image.
- Since each image pair can have different number of matches, the matches are concatenated together for all pair
- in the batch and a batch_indices tensor is returned to specify which match belong to which element in the batch.
- Note:
- This step can be done as a postprocessing step, because does not involve any model weights/params.
- However, we keep it in the modeling code for consistency with other keypoint matching models AND for
- easier torch.compile/torch.export (all ops are in torch).
- Args:
- scores (`torch.Tensor` of shape `(batch_size, height_0, width_0, height_1, width_1)`):
- Scores of keypoints
- Returns:
- matched_indices (`torch.Tensor` of shape `(2, num_matches)`):
- Indices representing which pixel in the first image matches which pixel in the second image
- matching_scores (`torch.Tensor` of shape `(num_matches,)`):
- Scores of each match
- """
- batch_size, height0, width0, height1, width1 = scores.shape
- scores = scores.view(batch_size, height0 * width0, height1 * width1)
- # For each keypoint, get the best match
- max_0 = scores.max(2, keepdim=True).values
- max_1 = scores.max(1, keepdim=True).values
- # 1. Thresholding
- mask = scores > self.config.coarse_matching_threshold
- # 2. Border removal
- mask = mask.reshape(batch_size, height0, width0, height1, width1)
- mask = mask_border(mask, self.config.coarse_matching_border_removal, False)
- mask = mask.reshape(batch_size, height0 * width0, height1 * width1)
- # 3. Mutual nearest neighbors
- mask = mask * (scores == max_0) * (scores == max_1)
- # 4. Fine coarse matches
- masked_scores = scores * mask
- matching_scores_0, max_indices_0 = masked_scores.max(1)
- matching_scores_1, max_indices_1 = masked_scores.max(2)
- matching_indices = torch.cat([max_indices_0, max_indices_1]).reshape(batch_size, 2, -1)
- matching_scores = torch.stack([matching_scores_0, matching_scores_1], dim=1)
- # For the keypoints not meeting the threshold score, set the indices to -1 which corresponds to no matches found
- matching_indices = torch.where(matching_scores > 0, matching_indices, -1)
- return matching_indices, matching_scores
- def _coarse_matching(
- self, coarse_features: torch.Tensor, coarse_scale: float
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- """
- For each image pair, compute the matching confidence between each coarse element (by default (image_height / 8)
- * (image_width / 8 elements)) from the first image to the second image.
- Note:
- This step can be done as a postprocessing step, because does not involve any model weights/params.
- However, we keep it in the modeling code for consistency with other keypoint matching models AND for
- easier torch.compile/torch.export (all ops are in torch).
- Args:
- coarse_features (`torch.Tensor` of shape `(batch_size, 2, hidden_size, coarse_height, coarse_width)`):
- Coarse features
- coarse_scale (`float`): Scale between the image size and the coarse size
- Returns:
- keypoints (`torch.Tensor` of shape `(batch_size, 2, num_matches, 2)`):
- Keypoints coordinates.
- matching_scores (`torch.Tensor` of shape `(batch_size, 2, num_matches)`):
- The confidence matching score of each keypoint.
- matched_indices (`torch.Tensor` of shape `(batch_size, 2, num_matches)`):
- Indices which indicates which keypoint in an image matched with which keypoint in the other image. For
- both image in the pair.
- """
- batch_size, _, embed_dim, height, width = coarse_features.shape
- # (batch_size, 2, embed_dim, height, width) -> (batch_size, 2, height * width, embed_dim)
- coarse_features = coarse_features.permute(0, 1, 3, 4, 2)
- coarse_features = coarse_features.reshape(batch_size, 2, -1, embed_dim)
- coarse_features = coarse_features / coarse_features.shape[-1] ** 0.5
- coarse_features_0 = coarse_features[:, 0]
- coarse_features_1 = coarse_features[:, 1]
- similarity = coarse_features_0 @ coarse_features_1.transpose(-1, -2)
- similarity = similarity / self.config.coarse_matching_temperature
- if self.config.coarse_matching_skip_softmax:
- confidence = similarity
- else:
- confidence = nn.functional.softmax(similarity, 1) * nn.functional.softmax(similarity, 2)
- confidence = confidence.view(batch_size, height, width, height, width)
- matched_indices, matching_scores = self._get_matches_from_scores(confidence)
- keypoints = torch.stack([matched_indices % width, matched_indices // width], dim=-1) * coarse_scale
- return keypoints, matching_scores, matched_indices
- def _get_first_stage_fine_matching(
- self,
- fine_confidence: torch.Tensor,
- coarse_matched_keypoints: torch.Tensor,
- fine_window_size: int,
- fine_scale: float,
- ) -> tuple[torch.Tensor, torch.Tensor]:
- """
- For each coarse pixel, retrieve the highest fine confidence score and index.
- The index represents the matching between a pixel position in the fine window in the first image and a pixel
- position in the fine window of the second image.
- For example, for a fine_window_size of 64 (8 * 8), the index 2474 represents the matching between the index 38
- (2474 // 64) in the fine window of the first image, and the index 42 in the second image. This means that 38
- which corresponds to the position (4, 6) (4 // 8 and 4 % 8) is matched with the position (5, 2). In this example
- the coarse matched coordinate will be shifted to the matched fine coordinates in the first and second image.
- Note:
- This step can be done as a postprocessing step, because does not involve any model weights/params.
- However, we keep it in the modeling code for consistency with other keypoint matching models AND for
- easier torch.compile/torch.export (all ops are in torch).
- Args:
- fine_confidence (`torch.Tensor` of shape `(num_matches, fine_window_size, fine_window_size)`):
- First stage confidence of matching fine features between the first and the second image
- coarse_matched_keypoints (`torch.Tensor` of shape `(2, num_matches, 2)`):
- Coarse matched keypoint between the first and the second image.
- fine_window_size (`int`):
- Size of the window used to refine matches
- fine_scale (`float`):
- Scale between the size of fine features and coarse features
- Returns:
- indices (`torch.Tensor` of shape `(2, num_matches, 1)`):
- Indices of the fine coordinate matched in the fine window
- fine_matches (`torch.Tensor` of shape `(2, num_matches, 2)`):
- Coordinates of matched keypoints after the first fine stage
- """
- batch_size, num_keypoints, _, _ = fine_confidence.shape
- fine_kernel_size = torch_int(fine_window_size**0.5)
- fine_confidence = fine_confidence.reshape(batch_size, num_keypoints, -1)
- values, indices = torch.max(fine_confidence, dim=-1)
- indices = indices[..., None]
- indices_0 = indices // fine_window_size
- indices_1 = indices % fine_window_size
- grid = create_meshgrid(
- fine_kernel_size,
- fine_kernel_size,
- normalized_coordinates=False,
- device=fine_confidence.device,
- dtype=fine_confidence.dtype,
- )
- grid = grid - (fine_kernel_size // 2) + 0.5
- grid = grid.reshape(1, 1, -1, 2).expand(batch_size, num_keypoints, -1, -1)
- delta_0 = torch.gather(grid, 1, indices_0.unsqueeze(-1).expand(-1, -1, -1, 2)).squeeze(2)
- delta_1 = torch.gather(grid, 1, indices_1.unsqueeze(-1).expand(-1, -1, -1, 2)).squeeze(2)
- fine_matches_0 = coarse_matched_keypoints[:, 0] + delta_0 * fine_scale
- fine_matches_1 = coarse_matched_keypoints[:, 1] + delta_1 * fine_scale
- indices = torch.stack([indices_0, indices_1], dim=1)
- fine_matches = torch.stack([fine_matches_0, fine_matches_1], dim=1)
- return indices, fine_matches
- def _get_second_stage_fine_matching(
- self,
- indices: torch.Tensor,
- fine_matches: torch.Tensor,
- fine_confidence: torch.Tensor,
- fine_window_size: int,
- fine_scale: float,
- ) -> torch.Tensor:
- """
- For the given position in their respective fine windows, retrieve the 3x3 fine confidences around this position.
- After applying softmax to these confidences, compute the 2D spatial expected coordinates.
- Shift the first stage fine matching with these expected coordinates.
- Note:
- This step can be done as a postprocessing step, because does not involve any model weights/params.
- However, we keep it in the modeling code for consistency with other keypoint matching models AND for
- easier torch.compile/torch.export (all ops are in torch).
- Args:
- indices (`torch.Tensor` of shape `(batch_size, 2, num_keypoints)`):
- Indices representing the position of each keypoint in the fine window
- fine_matches (`torch.Tensor` of shape `(2, num_matches, 2)`):
- Coordinates of matched keypoints after the first fine stage
- fine_confidence (`torch.Tensor` of shape `(num_matches, fine_window_size, fine_window_size)`):
- Second stage confidence of matching fine features between the first and the second image
- fine_window_size (`int`):
- Size of the window used to refine matches
- fine_scale (`float`):
- Scale between the size of fine features and coarse features
- Returns:
- fine_matches (`torch.Tensor` of shape `(2, num_matches, 2)`):
- Coordinates of matched keypoints after the second fine stage
- """
- batch_size, num_keypoints, _, _ = fine_confidence.shape
- fine_kernel_size = torch_int(fine_window_size**0.5)
- indices_0 = indices[:, 0]
- indices_1 = indices[:, 1]
- indices_1_i = indices_1 // fine_kernel_size
- indices_1_j = indices_1 % fine_kernel_size
- # matches_indices, indices_0, indices_1_i, indices_1_j of shape (num_matches, 3, 3)
- batch_indices = torch.arange(batch_size, device=indices_0.device).reshape(batch_size, 1, 1, 1)
- matches_indices = torch.arange(num_keypoints, device=indices_0.device).reshape(1, num_keypoints, 1, 1)
- indices_0 = indices_0[..., None]
- indices_1_i = indices_1_i[..., None]
- indices_1_j = indices_1_j[..., None]
- delta = create_meshgrid(3, 3, normalized_coordinates=True, device=indices_0.device).to(torch.long)
- delta = delta[None, ...]
- indices_1_i = indices_1_i + delta[..., 1]
- indices_1_j = indices_1_j + delta[..., 0]
- fine_confidence = fine_confidence.reshape(
- batch_size, num_keypoints, fine_window_size, fine_kernel_size + 2, fine_kernel_size + 2
- )
- # (batch_size, seq_len, fine_window_size, fine_kernel_size + 2, fine_kernel_size + 2) -> (batch_size, seq_len, 3, 3)
- fine_confidence = fine_confidence[batch_indices, matches_indices, indices_0, indices_1_i, indices_1_j]
- fine_confidence = fine_confidence.reshape(batch_size, num_keypoints, 9)
- fine_confidence = nn.functional.softmax(
- fine_confidence / self.config.fine_matching_regress_temperature, dim=-1
- )
- heatmap = fine_confidence.reshape(batch_size, num_keypoints, 3, 3)
- fine_coordinates_normalized = spatial_expectation2d(heatmap, True)[0]
- fine_matches_0 = fine_matches[:, 0]
- fine_matches_1 = fine_matches[:, 1] + (fine_coordinates_normalized * (3 // 2) * fine_scale)
- fine_matches = torch.stack([fine_matches_0, fine_matches_1], dim=1)
- return fine_matches
- def _fine_matching(
- self,
- fine_features_0: torch.Tensor,
- fine_features_1: torch.Tensor,
- coarse_matched_keypoints: torch.Tensor,
- fine_scale: float,
- ) -> torch.Tensor:
- """
- For each coarse pixel with a corresponding window of fine features, compute the matching confidence between fine
- features in the first image and the second image.
- Fine features are sliced in two part :
- - The first part used for the first stage are the first fine_hidden_size - config.fine_matching_slicedim (64 - 8
- = 56 by default) features.
- - The second part used for the second stage are the last config.fine_matching_slicedim (8 by default) features.
- Each part is used to compute a fine confidence tensor of the following shape :
- (batch_size, (coarse_height * coarse_width), fine_window_size, fine_window_size)
- They correspond to the score between each fine pixel in the first image and each fine pixel in the second image.
- Args:
- fine_features_0 (`torch.Tensor` of shape `(num_matches, fine_kernel_size ** 2, fine_kernel_size ** 2)`):
- Fine features from the first image
- fine_features_1 (`torch.Tensor` of shape `(num_matches, (fine_kernel_size + 2) ** 2, (fine_kernel_size + 2)
- ** 2)`):
- Fine features from the second image
- coarse_matched_keypoints (`torch.Tensor` of shape `(2, num_matches, 2)`):
- Keypoint coordinates found in coarse matching for the first and second image
- fine_scale (`int`):
- Scale between the size of fine features and coarse features
- Returns:
- fine_coordinates (`torch.Tensor` of shape `(2, num_matches, 2)`):
- Matched keypoint between the first and the second image. All matched keypoints are concatenated in the
- second dimension.
- """
- batch_size, num_keypoints, fine_window_size, fine_embed_dim = fine_features_0.shape
- fine_matching_slice_dim = self.config.fine_matching_slice_dim
- fine_kernel_size = torch_int(fine_window_size**0.5)
- # Split fine features into first and second stage features
- split_fine_features_0 = torch.split(fine_features_0, fine_embed_dim - fine_matching_slice_dim, -1)
- split_fine_features_1 = torch.split(fine_features_1, fine_embed_dim - fine_matching_slice_dim, -1)
- # Retrieve first stage fine features
- fine_features_0 = split_fine_features_0[0]
- fine_features_1 = split_fine_features_1[0]
- # Normalize first stage fine features
- fine_features_0 = fine_features_0 / fine_features_0.shape[-1] ** 0.5
- fine_features_1 = fine_features_1 / fine_features_1.shape[-1] ** 0.5
- # Compute first stage confidence
- fine_confidence = fine_features_0 @ fine_features_1.transpose(-1, -2)
- fine_confidence = nn.functional.softmax(fine_confidence, 1) * nn.functional.softmax(fine_confidence, 2)
- fine_confidence = fine_confidence.reshape(
- batch_size, num_keypoints, fine_window_size, fine_kernel_size + 2, fine_kernel_size + 2
- )
- fine_confidence = fine_confidence[..., 1:-1, 1:-1]
- first_stage_fine_confidence = fine_confidence.reshape(
- batch_size, num_keypoints, fine_window_size, fine_window_size
- )
- fine_indices, fine_matches = self._get_first_stage_fine_matching(
- first_stage_fine_confidence,
- coarse_matched_keypoints,
- fine_window_size,
- fine_scale,
- )
- # Retrieve second stage fine features
- fine_features_0 = split_fine_features_0[1]
- fine_features_1 = split_fine_features_1[1]
- # Normalize second stage fine features
- fine_features_1 = fine_features_1 / fine_matching_slice_dim**0.5
- # Compute second stage fine confidence
- second_stage_fine_confidence = fine_features_0 @ fine_features_1.transpose(-1, -2)
- fine_coordinates = self._get_second_stage_fine_matching(
- fine_indices,
- fine_matches,
- second_stage_fine_confidence,
- fine_window_size,
- fine_scale,
- )
- return fine_coordinates
- @auto_docstring
- @can_return_tuple
- def forward(
- self,
- pixel_values: torch.FloatTensor,
- labels: torch.LongTensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> EfficientLoFTRKeypointMatchingOutput:
- r"""
- Examples:
- ```python
- >>> from transformers import AutoImageProcessor, AutoModel
- >>> import torch
- >>> from PIL import Image
- >>> import httpx
- >>> from io import BytesIO
- >>> url = "https://github.com/magicleap/SuperGluePretrainedNetwork/blob/master/assets/phototourism_sample_images/london_bridge_78916675_4568141288.jpg?raw=true"
- >>> with httpx.stream("GET", url) as response:
- ... image1 = Image.open(BytesIO(response.read()))
- >>> url = "https://github.com/magicleap/SuperGluePretrainedNetwork/blob/master/assets/phototourism_sample_images/london_bridge_19481797_2295892421.jpg?raw=true"
- >>> with httpx.stream("GET", url) as response:
- ... image2 = Image.open(BytesIO(response.read()))
- >>> images = [image1, image2]
- >>> processor = AutoImageProcessor.from_pretrained("zju-community/efficient_loftr")
- >>> model = AutoModel.from_pretrained("zju-community/efficient_loftr")
- >>> with torch.no_grad():
- >>> inputs = processor(images, return_tensors="pt")
- >>> outputs = model(**inputs)
- ```"""
- if labels is not None:
- raise ValueError("SuperGlue is not trainable, no labels should be provided.")
- # 1. Extract coarse and residual features
- model_outputs: BackboneOutput = self.efficientloftr(pixel_values, **kwargs)
- features = model_outputs.feature_maps
- # 2. Compute coarse-level matching
- coarse_features = features[0]
- coarse_embed_dim, coarse_height, coarse_width = coarse_features.shape[-3:]
- batch_size, _, channels, height, width = pixel_values.shape
- coarse_scale = height / coarse_height
- coarse_keypoints, coarse_matching_scores, coarse_matched_indices = self._coarse_matching(
- coarse_features, coarse_scale
- )
- # 3. Fine-level refinement
- residual_features = features[1:]
- coarse_features = coarse_features / self.config.hidden_size**0.5
- fine_features_0, fine_features_1 = self.refinement_layer(coarse_features, residual_features)
- # Filter fine features with coarse matches indices
- _, _, num_keypoints = coarse_matching_scores.shape
- batch_indices = torch.arange(batch_size)[..., None]
- fine_features_0 = fine_features_0[batch_indices, coarse_matched_indices[:, 0]]
- fine_features_1 = fine_features_1[batch_indices, coarse_matched_indices[:, 1]]
- # 4. Computer fine-level matching
- fine_height = torch_int(coarse_height * coarse_scale)
- fine_scale = height / fine_height
- matching_keypoints = self._fine_matching(fine_features_0, fine_features_1, coarse_keypoints, fine_scale)
- matching_keypoints[:, :, :, 0] = matching_keypoints[:, :, :, 0] / width
- matching_keypoints[:, :, :, 1] = matching_keypoints[:, :, :, 1] / height
- loss = None
- return EfficientLoFTRKeypointMatchingOutput(
- loss=loss,
- matches=coarse_matched_indices,
- matching_scores=coarse_matching_scores,
- keypoints=matching_keypoints,
- hidden_states=model_outputs.hidden_states,
- attentions=model_outputs.attentions,
- )
- __all__ = ["EfficientLoFTRPreTrainedModel", "EfficientLoFTRModel", "EfficientLoFTRForKeypointMatching"]
|