|
@@ -0,0 +1,667 @@
|
|
|
|
|
+import warnings
|
|
|
|
|
+from pathlib import Path
|
|
|
|
|
+from types import SimpleNamespace
|
|
|
|
|
+from typing import Callable, List, Optional, Tuple
|
|
|
|
|
+
|
|
|
|
|
+import numpy as np
|
|
|
|
|
+import torch
|
|
|
|
|
+import torch.nn.functional as F
|
|
|
|
|
+from torch import nn
|
|
|
|
|
+
|
|
|
|
|
+try:
|
|
|
|
|
+ from flash_attn.modules.mha import FlashCrossAttention
|
|
|
|
|
+except ModuleNotFoundError:
|
|
|
|
|
+ FlashCrossAttention = None
|
|
|
|
|
+
|
|
|
|
|
+if FlashCrossAttention or hasattr(F, "scaled_dot_product_attention"):
|
|
|
|
|
+ FLASH_AVAILABLE = True
|
|
|
|
|
+else:
|
|
|
|
|
+ FLASH_AVAILABLE = False
|
|
|
|
|
+
|
|
|
|
|
+torch.backends.cudnn.deterministic = True
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+AMP_CUSTOM_FWD_F32 = (
|
|
|
|
|
+ torch.amp.custom_fwd(cast_inputs=torch.float32, device_type="cuda")
|
|
|
|
|
+ if hasattr(torch, "amp") and hasattr(torch.amp, "custom_fwd")
|
|
|
|
|
+ else torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
|
|
|
|
+)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@AMP_CUSTOM_FWD_F32
|
|
|
|
|
+def normalize_keypoints(
|
|
|
|
|
+ kpts: torch.Tensor, size: Optional[torch.Tensor] = None
|
|
|
|
|
+) -> torch.Tensor:
|
|
|
|
|
+ if size is None:
|
|
|
|
|
+ size = 1 + kpts.max(-2).values - kpts.min(-2).values
|
|
|
|
|
+ elif not isinstance(size, torch.Tensor):
|
|
|
|
|
+ size = torch.tensor(size, device=kpts.device, dtype=kpts.dtype)
|
|
|
|
|
+ size = size.to(kpts)
|
|
|
|
|
+ shift = size / 2
|
|
|
|
|
+ scale = size.max(-1).values / 2
|
|
|
|
|
+ kpts = (kpts - shift[..., None, :]) / scale[..., None, None]
|
|
|
|
|
+ return kpts
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def pad_to_length(x: torch.Tensor, length: int) -> Tuple[torch.Tensor]:
|
|
|
|
|
+ if length <= x.shape[-2]:
|
|
|
|
|
+ return x, torch.ones_like(x[..., :1], dtype=torch.bool)
|
|
|
|
|
+ pad = torch.ones(
|
|
|
|
|
+ *x.shape[:-2], length - x.shape[-2], x.shape[-1], device=x.device, dtype=x.dtype
|
|
|
|
|
+ )
|
|
|
|
|
+ y = torch.cat([x, pad], dim=-2)
|
|
|
|
|
+ mask = torch.zeros(*y.shape[:-1], 1, dtype=torch.bool, device=x.device)
|
|
|
|
|
+ mask[..., : x.shape[-2], :] = True
|
|
|
|
|
+ return y, mask
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def rotate_half(x: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
+ x = x.unflatten(-1, (-1, 2))
|
|
|
|
|
+ x1, x2 = x.unbind(dim=-1)
|
|
|
|
|
+ return torch.stack((-x2, x1), dim=-1).flatten(start_dim=-2)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def apply_cached_rotary_emb(freqs: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
+ return (t * freqs[0]) + (rotate_half(t) * freqs[1])
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class LearnableFourierPositionalEncoding(nn.Module):
|
|
|
|
|
+ def __init__(self, M: int, dim: int, F_dim: int = None, gamma: float = 1.0) -> None:
|
|
|
|
|
+ super().__init__()
|
|
|
|
|
+ F_dim = F_dim if F_dim is not None else dim
|
|
|
|
|
+ self.gamma = gamma
|
|
|
|
|
+ self.Wr = nn.Linear(M, F_dim // 2, bias=False)
|
|
|
|
|
+ nn.init.normal_(self.Wr.weight.data, mean=0, std=self.gamma**-2)
|
|
|
|
|
+
|
|
|
|
|
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
+ """encode position vector"""
|
|
|
|
|
+ projected = self.Wr(x)
|
|
|
|
|
+ cosines, sines = torch.cos(projected), torch.sin(projected)
|
|
|
|
|
+ emb = torch.stack([cosines, sines], 0).unsqueeze(-3)
|
|
|
|
|
+ return emb.repeat_interleave(2, dim=-1)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class TokenConfidence(nn.Module):
|
|
|
|
|
+ def __init__(self, dim: int) -> None:
|
|
|
|
|
+ super().__init__()
|
|
|
|
|
+ self.token = nn.Sequential(nn.Linear(dim, 1), nn.Sigmoid())
|
|
|
|
|
+
|
|
|
|
|
+ def forward(self, desc0: torch.Tensor, desc1: torch.Tensor):
|
|
|
|
|
+ """get confidence tokens"""
|
|
|
|
|
+ return (
|
|
|
|
|
+ self.token(desc0.detach()).squeeze(-1),
|
|
|
|
|
+ self.token(desc1.detach()).squeeze(-1),
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class Attention(nn.Module):
|
|
|
|
|
+ def __init__(self, allow_flash: bool) -> None:
|
|
|
|
|
+ super().__init__()
|
|
|
|
|
+ if allow_flash and not FLASH_AVAILABLE:
|
|
|
|
|
+ warnings.warn(
|
|
|
|
|
+ "FlashAttention is not available. For optimal speed, "
|
|
|
|
|
+ "consider installing torch >= 2.0 or flash-attn.",
|
|
|
|
|
+ stacklevel=2,
|
|
|
|
|
+ )
|
|
|
|
|
+ self.enable_flash = allow_flash and FLASH_AVAILABLE
|
|
|
|
|
+ self.has_sdp = hasattr(F, "scaled_dot_product_attention")
|
|
|
|
|
+ if allow_flash and FlashCrossAttention:
|
|
|
|
|
+ self.flash_ = FlashCrossAttention()
|
|
|
|
|
+ if self.has_sdp:
|
|
|
|
|
+ torch.backends.cuda.enable_flash_sdp(allow_flash)
|
|
|
|
|
+
|
|
|
|
|
+ def forward(self, q, k, v, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
|
|
|
+ if q.shape[-2] == 0 or k.shape[-2] == 0:
|
|
|
|
|
+ return q.new_zeros((*q.shape[:-1], v.shape[-1]))
|
|
|
|
|
+ if self.enable_flash and q.device.type == "cuda":
|
|
|
|
|
+ # use torch 2.0 scaled_dot_product_attention with flash
|
|
|
|
|
+ if self.has_sdp:
|
|
|
|
|
+ args = [x.half().contiguous() for x in [q, k, v]]
|
|
|
|
|
+ v = F.scaled_dot_product_attention(*args, attn_mask=mask).to(q.dtype)
|
|
|
|
|
+ return v if mask is None else v.nan_to_num()
|
|
|
|
|
+ else:
|
|
|
|
|
+ assert mask is None
|
|
|
|
|
+ q, k, v = [x.transpose(-2, -3).contiguous() for x in [q, k, v]]
|
|
|
|
|
+ m = self.flash_(q.half(), torch.stack([k, v], 2).half())
|
|
|
|
|
+ return m.transpose(-2, -3).to(q.dtype).clone()
|
|
|
|
|
+ elif self.has_sdp:
|
|
|
|
|
+ args = [x.contiguous() for x in [q, k, v]]
|
|
|
|
|
+ v = F.scaled_dot_product_attention(*args, attn_mask=mask)
|
|
|
|
|
+ return v if mask is None else v.nan_to_num()
|
|
|
|
|
+ else:
|
|
|
|
|
+ s = q.shape[-1] ** -0.5
|
|
|
|
|
+ sim = torch.einsum("...id,...jd->...ij", q, k) * s
|
|
|
|
|
+ if mask is not None:
|
|
|
|
|
+ sim.masked_fill(~mask, -float("inf"))
|
|
|
|
|
+ attn = F.softmax(sim, -1)
|
|
|
|
|
+ return torch.einsum("...ij,...jd->...id", attn, v)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class SelfBlock(nn.Module):
|
|
|
|
|
+ def __init__(
|
|
|
|
|
+ self, embed_dim: int, num_heads: int, flash: bool = False, bias: bool = True
|
|
|
|
|
+ ) -> None:
|
|
|
|
|
+ super().__init__()
|
|
|
|
|
+ self.embed_dim = embed_dim
|
|
|
|
|
+ self.num_heads = num_heads
|
|
|
|
|
+ assert self.embed_dim % num_heads == 0
|
|
|
|
|
+ self.head_dim = self.embed_dim // num_heads
|
|
|
|
|
+ self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias)
|
|
|
|
|
+ self.inner_attn = Attention(flash)
|
|
|
|
|
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
|
|
|
|
+ self.ffn = nn.Sequential(
|
|
|
|
|
+ nn.Linear(2 * embed_dim, 2 * embed_dim),
|
|
|
|
|
+ nn.LayerNorm(2 * embed_dim, elementwise_affine=True),
|
|
|
|
|
+ nn.GELU(),
|
|
|
|
|
+ nn.Linear(2 * embed_dim, embed_dim),
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ def forward(
|
|
|
|
|
+ self,
|
|
|
|
|
+ x: torch.Tensor,
|
|
|
|
|
+ encoding: torch.Tensor,
|
|
|
|
|
+ mask: Optional[torch.Tensor] = None,
|
|
|
|
|
+ ) -> torch.Tensor:
|
|
|
|
|
+ qkv = self.Wqkv(x)
|
|
|
|
|
+ qkv = qkv.unflatten(-1, (self.num_heads, -1, 3)).transpose(1, 2)
|
|
|
|
|
+ q, k, v = qkv[..., 0], qkv[..., 1], qkv[..., 2]
|
|
|
|
|
+ q = apply_cached_rotary_emb(encoding, q)
|
|
|
|
|
+ k = apply_cached_rotary_emb(encoding, k)
|
|
|
|
|
+ context = self.inner_attn(q, k, v, mask=mask)
|
|
|
|
|
+ message = self.out_proj(context.transpose(1, 2).flatten(start_dim=-2))
|
|
|
|
|
+ return x + self.ffn(torch.cat([x, message], -1))
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class CrossBlock(nn.Module):
|
|
|
|
|
+ def __init__(
|
|
|
|
|
+ self, embed_dim: int, num_heads: int, flash: bool = False, bias: bool = True
|
|
|
|
|
+ ) -> None:
|
|
|
|
|
+ super().__init__()
|
|
|
|
|
+ self.heads = num_heads
|
|
|
|
|
+ dim_head = embed_dim // num_heads
|
|
|
|
|
+ self.scale = dim_head**-0.5
|
|
|
|
|
+ inner_dim = dim_head * num_heads
|
|
|
|
|
+ self.to_qk = nn.Linear(embed_dim, inner_dim, bias=bias)
|
|
|
|
|
+ self.to_v = nn.Linear(embed_dim, inner_dim, bias=bias)
|
|
|
|
|
+ self.to_out = nn.Linear(inner_dim, embed_dim, bias=bias)
|
|
|
|
|
+ self.ffn = nn.Sequential(
|
|
|
|
|
+ nn.Linear(2 * embed_dim, 2 * embed_dim),
|
|
|
|
|
+ nn.LayerNorm(2 * embed_dim, elementwise_affine=True),
|
|
|
|
|
+ nn.GELU(),
|
|
|
|
|
+ nn.Linear(2 * embed_dim, embed_dim),
|
|
|
|
|
+ )
|
|
|
|
|
+ if flash and FLASH_AVAILABLE:
|
|
|
|
|
+ self.flash = Attention(True)
|
|
|
|
|
+ else:
|
|
|
|
|
+ self.flash = None
|
|
|
|
|
+
|
|
|
|
|
+ def map_(self, func: Callable, x0: torch.Tensor, x1: torch.Tensor):
|
|
|
|
|
+ return func(x0), func(x1)
|
|
|
|
|
+
|
|
|
|
|
+ def forward(
|
|
|
|
|
+ self, x0: torch.Tensor, x1: torch.Tensor, mask: Optional[torch.Tensor] = None
|
|
|
|
|
+ ) -> List[torch.Tensor]:
|
|
|
|
|
+ qk0, qk1 = self.map_(self.to_qk, x0, x1)
|
|
|
|
|
+ v0, v1 = self.map_(self.to_v, x0, x1)
|
|
|
|
|
+ qk0, qk1, v0, v1 = map(
|
|
|
|
|
+ lambda t: t.unflatten(-1, (self.heads, -1)).transpose(1, 2),
|
|
|
|
|
+ (qk0, qk1, v0, v1),
|
|
|
|
|
+ )
|
|
|
|
|
+ if self.flash is not None and qk0.device.type == "cuda":
|
|
|
|
|
+ m0 = self.flash(qk0, qk1, v1, mask)
|
|
|
|
|
+ m1 = self.flash(
|
|
|
|
|
+ qk1, qk0, v0, mask.transpose(-1, -2) if mask is not None else None
|
|
|
|
|
+ )
|
|
|
|
|
+ else:
|
|
|
|
|
+ qk0, qk1 = qk0 * self.scale**0.5, qk1 * self.scale**0.5
|
|
|
|
|
+ sim = torch.einsum("bhid, bhjd -> bhij", qk0, qk1)
|
|
|
|
|
+ if mask is not None:
|
|
|
|
|
+ sim = sim.masked_fill(~mask, -float("inf"))
|
|
|
|
|
+ attn01 = F.softmax(sim, dim=-1)
|
|
|
|
|
+ attn10 = F.softmax(sim.transpose(-2, -1).contiguous(), dim=-1)
|
|
|
|
|
+ m0 = torch.einsum("bhij, bhjd -> bhid", attn01, v1)
|
|
|
|
|
+ m1 = torch.einsum("bhji, bhjd -> bhid", attn10.transpose(-2, -1), v0)
|
|
|
|
|
+ if mask is not None:
|
|
|
|
|
+ m0, m1 = m0.nan_to_num(), m1.nan_to_num()
|
|
|
|
|
+ m0, m1 = self.map_(lambda t: t.transpose(1, 2).flatten(start_dim=-2), m0, m1)
|
|
|
|
|
+ m0, m1 = self.map_(self.to_out, m0, m1)
|
|
|
|
|
+ x0 = x0 + self.ffn(torch.cat([x0, m0], -1))
|
|
|
|
|
+ x1 = x1 + self.ffn(torch.cat([x1, m1], -1))
|
|
|
|
|
+ return x0, x1
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class TransformerLayer(nn.Module):
|
|
|
|
|
+ def __init__(self, *args, **kwargs):
|
|
|
|
|
+ super().__init__()
|
|
|
|
|
+ self.self_attn = SelfBlock(*args, **kwargs)
|
|
|
|
|
+ self.cross_attn = CrossBlock(*args, **kwargs)
|
|
|
|
|
+
|
|
|
|
|
+ def forward(
|
|
|
|
|
+ self,
|
|
|
|
|
+ desc0,
|
|
|
|
|
+ desc1,
|
|
|
|
|
+ encoding0,
|
|
|
|
|
+ encoding1,
|
|
|
|
|
+ mask0: Optional[torch.Tensor] = None,
|
|
|
|
|
+ mask1: Optional[torch.Tensor] = None,
|
|
|
|
|
+ ):
|
|
|
|
|
+ if mask0 is not None and mask1 is not None:
|
|
|
|
|
+ return self.masked_forward(desc0, desc1, encoding0, encoding1, mask0, mask1)
|
|
|
|
|
+ else:
|
|
|
|
|
+ desc0 = self.self_attn(desc0, encoding0)
|
|
|
|
|
+ desc1 = self.self_attn(desc1, encoding1)
|
|
|
|
|
+ return self.cross_attn(desc0, desc1)
|
|
|
|
|
+
|
|
|
|
|
+ # This part is compiled and allows padding inputs
|
|
|
|
|
+ def masked_forward(self, desc0, desc1, encoding0, encoding1, mask0, mask1):
|
|
|
|
|
+ mask = mask0 & mask1.transpose(-1, -2)
|
|
|
|
|
+ mask0 = mask0 & mask0.transpose(-1, -2)
|
|
|
|
|
+ mask1 = mask1 & mask1.transpose(-1, -2)
|
|
|
|
|
+ desc0 = self.self_attn(desc0, encoding0, mask0)
|
|
|
|
|
+ desc1 = self.self_attn(desc1, encoding1, mask1)
|
|
|
|
|
+ return self.cross_attn(desc0, desc1, mask)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def sigmoid_log_double_softmax(
|
|
|
|
|
+ sim: torch.Tensor, z0: torch.Tensor, z1: torch.Tensor
|
|
|
|
|
+) -> torch.Tensor:
|
|
|
|
|
+ """create the log assignment matrix from logits and similarity"""
|
|
|
|
|
+ b, m, n = sim.shape
|
|
|
|
|
+ certainties = F.logsigmoid(z0) + F.logsigmoid(z1).transpose(1, 2)
|
|
|
|
|
+ scores0 = F.log_softmax(sim, 2)
|
|
|
|
|
+ scores1 = F.log_softmax(sim.transpose(-1, -2).contiguous(), 2).transpose(-1, -2)
|
|
|
|
|
+ scores = sim.new_full((b, m + 1, n + 1), 0)
|
|
|
|
|
+ scores[:, :m, :n] = scores0 + scores1 + certainties
|
|
|
|
|
+ scores[:, :-1, -1] = F.logsigmoid(-z0.squeeze(-1))
|
|
|
|
|
+ scores[:, -1, :-1] = F.logsigmoid(-z1.squeeze(-1))
|
|
|
|
|
+ return scores
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class MatchAssignment(nn.Module):
|
|
|
|
|
+ def __init__(self, dim: int) -> None:
|
|
|
|
|
+ super().__init__()
|
|
|
|
|
+ self.dim = dim
|
|
|
|
|
+ self.matchability = nn.Linear(dim, 1, bias=True)
|
|
|
|
|
+ self.final_proj = nn.Linear(dim, dim, bias=True)
|
|
|
|
|
+
|
|
|
|
|
+ def forward(self, desc0: torch.Tensor, desc1: torch.Tensor):
|
|
|
|
|
+ """build assignment matrix from descriptors"""
|
|
|
|
|
+ mdesc0, mdesc1 = self.final_proj(desc0), self.final_proj(desc1)
|
|
|
|
|
+ _, _, d = mdesc0.shape
|
|
|
|
|
+ mdesc0, mdesc1 = mdesc0 / d**0.25, mdesc1 / d**0.25
|
|
|
|
|
+ sim = torch.einsum("bmd,bnd->bmn", mdesc0, mdesc1)
|
|
|
|
|
+ z0 = self.matchability(desc0)
|
|
|
|
|
+ z1 = self.matchability(desc1)
|
|
|
|
|
+ scores = sigmoid_log_double_softmax(sim, z0, z1)
|
|
|
|
|
+ return scores, sim
|
|
|
|
|
+
|
|
|
|
|
+ def get_matchability(self, desc: torch.Tensor):
|
|
|
|
|
+ return torch.sigmoid(self.matchability(desc)).squeeze(-1)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def filter_matches(scores: torch.Tensor, th: float):
|
|
|
|
|
+ """obtain matches from a log assignment matrix [Bx M+1 x N+1]"""
|
|
|
|
|
+ max0, max1 = scores[:, :-1, :-1].max(2), scores[:, :-1, :-1].max(1)
|
|
|
|
|
+ m0, m1 = max0.indices, max1.indices
|
|
|
|
|
+ indices0 = torch.arange(m0.shape[1], device=m0.device)[None]
|
|
|
|
|
+ indices1 = torch.arange(m1.shape[1], device=m1.device)[None]
|
|
|
|
|
+ mutual0 = indices0 == m1.gather(1, m0)
|
|
|
|
|
+ mutual1 = indices1 == m0.gather(1, m1)
|
|
|
|
|
+ max0_exp = max0.values.exp()
|
|
|
|
|
+ zero = max0_exp.new_tensor(0)
|
|
|
|
|
+ mscores0 = torch.where(mutual0, max0_exp, zero)
|
|
|
|
|
+ mscores1 = torch.where(mutual1, mscores0.gather(1, m1), zero)
|
|
|
|
|
+ valid0 = mutual0 & (mscores0 > th)
|
|
|
|
|
+ valid1 = mutual1 & valid0.gather(1, m1)
|
|
|
|
|
+ m0 = torch.where(valid0, m0, -1)
|
|
|
|
|
+ m1 = torch.where(valid1, m1, -1)
|
|
|
|
|
+ return m0, m1, mscores0, mscores1
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class LightGlue(nn.Module):
|
|
|
|
|
+ default_conf = {
|
|
|
|
|
+ "name": "lightglue", # just for interfacing
|
|
|
|
|
+ "input_dim": 256, # input descriptor dimension (autoselected from weights)
|
|
|
|
|
+ "descriptor_dim": 256,
|
|
|
|
|
+ "add_scale_ori": False,
|
|
|
|
|
+ "n_layers": 9,
|
|
|
|
|
+ "num_heads": 4,
|
|
|
|
|
+ "flash": True, # enable FlashAttention if available.
|
|
|
|
|
+ "mp": False, # enable mixed precision
|
|
|
|
|
+ "depth_confidence": 0.95, # early stopping, disable with -1
|
|
|
|
|
+ "width_confidence": 0.99, # point pruning, disable with -1
|
|
|
|
|
+ "filter_threshold": 0.1, # match threshold
|
|
|
|
|
+ "weights": None,
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ # Point pruning involves an overhead (gather).
|
|
|
|
|
+ # Therefore, we only activate it if there are enough keypoints.
|
|
|
|
|
+ pruning_keypoint_thresholds = {
|
|
|
|
|
+ "cpu": -1,
|
|
|
|
|
+ "mps": -1,
|
|
|
|
|
+ "cuda": 1024,
|
|
|
|
|
+ "flash": 1536,
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ required_data_keys = ["image0", "image1"]
|
|
|
|
|
+
|
|
|
|
|
+ version = "v0.1_arxiv"
|
|
|
|
|
+ url = "https://github.com/cvg/LightGlue/releases/download/{}/{}.pth"
|
|
|
|
|
+
|
|
|
|
|
+ features = {
|
|
|
|
|
+ "superpoint": {
|
|
|
|
|
+ "weights": "superpoint_lightglue",
|
|
|
|
|
+ "input_dim": 256,
|
|
|
|
|
+ },
|
|
|
|
|
+ "disk": {
|
|
|
|
|
+ "weights": "disk_lightglue",
|
|
|
|
|
+ "input_dim": 128,
|
|
|
|
|
+ },
|
|
|
|
|
+ "aliked": {
|
|
|
|
|
+ "weights": "aliked_lightglue",
|
|
|
|
|
+ "input_dim": 128,
|
|
|
|
|
+ },
|
|
|
|
|
+ "raco-aliked": {
|
|
|
|
|
+ "weights": "raco_aliked_lightglue",
|
|
|
|
|
+ "input_dim": 128,
|
|
|
|
|
+ },
|
|
|
|
|
+ "sift": {
|
|
|
|
|
+ "weights": "sift_lightglue",
|
|
|
|
|
+ "input_dim": 128,
|
|
|
|
|
+ "add_scale_ori": True,
|
|
|
|
|
+ },
|
|
|
|
|
+ "doghardnet": {
|
|
|
|
|
+ "weights": "doghardnet_lightglue",
|
|
|
|
|
+ "input_dim": 128,
|
|
|
|
|
+ "add_scale_ori": True,
|
|
|
|
|
+ },
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ def __init__(self, features="superpoint", **conf) -> None:
|
|
|
|
|
+ super().__init__()
|
|
|
|
|
+ self.conf = conf = SimpleNamespace(**{**self.default_conf, **conf})
|
|
|
|
|
+ if features is not None:
|
|
|
|
|
+ if features not in self.features:
|
|
|
|
|
+ raise ValueError(
|
|
|
|
|
+ f"Unsupported features: {features} not in "
|
|
|
|
|
+ f"{{{','.join(self.features)}}}"
|
|
|
|
|
+ )
|
|
|
|
|
+ for k, v in self.features[features].items():
|
|
|
|
|
+ setattr(conf, k, v)
|
|
|
|
|
+
|
|
|
|
|
+ if conf.input_dim != conf.descriptor_dim:
|
|
|
|
|
+ self.input_proj = nn.Linear(conf.input_dim, conf.descriptor_dim, bias=True)
|
|
|
|
|
+ else:
|
|
|
|
|
+ self.input_proj = nn.Identity()
|
|
|
|
|
+
|
|
|
|
|
+ head_dim = conf.descriptor_dim // conf.num_heads
|
|
|
|
|
+ self.posenc = LearnableFourierPositionalEncoding(
|
|
|
|
|
+ 2 + 2 * self.conf.add_scale_ori, head_dim, head_dim
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ h, n, d = conf.num_heads, conf.n_layers, conf.descriptor_dim
|
|
|
|
|
+
|
|
|
|
|
+ self.transformers = nn.ModuleList(
|
|
|
|
|
+ [TransformerLayer(d, h, conf.flash) for _ in range(n)]
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ self.log_assignment = nn.ModuleList([MatchAssignment(d) for _ in range(n)])
|
|
|
|
|
+ self.token_confidence = nn.ModuleList(
|
|
|
|
|
+ [TokenConfidence(d) for _ in range(n - 1)]
|
|
|
|
|
+ )
|
|
|
|
|
+ self.register_buffer(
|
|
|
|
|
+ "confidence_thresholds",
|
|
|
|
|
+ torch.Tensor(
|
|
|
|
|
+ [self.confidence_threshold(i) for i in range(self.conf.n_layers)]
|
|
|
|
|
+ ),
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ state_dict = None
|
|
|
|
|
+ if features is not None:
|
|
|
|
|
+ fname = f"{conf.weights}_{self.version.replace('.', '-')}.pth"
|
|
|
|
|
+ state_dict = torch.hub.load_state_dict_from_url(
|
|
|
|
|
+ self.url.format(self.version, self.conf.weights),
|
|
|
|
|
+ file_name=fname,
|
|
|
|
|
+ )
|
|
|
|
|
+ self.load_state_dict(state_dict, strict=False)
|
|
|
|
|
+ elif conf.weights is not None:
|
|
|
|
|
+ path = Path(__file__).parent
|
|
|
|
|
+ path = path / "weights/{}.pth".format(self.conf.weights)
|
|
|
|
|
+ state_dict = torch.load(str(path), map_location="cpu")
|
|
|
|
|
+
|
|
|
|
|
+ if state_dict:
|
|
|
|
|
+ # rename old state dict entries
|
|
|
|
|
+ for i in range(self.conf.n_layers):
|
|
|
|
|
+ pattern = f"self_attn.{i}", f"transformers.{i}.self_attn"
|
|
|
|
|
+ state_dict = {k.replace(*pattern): v for k, v in state_dict.items()}
|
|
|
|
|
+ pattern = f"cross_attn.{i}", f"transformers.{i}.cross_attn"
|
|
|
|
|
+ state_dict = {k.replace(*pattern): v for k, v in state_dict.items()}
|
|
|
|
|
+ self.load_state_dict(state_dict, strict=False)
|
|
|
|
|
+
|
|
|
|
|
+ # static lengths LightGlue is compiled for (only used with torch.compile)
|
|
|
|
|
+ self.static_lengths = None
|
|
|
|
|
+
|
|
|
|
|
+ def compile(
|
|
|
|
|
+ self, mode="reduce-overhead", static_lengths=[256, 512, 768, 1024, 1280, 1536]
|
|
|
|
|
+ ):
|
|
|
|
|
+ if self.conf.width_confidence != -1:
|
|
|
|
|
+ warnings.warn(
|
|
|
|
|
+ "Point pruning is partially disabled for compiled forward.",
|
|
|
|
|
+ stacklevel=2,
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ torch._inductor.cudagraph_mark_step_begin()
|
|
|
|
|
+ for i in range(self.conf.n_layers):
|
|
|
|
|
+ self.transformers[i].masked_forward = torch.compile(
|
|
|
|
|
+ self.transformers[i].masked_forward, mode=mode, fullgraph=True
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ self.static_lengths = static_lengths
|
|
|
|
|
+
|
|
|
|
|
+ def forward(self, data: dict) -> dict:
|
|
|
|
|
+ """
|
|
|
|
|
+ Match keypoints and descriptors between two images
|
|
|
|
|
+
|
|
|
|
|
+ Input (dict):
|
|
|
|
|
+ image0: dict
|
|
|
|
|
+ keypoints: [B x M x 2]
|
|
|
|
|
+ descriptors: [B x M x D]
|
|
|
|
|
+ image: [B x C x H x W] or image_size: [B x 2]
|
|
|
|
|
+ image1: dict
|
|
|
|
|
+ keypoints: [B x N x 2]
|
|
|
|
|
+ descriptors: [B x N x D]
|
|
|
|
|
+ image: [B x C x H x W] or image_size: [B x 2]
|
|
|
|
|
+ Output (dict):
|
|
|
|
|
+ matches0: [B x M]
|
|
|
|
|
+ matching_scores0: [B x M]
|
|
|
|
|
+ matches1: [B x N]
|
|
|
|
|
+ matching_scores1: [B x N]
|
|
|
|
|
+ matches: List[[Si x 2]]
|
|
|
|
|
+ scores: List[[Si]]
|
|
|
|
|
+ stop: int
|
|
|
|
|
+ prune0: [B x M]
|
|
|
|
|
+ prune1: [B x N]
|
|
|
|
|
+ """
|
|
|
|
|
+ with torch.autocast(enabled=self.conf.mp, device_type="cuda"):
|
|
|
|
|
+ return self._forward(data)
|
|
|
|
|
+
|
|
|
|
|
+ def _forward(self, data: dict) -> dict:
|
|
|
|
|
+ for key in self.required_data_keys:
|
|
|
|
|
+ assert key in data, f"Missing key {key} in data"
|
|
|
|
|
+ data0, data1 = data["image0"], data["image1"]
|
|
|
|
|
+ kpts0, kpts1 = data0["keypoints"], data1["keypoints"]
|
|
|
|
|
+ b, m, _ = kpts0.shape
|
|
|
|
|
+ b, n, _ = kpts1.shape
|
|
|
|
|
+ device = kpts0.device
|
|
|
|
|
+ size0, size1 = data0.get("image_size"), data1.get("image_size")
|
|
|
|
|
+ kpts0 = normalize_keypoints(kpts0, size0).clone()
|
|
|
|
|
+ kpts1 = normalize_keypoints(kpts1, size1).clone()
|
|
|
|
|
+
|
|
|
|
|
+ if self.conf.add_scale_ori:
|
|
|
|
|
+ kpts0 = torch.cat(
|
|
|
|
|
+ [kpts0] + [data0[k].unsqueeze(-1) for k in ("scales", "oris")], -1
|
|
|
|
|
+ )
|
|
|
|
|
+ kpts1 = torch.cat(
|
|
|
|
|
+ [kpts1] + [data1[k].unsqueeze(-1) for k in ("scales", "oris")], -1
|
|
|
|
|
+ )
|
|
|
|
|
+ desc0 = data0["descriptors"].detach().contiguous()
|
|
|
|
|
+ desc1 = data1["descriptors"].detach().contiguous()
|
|
|
|
|
+
|
|
|
|
|
+ assert desc0.shape[-1] == self.conf.input_dim
|
|
|
|
|
+ assert desc1.shape[-1] == self.conf.input_dim
|
|
|
|
|
+
|
|
|
|
|
+ if torch.is_autocast_enabled():
|
|
|
|
|
+ desc0 = desc0.half()
|
|
|
|
|
+ desc1 = desc1.half()
|
|
|
|
|
+
|
|
|
|
|
+ mask0, mask1 = None, None
|
|
|
|
|
+ c = max(m, n)
|
|
|
|
|
+ do_compile = self.static_lengths and c <= max(self.static_lengths)
|
|
|
|
|
+ if do_compile:
|
|
|
|
|
+ kn = min([k for k in self.static_lengths if k >= c])
|
|
|
|
|
+ desc0, mask0 = pad_to_length(desc0, kn)
|
|
|
|
|
+ desc1, mask1 = pad_to_length(desc1, kn)
|
|
|
|
|
+ kpts0, _ = pad_to_length(kpts0, kn)
|
|
|
|
|
+ kpts1, _ = pad_to_length(kpts1, kn)
|
|
|
|
|
+ desc0 = self.input_proj(desc0)
|
|
|
|
|
+ desc1 = self.input_proj(desc1)
|
|
|
|
|
+ # cache positional embeddings
|
|
|
|
|
+ encoding0 = self.posenc(kpts0)
|
|
|
|
|
+ encoding1 = self.posenc(kpts1)
|
|
|
|
|
+
|
|
|
|
|
+ # GNN + final_proj + assignment
|
|
|
|
|
+ do_early_stop = self.conf.depth_confidence > 0
|
|
|
|
|
+ do_point_pruning = self.conf.width_confidence > 0 and not do_compile
|
|
|
|
|
+ pruning_th = self.pruning_min_kpts(device)
|
|
|
|
|
+ if do_point_pruning:
|
|
|
|
|
+ ind0 = torch.arange(0, m, device=device)[None]
|
|
|
|
|
+ ind1 = torch.arange(0, n, device=device)[None]
|
|
|
|
|
+ # We store the index of the layer at which pruning is detected.
|
|
|
|
|
+ prune0 = torch.ones_like(ind0)
|
|
|
|
|
+ prune1 = torch.ones_like(ind1)
|
|
|
|
|
+ token0, token1 = None, None
|
|
|
|
|
+ for i in range(self.conf.n_layers):
|
|
|
|
|
+ if desc0.shape[1] == 0 or desc1.shape[1] == 0: # no keypoints
|
|
|
|
|
+ break
|
|
|
|
|
+ desc0, desc1 = self.transformers[i](
|
|
|
|
|
+ desc0, desc1, encoding0, encoding1, mask0=mask0, mask1=mask1
|
|
|
|
|
+ )
|
|
|
|
|
+ if i == self.conf.n_layers - 1:
|
|
|
|
|
+ continue # no early stopping or adaptive width at last layer
|
|
|
|
|
+
|
|
|
|
|
+ if do_early_stop:
|
|
|
|
|
+ token0, token1 = self.token_confidence[i](desc0, desc1)
|
|
|
|
|
+ if self.check_if_stop(token0[..., :m], token1[..., :n], i, m + n):
|
|
|
|
|
+ break
|
|
|
|
|
+ if do_point_pruning and desc0.shape[-2] > pruning_th:
|
|
|
|
|
+ scores0 = self.log_assignment[i].get_matchability(desc0)
|
|
|
|
|
+ prunemask0 = self.get_pruning_mask(token0, scores0, i)
|
|
|
|
|
+ keep0 = torch.where(prunemask0)[1]
|
|
|
|
|
+ ind0 = ind0.index_select(1, keep0)
|
|
|
|
|
+ desc0 = desc0.index_select(1, keep0)
|
|
|
|
|
+ encoding0 = encoding0.index_select(-2, keep0)
|
|
|
|
|
+ prune0[:, ind0] += 1
|
|
|
|
|
+ if do_point_pruning and desc1.shape[-2] > pruning_th:
|
|
|
|
|
+ scores1 = self.log_assignment[i].get_matchability(desc1)
|
|
|
|
|
+ prunemask1 = self.get_pruning_mask(token1, scores1, i)
|
|
|
|
|
+ keep1 = torch.where(prunemask1)[1]
|
|
|
|
|
+ ind1 = ind1.index_select(1, keep1)
|
|
|
|
|
+ desc1 = desc1.index_select(1, keep1)
|
|
|
|
|
+ encoding1 = encoding1.index_select(-2, keep1)
|
|
|
|
|
+ prune1[:, ind1] += 1
|
|
|
|
|
+
|
|
|
|
|
+ if desc0.shape[1] == 0 or desc1.shape[1] == 0: # no keypoints
|
|
|
|
|
+ m0 = desc0.new_full((b, m), -1, dtype=torch.long)
|
|
|
|
|
+ m1 = desc1.new_full((b, n), -1, dtype=torch.long)
|
|
|
|
|
+ mscores0 = desc0.new_zeros((b, m))
|
|
|
|
|
+ mscores1 = desc1.new_zeros((b, n))
|
|
|
|
|
+ matches = desc0.new_empty((b, 0, 2), dtype=torch.long)
|
|
|
|
|
+ mscores = desc0.new_empty((b, 0))
|
|
|
|
|
+ if not do_point_pruning:
|
|
|
|
|
+ prune0 = torch.ones_like(mscores0) * self.conf.n_layers
|
|
|
|
|
+ prune1 = torch.ones_like(mscores1) * self.conf.n_layers
|
|
|
|
|
+ return {
|
|
|
|
|
+ "matches0": m0,
|
|
|
|
|
+ "matches1": m1,
|
|
|
|
|
+ "matching_scores0": mscores0,
|
|
|
|
|
+ "matching_scores1": mscores1,
|
|
|
|
|
+ "stop": i + 1,
|
|
|
|
|
+ "matches": matches,
|
|
|
|
|
+ "scores": mscores,
|
|
|
|
|
+ "prune0": prune0,
|
|
|
|
|
+ "prune1": prune1,
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ desc0, desc1 = desc0[..., :m, :], desc1[..., :n, :] # remove padding
|
|
|
|
|
+ scores, _ = self.log_assignment[i](desc0, desc1)
|
|
|
|
|
+ m0, m1, mscores0, mscores1 = filter_matches(scores, self.conf.filter_threshold)
|
|
|
|
|
+ matches, mscores = [], []
|
|
|
|
|
+ for k in range(b):
|
|
|
|
|
+ valid = m0[k] > -1
|
|
|
|
|
+ m_indices_0 = torch.where(valid)[0]
|
|
|
|
|
+ m_indices_1 = m0[k][valid]
|
|
|
|
|
+ if do_point_pruning:
|
|
|
|
|
+ m_indices_0 = ind0[k, m_indices_0]
|
|
|
|
|
+ m_indices_1 = ind1[k, m_indices_1]
|
|
|
|
|
+ matches.append(torch.stack([m_indices_0, m_indices_1], -1))
|
|
|
|
|
+ mscores.append(mscores0[k][valid])
|
|
|
|
|
+
|
|
|
|
|
+ # TODO: Remove when hloc switches to the compact format.
|
|
|
|
|
+ if do_point_pruning:
|
|
|
|
|
+ m0_ = torch.full((b, m), -1, device=m0.device, dtype=m0.dtype)
|
|
|
|
|
+ m1_ = torch.full((b, n), -1, device=m1.device, dtype=m1.dtype)
|
|
|
|
|
+ m0_[:, ind0] = torch.where(m0 == -1, -1, ind1.gather(1, m0.clamp(min=0)))
|
|
|
|
|
+ m1_[:, ind1] = torch.where(m1 == -1, -1, ind0.gather(1, m1.clamp(min=0)))
|
|
|
|
|
+ mscores0_ = torch.zeros((b, m), device=mscores0.device)
|
|
|
|
|
+ mscores1_ = torch.zeros((b, n), device=mscores1.device)
|
|
|
|
|
+ mscores0_[:, ind0] = mscores0
|
|
|
|
|
+ mscores1_[:, ind1] = mscores1
|
|
|
|
|
+ m0, m1, mscores0, mscores1 = m0_, m1_, mscores0_, mscores1_
|
|
|
|
|
+ else:
|
|
|
|
|
+ prune0 = torch.ones_like(mscores0) * self.conf.n_layers
|
|
|
|
|
+ prune1 = torch.ones_like(mscores1) * self.conf.n_layers
|
|
|
|
|
+
|
|
|
|
|
+ return {
|
|
|
|
|
+ "matches0": m0,
|
|
|
|
|
+ "matches1": m1,
|
|
|
|
|
+ "matching_scores0": mscores0,
|
|
|
|
|
+ "matching_scores1": mscores1,
|
|
|
|
|
+ "stop": i + 1,
|
|
|
|
|
+ "matches": matches,
|
|
|
|
|
+ "scores": mscores,
|
|
|
|
|
+ "prune0": prune0,
|
|
|
|
|
+ "prune1": prune1,
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ def confidence_threshold(self, layer_index: int) -> float:
|
|
|
|
|
+ """scaled confidence threshold"""
|
|
|
|
|
+ threshold = 0.8 + 0.1 * np.exp(-4.0 * layer_index / self.conf.n_layers)
|
|
|
|
|
+ return np.clip(threshold, 0, 1)
|
|
|
|
|
+
|
|
|
|
|
+ def get_pruning_mask(
|
|
|
|
|
+ self, confidences: torch.Tensor, scores: torch.Tensor, layer_index: int
|
|
|
|
|
+ ) -> torch.Tensor:
|
|
|
|
|
+ """mask points which should be removed"""
|
|
|
|
|
+ keep = scores > (1 - self.conf.width_confidence)
|
|
|
|
|
+ if confidences is not None: # Low-confidence points are never pruned.
|
|
|
|
|
+ keep |= confidences <= self.confidence_thresholds[layer_index]
|
|
|
|
|
+ return keep
|
|
|
|
|
+
|
|
|
|
|
+ def check_if_stop(
|
|
|
|
|
+ self,
|
|
|
|
|
+ confidences0: torch.Tensor,
|
|
|
|
|
+ confidences1: torch.Tensor,
|
|
|
|
|
+ layer_index: int,
|
|
|
|
|
+ num_points: int,
|
|
|
|
|
+ ) -> torch.Tensor:
|
|
|
|
|
+ """evaluate stopping condition"""
|
|
|
|
|
+ confidences = torch.cat([confidences0, confidences1], -1)
|
|
|
|
|
+ threshold = self.confidence_thresholds[layer_index]
|
|
|
|
|
+ ratio_confident = 1.0 - (confidences < threshold).float().sum() / num_points
|
|
|
|
|
+ return ratio_confident > self.conf.depth_confidence
|
|
|
|
|
+
|
|
|
|
|
+ def pruning_min_kpts(self, device: torch.device):
|
|
|
|
|
+ if self.conf.flash and FLASH_AVAILABLE and device.type == "cuda":
|
|
|
|
|
+ return self.pruning_keypoint_thresholds["flash"]
|
|
|
|
|
+ else:
|
|
|
|
|
+ return self.pruning_keypoint_thresholds[device.type]
|