| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701 |
- # LICENSE HEADER MANAGED BY add-license-header
- #
- # Copyright 2018 Kornia Team
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- import math
- import warnings
- from pathlib import Path
- from types import SimpleNamespace
- from typing import Any, Callable, ClassVar, Dict, List, Optional, Sequence, Tuple
- import torch
- import torch.nn.functional as F
- from torch import nn
- from kornia.core import (
- Module,
- ModuleList,
- Tensor,
- arange,
- concatenate,
- cos,
- einsum,
- ones,
- ones_like,
- sin,
- stack,
- where,
- zeros,
- )
- from kornia.core.check import KORNIA_CHECK
- from kornia.feature.laf import laf_to_three_points, scale_laf
- from kornia.utils._compat import custom_fwd
- try:
- from flash_attn.modules.mha import FlashCrossAttention
- except ModuleNotFoundError:
- FlashCrossAttention = None
- if FlashCrossAttention or hasattr(F, "scaled_dot_product_attention"):
- FLASH_AVAILABLE = True
- else:
- FLASH_AVAILABLE = False
- def math_clamp(x, min_, max_): # type: ignore
- """Clamp a value to lie within [min, max]."""
- return min(max(x, min_), max_)
- @custom_fwd(cast_inputs=torch.float32)
- def normalize_keypoints(kpts: Tensor, size: Tensor) -> Tensor:
- """Normalize tensor of keypoints."""
- if isinstance(size, torch.Size):
- size = torch.tensor(size)[None]
- shift = size.float().to(kpts) / 2
- scale = size.max(1).values.float().to(kpts) / 2
- kpts = (kpts - shift[:, None]) / scale[:, None, None]
- return kpts
- def pad_to_length(x: Tensor, length: int) -> Tuple[Tensor, Tensor]:
- """Pad tensor to desired length."""
- if length <= x.shape[-2]:
- return x, ones_like(x[..., :1], dtype=torch.bool)
- pad = ones(*x.shape[:-2], length - x.shape[-2], x.shape[-1], device=x.device, dtype=x.dtype)
- y = concatenate([x, pad], dim=-2)
- mask = zeros(*y.shape[:-1], 1, dtype=torch.bool, device=x.device)
- mask[..., : x.shape[-2], :] = True
- return y, mask
- def rotate_half(x: Tensor) -> Tensor:
- """Apply half rotation."""
- x = x.unflatten(-1, (-1, 2))
- x1, x2 = x.unbind(dim=-1)
- return stack((-x2, x1), dim=-1).flatten(start_dim=-2)
- def apply_cached_rotary_emb(freqs: Tensor, t: Tensor) -> Tensor:
- """Apply rotary embedding."""
- return (t * freqs[0]) + (rotate_half(t) * freqs[1])
- class LearnableFourierPositionalEncoding(Module):
- def __init__(self, M: int, dim: int, F_dim: Optional[int] = None, gamma: float = 1.0) -> None:
- super().__init__()
- F_dim = F_dim if F_dim is not None else dim
- self.gamma = gamma
- self.Wr = nn.Linear(M, F_dim // 2, bias=False)
- nn.init.normal_(self.Wr.weight.data, mean=0, std=self.gamma**-2)
- def forward(self, x: Tensor) -> Tensor:
- """Encode position vector."""
- projected = self.Wr(x)
- cosines, sines = cos(projected), sin(projected)
- emb = stack([cosines, sines], 0).unsqueeze(-3)
- return emb.repeat_interleave(2, dim=-1)
- class TokenConfidence(Module):
- def __init__(self, dim: int) -> None:
- super().__init__()
- self.token = nn.Sequential(nn.Linear(dim, 1), nn.Sigmoid())
- def forward(self, desc0: Tensor, desc1: Tensor) -> Tuple[Tensor, Tensor]:
- """Get confidence tokens."""
- dtype = self.token[0].weight.dtype
- orig_dtype = desc0.dtype
- return (
- self.token(desc0.detach().to(dtype)).squeeze(-1).to(orig_dtype),
- self.token(desc1.detach().to(dtype)).squeeze(-1).to(orig_dtype),
- )
- class Attention(Module):
- def __init__(self, allow_flash: bool) -> None:
- super().__init__()
- if allow_flash and not FLASH_AVAILABLE:
- warnings.warn(
- "FlashAttention is not available. For optimal speed, consider installing torch >= 2.0 or flash-attn.",
- stacklevel=2,
- )
- self.enable_flash = allow_flash and FLASH_AVAILABLE
- self.has_sdp = hasattr(F, "scaled_dot_product_attention")
- if allow_flash and FlashCrossAttention:
- self.flash_ = FlashCrossAttention()
- if self.has_sdp:
- torch.backends.cuda.enable_flash_sdp(allow_flash)
- def forward(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None) -> Tensor:
- if self.enable_flash and q.device.type == "cuda":
- # use torch 2.0 scaled_dot_product_attention with flash
- if self.has_sdp:
- args = [x.half().contiguous() for x in [q, k, v]]
- v = F.scaled_dot_product_attention(*args, attn_mask=mask).to(q.dtype) # type: ignore
- return v if mask is None else v.nan_to_num()
- else:
- KORNIA_CHECK(mask is None)
- q, k, v = (x.transpose(-2, -3).contiguous() for x in [q, k, v])
- m = self.flash_(q.half(), stack([k, v], 2).half())
- return m.transpose(-2, -3).to(q.dtype).clone()
- elif self.has_sdp:
- args = [x.contiguous() for x in [q, k, v]]
- v = F.scaled_dot_product_attention(*args, attn_mask=mask) # type: ignore
- return v if mask is None else v.nan_to_num()
- else:
- s = q.shape[-1] ** -0.5
- sim = einsum("...id,...jd->...ij", q, k) * s
- if mask is not None:
- sim.masked_fill(~mask, -float("inf"))
- attn = F.softmax(sim, -1)
- return einsum("...ij,...jd->...id", attn, v)
- class SelfBlock(Module):
- def __init__(self, embed_dim: int, num_heads: int, flash: bool = False, bias: bool = True) -> None:
- super().__init__()
- self.embed_dim = embed_dim
- self.num_heads = num_heads
- KORNIA_CHECK(self.embed_dim % num_heads == 0, "Embed dimension should be dividable by num_heads")
- self.head_dim = self.embed_dim // num_heads
- self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias)
- self.inner_attn = Attention(flash)
- self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
- self.ffn = nn.Sequential(
- nn.Linear(2 * embed_dim, 2 * embed_dim),
- nn.LayerNorm(2 * embed_dim, elementwise_affine=True),
- nn.GELU(),
- nn.Linear(2 * embed_dim, embed_dim),
- )
- def forward(
- self,
- x: Tensor,
- encoding: Tensor,
- mask: Optional[Tensor] = None,
- ) -> Tensor:
- qkv = self.Wqkv(x)
- qkv = qkv.unflatten(-1, (self.num_heads, -1, 3)).transpose(1, 2)
- q, k, v = qkv[..., 0], qkv[..., 1], qkv[..., 2]
- q = apply_cached_rotary_emb(encoding, q)
- k = apply_cached_rotary_emb(encoding, k)
- context = self.inner_attn(q, k, v, mask=mask)
- message = self.out_proj(context.transpose(1, 2).flatten(start_dim=-2))
- return x + self.ffn(concatenate([x, message], -1))
- class CrossBlock(Module):
- def __init__(self, embed_dim: int, num_heads: int, flash: bool = False, bias: bool = True) -> None:
- super().__init__()
- self.heads = num_heads
- dim_head = embed_dim // num_heads
- self.scale = dim_head**-0.5
- inner_dim = dim_head * num_heads
- self.to_qk = nn.Linear(embed_dim, inner_dim, bias=bias)
- self.to_v = nn.Linear(embed_dim, inner_dim, bias=bias)
- self.to_out = nn.Linear(inner_dim, embed_dim, bias=bias)
- self.ffn = nn.Sequential(
- nn.Linear(2 * embed_dim, 2 * embed_dim),
- nn.LayerNorm(2 * embed_dim, elementwise_affine=True),
- nn.GELU(),
- nn.Linear(2 * embed_dim, embed_dim),
- )
- if flash and FLASH_AVAILABLE:
- self.flash = Attention(True)
- else:
- self.flash = None # type: ignore
- def map_(self, func: Callable, x0: Tensor, x1: Tensor) -> Tuple[Tensor, Tensor]: # type: ignore
- return func(x0), func(x1)
- def forward(self, x0: Tensor, x1: Tensor, mask: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
- qk0, qk1 = self.map_(self.to_qk, x0, x1)
- v0, v1 = self.map_(self.to_v, x0, x1)
- qk0, qk1, v0, v1 = (t.unflatten(-1, (self.heads, -1)).transpose(1, 2) for t in (qk0, qk1, v0, v1))
- if self.flash is not None and qk0.device.type == "cuda":
- m0 = self.flash(qk0, qk1, v1, mask)
- m1 = self.flash(qk1, qk0, v0, mask.transpose(-1, -2) if mask is not None else None)
- else:
- qk0, qk1 = qk0 * self.scale**0.5, qk1 * self.scale**0.5
- sim = einsum("bhid, bhjd -> bhij", qk0, qk1)
- if mask is not None:
- sim = sim.masked_fill(~mask, -float("inf"))
- attn01 = F.softmax(sim, dim=-1)
- attn10 = F.softmax(sim.transpose(-2, -1).contiguous(), dim=-1)
- m0 = einsum("bhij, bhjd -> bhid", attn01, v1)
- m1 = einsum("bhji, bhjd -> bhid", attn10.transpose(-2, -1), v0)
- if mask is not None:
- m0, m1 = m0.nan_to_num(), m1.nan_to_num()
- m0, m1 = self.map_(lambda t: t.transpose(1, 2).flatten(start_dim=-2), m0, m1)
- m0, m1 = self.map_(self.to_out, m0, m1)
- x0 = x0 + self.ffn(concatenate([x0, m0], -1))
- x1 = x1 + self.ffn(concatenate([x1, m1], -1))
- return x0, x1
- class TransformerLayer(Module):
- def __init__(self, *args, **kwargs): # type: ignore
- super().__init__()
- self.self_attn = SelfBlock(*args, **kwargs)
- self.cross_attn = CrossBlock(*args, **kwargs)
- def forward(
- self,
- desc0: Tensor,
- desc1: Tensor,
- encoding0: Tensor,
- encoding1: Tensor,
- mask0: Optional[Tensor] = None,
- mask1: Optional[Tensor] = None,
- ) -> Tensor:
- if mask0 is not None and mask1 is not None:
- return self.masked_forward(desc0, desc1, encoding0, encoding1, mask0, mask1)
- else:
- desc0 = self.self_attn(desc0, encoding0)
- desc1 = self.self_attn(desc1, encoding1)
- return self.cross_attn(desc0, desc1)
- # This part is compiled and allows padding inputs
- def masked_forward(
- self, desc0: Tensor, desc1: Tensor, encoding0: Tensor, encoding1: Tensor, mask0: Tensor, mask1: Tensor
- ) -> Tensor:
- mask = mask0 & mask1.transpose(-1, -2)
- mask0 = mask0 & mask0.transpose(-1, -2)
- mask1 = mask1 & mask1.transpose(-1, -2)
- desc0 = self.self_attn(desc0, encoding0, mask0)
- desc1 = self.self_attn(desc1, encoding1, mask1)
- return self.cross_attn(desc0, desc1, mask)
- def sigmoid_log_double_softmax(sim: Tensor, z0: Tensor, z1: Tensor) -> Tensor:
- """Create the log assignment matrix from logits and similarity."""
- b, m, n = sim.shape
- certainties = F.logsigmoid(z0) + F.logsigmoid(z1).transpose(1, 2)
- scores0 = F.log_softmax(sim, 2)
- scores1 = F.log_softmax(sim.transpose(-1, -2).contiguous(), 2).transpose(-1, -2)
- scores = sim.new_full((b, m + 1, n + 1), 0)
- scores[:, :m, :n] = scores0 + scores1 + certainties
- scores[:, :-1, -1] = F.logsigmoid(-z0.squeeze(-1))
- scores[:, -1, :-1] = F.logsigmoid(-z1.squeeze(-1))
- return scores
- class MatchAssignment(Module):
- def __init__(self, dim: int) -> None:
- super().__init__()
- self.dim = dim
- self.matchability = nn.Linear(dim, 1, bias=True)
- self.final_proj = nn.Linear(dim, dim, bias=True)
- def forward(self, desc0: Tensor, desc1: Tensor) -> Tuple[Tensor, Tensor]:
- """Build assignment matrix from descriptors."""
- mdesc0, mdesc1 = self.final_proj(desc0), self.final_proj(desc1)
- _, _, d = mdesc0.shape
- mdesc0, mdesc1 = mdesc0 / d**0.25, mdesc1 / d**0.25
- sim = einsum("bmd,bnd->bmn", mdesc0, mdesc1)
- z0 = self.matchability(desc0)
- z1 = self.matchability(desc1)
- scores = sigmoid_log_double_softmax(sim, z0, z1)
- return scores, sim
- def get_matchability(self, desc: Tensor) -> Tensor:
- return torch.sigmoid(self.matchability(desc)).squeeze(-1)
- def filter_matches(scores: Tensor, th: float) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
- """Obtain matches from a log assignment matrix [Bx M+1 x N+1]."""
- max0, max1 = scores[:, :-1, :-1].max(2), scores[:, :-1, :-1].max(1)
- m0, m1 = max0.indices, max1.indices
- indices0 = arange(m0.shape[1], device=m0.device)[None]
- indices1 = arange(m1.shape[1], device=m1.device)[None]
- mutual0 = indices0 == m1.gather(1, m0)
- mutual1 = indices1 == m0.gather(1, m1)
- max0_exp = max0.values.exp()
- zero = max0_exp.new_tensor(0)
- mscores0 = where(mutual0, max0_exp, zero)
- mscores1 = where(mutual1, mscores0.gather(1, m1), zero)
- valid0 = mutual0 & (mscores0 > th)
- valid1 = mutual1 & valid0.gather(1, m1)
- m0 = where(valid0, m0, -1)
- m1 = where(valid1, m1, -1)
- return m0, m1, mscores0, mscores1
- class LightGlue(Module):
- default_conf: ClassVar[Dict[str, Any]] = {
- "name": "lightglue", # just for interfacing
- "input_dim": 256, # input descriptor dimension (autoselected from weights)
- "descriptor_dim": 256,
- "add_scale_ori": False,
- "add_laf": False, # for KeyNetAffNetHardNet
- "scale_coef": 1.0, # to compensate for the SIFT scale bigger than KeyNet
- "n_layers": 9,
- "num_heads": 4,
- "flash": True, # enable FlashAttention if available.
- "mp": False, # enable mixed precision
- "depth_confidence": 0.95, # early stopping, disable with -1
- "width_confidence": 0.99, # point pruning, disable with -1
- "filter_threshold": 0.1, # match threshold
- "weights": None,
- }
- # Point pruning involves an overhead (gather).
- # Therefore, we only activate it if there are enough keypoints.
- pruning_keypoint_thresholds: ClassVar[Dict[str, Any]] = {
- "cpu": -1,
- "mps": -1,
- "cuda": 1024,
- "flash": 1536,
- }
- required_data_keys: ClassVar[List[str]] = ["image0", "image1"]
- version: ClassVar[str] = "v0.1_arxiv"
- url: ClassVar[str] = "https://github.com/cvg/LightGlue/releases/download/{}/{}_lightglue.pth"
- features: ClassVar[Dict[str, Any]] = {
- "superpoint": {
- "weights": "superpoint_lightglue",
- "input_dim": 256,
- },
- "dedodeb": {
- "weights": "dedodeb_lightglue",
- "input_dim": 256,
- },
- "dedodeg": {
- "weights": "dedodeg_lightglue",
- "input_dim": 256,
- },
- "disk": {
- "weights": "disk_lightglue",
- "input_dim": 128,
- },
- "aliked": {
- "weights": "aliked_lightglue",
- "input_dim": 128,
- },
- "sift": {
- "weights": "sift_lightglue",
- "input_dim": 128,
- "add_scale_ori": True,
- },
- "dog_affnet_hardnet": {
- "weights": "doghardnet",
- "input_dim": 128,
- "width_confidence": -1,
- "depth_confidence": -1,
- "add_scale_ori": True,
- "scale_coef": 1.0 / 6.0,
- },
- "doghardnet": {
- "weights": "doghardnet",
- "input_dim": 128,
- "width_confidence": -1,
- "depth_confidence": -1,
- "add_scale_ori": True,
- "scale_coef": 1.0 / 6.0,
- },
- "keynet_affnet_hardnet": {
- "weights": "keynet_affnet_hardnet_lightglue",
- "input_dim": 128,
- "width_confidence": -1,
- "depth_confidence": -1,
- "add_laf": True,
- },
- }
- def __init__(self, features: str = "superpoint", **conf_) -> None: # type: ignore
- super().__init__()
- self.conf = conf = SimpleNamespace(**{**self.default_conf, **conf_})
- if features is not None:
- KORNIA_CHECK(features in list(self.features.keys()), "Features keys are wrong")
- for k, v in self.features[features].items():
- setattr(conf, k, v)
- KORNIA_CHECK(not (self.conf.add_scale_ori and self.conf.add_laf)) # we use either scale ori, or LAF
- if conf.input_dim != conf.descriptor_dim:
- self.input_proj = nn.Linear(conf.input_dim, conf.descriptor_dim, bias=True)
- else:
- self.input_proj = nn.Identity() # type: ignore
- head_dim = conf.descriptor_dim // conf.num_heads
- self.posenc = LearnableFourierPositionalEncoding(
- 2 + 2 * conf.add_scale_ori + 4 * conf.add_laf,
- head_dim,
- head_dim,
- )
- h, n, d = conf.num_heads, conf.n_layers, conf.descriptor_dim
- self.transformers = ModuleList([TransformerLayer(d, h, conf.flash) for _ in range(n)])
- self.log_assignment = ModuleList([MatchAssignment(d) for _ in range(n)])
- self.token_confidence = ModuleList([TokenConfidence(d) for _ in range(n - 1)])
- self.register_buffer(
- "confidence_thresholds",
- Tensor([self.confidence_threshold(i) for i in range(self.conf.n_layers)]),
- )
- state_dict = None
- if features is not None:
- fname = f"{conf.weights}_{self.version}.pth".replace(".", "-")
- if features == "dog_affnet_hardnet":
- features = "doghardnet" # new dog model is better for affnet as well
- if features in ["keynet_affnet_hardnet"]:
- fname = "keynet_affnet_hardnet_lightlue.pth"
- url = "http://cmp.felk.cvut.cz/~mishkdmy/models/keynet_affnet_hardnet_lightlue.pth"
- elif features in ["dedodeb"]:
- fname = "dedodeb_lightglue.pth"
- url = "http://cmp.felk.cvut.cz/~mishkdmy/models/dedodeb_lightglue.pth"
- elif features in ["dedodeg"]:
- fname = "dedodeg_lightglue.pth"
- url = "http://cmp.felk.cvut.cz/~mishkdmy/models/dedodeg_lightglue.pth"
- else:
- url = self.url.format(self.version, features)
- state_dict = torch.hub.load_state_dict_from_url(url, file_name=fname)
- elif conf.weights is not None:
- path = Path(__file__).parent
- path = path / f"weights/{self.conf.weights}.pth"
- state_dict = torch.load(str(path), map_location="cpu")
- if state_dict:
- # rename old state dict entries
- for i in range(self.conf.n_layers):
- pattern = f"self_attn.{i}", f"transformers.{i}.self_attn"
- state_dict = {k.replace(*pattern): v for k, v in state_dict.items()}
- pattern = f"cross_attn.{i}", f"transformers.{i}.cross_attn"
- state_dict = {k.replace(*pattern): v for k, v in state_dict.items()}
- self.load_state_dict(state_dict, strict=False)
- print("Loaded LightGlue model")
- # static lengths LightGlue is compiled for (only used with torch.compile)
- self.static_lengths = None
- def compile(
- self, mode: str = "reduce-overhead", static_lengths: Sequence[int] = (256, 512, 768, 1024, 1280, 1536)
- ) -> None:
- if self.conf.width_confidence != -1:
- warnings.warn(
- "Point pruning is partially disabled for compiled forward.",
- stacklevel=2,
- )
- for i in range(self.conf.n_layers):
- self.transformers[i].masked_forward = torch.compile( # type: ignore[assignment]
- self.transformers[i].masked_forward, mode=mode, fullgraph=True
- )
- self.static_lengths = static_lengths # type: ignore
- def forward(self, data: dict) -> dict: # type: ignore
- """Match keypoints and descriptors between two images.
- Input (dict):
- image0: dict
- keypoints: [B x M x 2]
- descriptors: [B x M x D]
- image: [B x C x H x W] or image_size: [B x 2]
- image1: dict
- keypoints: [B x N x 2]
- descriptors: [B x N x D]
- image: [B x C x H x W] or image_size: [B x 2]
- Output (dict):
- log_assignment: [B x M+1 x N+1]
- matches0: [B x M]
- matching_scores0: [B x M]
- matches1: [B x N]
- matching_scores1: [B x N]
- matches: List[[Si x 2]], scores: List[[Si]]
- """
- with torch.autocast(enabled=self.conf.mp, device_type="cuda"):
- return self._forward(data)
- def _forward(self, data: dict) -> dict: # type: ignore
- for key in self.required_data_keys:
- KORNIA_CHECK(key in data, f"Missing key {key} in data")
- data0, data1 = data["image0"], data["image1"]
- kpts0, kpts1 = data0["keypoints"], data1["keypoints"]
- b, m, _ = kpts0.shape
- b, n, _ = kpts1.shape
- device = kpts0.device
- size0, size1 = data0.get("image_size"), data1.get("image_size")
- size0 = size0 if size0 is not None else data0["image"].shape[-2:][::-1]
- size1 = size1 if size1 is not None else data1["image"].shape[-2:][::-1]
- kpts0 = normalize_keypoints(kpts0, size0).clone()
- kpts1 = normalize_keypoints(kpts1, size1).clone()
- KORNIA_CHECK(torch.all(kpts0 >= -1).item() and torch.all(kpts0 <= 1).item(), "") # type: ignore
- KORNIA_CHECK(torch.all(kpts1 >= -1).item() and torch.all(kpts1 <= 1).item(), "") # type: ignore
- if self.conf.add_scale_ori:
- kpts0 = concatenate([kpts0] + [data0[k].unsqueeze(-1) for k in ("scales", "oris")], -1)
- if self.conf.scale_coef != 1.0:
- kpts0[..., -2] = kpts0[..., -2] * self.conf.scale_coef
- kpts1 = concatenate([kpts1] + [data1[k].unsqueeze(-1) for k in ("scales", "oris")], -1)
- if self.conf.scale_coef != 1.0:
- kpts1[..., -2] = kpts1[..., -2] * self.conf.scale_coef
- elif self.conf.add_laf:
- laf0 = scale_laf(data0["lafs"], self.conf.scale_coef)
- laf1 = scale_laf(data1["lafs"], self.conf.scale_coef)
- laf0 = laf_to_three_points(laf0)
- laf1 = laf_to_three_points(laf1)
- kpts0 = concatenate(
- [
- kpts0,
- normalize_keypoints(laf0[..., 0], size0).clone().to(kpts0.dtype),
- normalize_keypoints(laf0[..., 1], size0).clone().to(kpts0.dtype),
- ],
- -1,
- )
- kpts1 = concatenate(
- [
- kpts1,
- normalize_keypoints(laf1[..., 0], size1).clone().to(kpts1.dtype),
- normalize_keypoints(laf1[..., 1], size1).clone().to(kpts1.dtype),
- ],
- -1,
- )
- desc0 = data0["descriptors"].detach().contiguous()
- desc1 = data1["descriptors"].detach().contiguous()
- KORNIA_CHECK(desc0.shape[-1] == self.conf.input_dim, "Descriptor dimension does not match input dim in config")
- KORNIA_CHECK(desc1.shape[-1] == self.conf.input_dim, "Descriptor dimension does not match input dim in config")
- if torch.is_autocast_enabled():
- desc0 = desc0.half()
- desc1 = desc1.half()
- mask0, mask1 = None, None
- c = max(m, n)
- do_compile = self.static_lengths and c <= max(self.static_lengths)
- if do_compile:
- kn = min([k for k in self.static_lengths if k >= c])
- desc0, mask0 = pad_to_length(desc0, kn)
- desc1, mask1 = pad_to_length(desc1, kn)
- kpts0, _ = pad_to_length(kpts0, kn)
- kpts1, _ = pad_to_length(kpts1, kn)
- desc0 = self.input_proj(desc0)
- desc1 = self.input_proj(desc1)
- # cache positional embeddings
- encoding0 = self.posenc(kpts0)
- encoding1 = self.posenc(kpts1)
- # GNN + final_proj + assignment
- do_early_stop = self.conf.depth_confidence > 0
- do_point_pruning = self.conf.width_confidence > 0 and not do_compile
- pruning_th = self.pruning_min_kpts(device)
- if do_point_pruning:
- ind0 = arange(0, m, device=device)[None]
- ind1 = arange(0, n, device=device)[None]
- # We store the index of the layer at which pruning is detected.
- prune0 = ones_like(ind0)
- prune1 = ones_like(ind1)
- token0, token1 = None, None
- for i in range(self.conf.n_layers):
- desc0, desc1 = self.transformers[i](desc0, desc1, encoding0, encoding1, mask0=mask0, mask1=mask1)
- if i == self.conf.n_layers - 1:
- continue # no early stopping or adaptive width at last layer
- if do_early_stop:
- token0, token1 = self.token_confidence[i](desc0, desc1)
- if self.check_if_stop(token0[..., :m, :], token1[..., :n, :], i, m + n):
- break
- if do_point_pruning and desc0.shape[-2] > pruning_th:
- scores0 = self.log_assignment[i].get_matchability(desc0)
- prunemask0 = self.get_pruning_mask(token0, scores0, i) # type: ignore
- keep0 = where(prunemask0)[1]
- ind0 = ind0.index_select(1, keep0)
- desc0 = desc0.index_select(1, keep0)
- encoding0 = encoding0.index_select(-2, keep0)
- prune0[:, ind0] += 1
- if do_point_pruning and desc1.shape[-2] > pruning_th:
- scores1 = self.log_assignment[i].get_matchability(desc1)
- prunemask1 = self.get_pruning_mask(token1, scores1, i) # type: ignore
- keep1 = where(prunemask1)[1]
- ind1 = ind1.index_select(1, keep1)
- desc1 = desc1.index_select(1, keep1)
- encoding1 = encoding1.index_select(-2, keep1)
- prune1[:, ind1] += 1
- desc0, desc1 = desc0[..., :m, :], desc1[..., :n, :]
- scores, _ = self.log_assignment[i](desc0, desc1)
- m0, m1, mscores0, mscores1 = filter_matches(scores, self.conf.filter_threshold)
- matches, mscores = [], []
- for k in range(b):
- valid = m0[k] > -1
- m_indices_0 = where(valid)[0]
- m_indices_1 = m0[k][valid]
- if do_point_pruning:
- m_indices_0 = ind0[k, m_indices_0]
- m_indices_1 = ind1[k, m_indices_1]
- matches.append(stack([m_indices_0, m_indices_1], -1))
- mscores.append(mscores0[k][valid])
- # TODO: Remove when hloc switches to the compact format.
- if do_point_pruning:
- m0_ = torch.full((b, m), -1, device=m0.device, dtype=m0.dtype)
- m1_ = torch.full((b, n), -1, device=m1.device, dtype=m1.dtype)
- m0_[:, ind0] = where(m0 == -1, -1, ind1.gather(1, m0.clamp(min=0)))
- m1_[:, ind1] = where(m1 == -1, -1, ind0.gather(1, m1.clamp(min=0)))
- mscores0_ = zeros((b, m), device=mscores0.device)
- mscores1_ = zeros((b, n), device=mscores1.device)
- mscores0_[:, ind0] = mscores0
- mscores1_[:, ind1] = mscores1
- m0, m1, mscores0, mscores1 = m0_, m1_, mscores0_, mscores1_
- else:
- prune0 = ones_like(mscores0) * self.conf.n_layers
- prune1 = ones_like(mscores1) * self.conf.n_layers
- pred = {
- "log_assignment": scores,
- "matches0": m0,
- "matches1": m1,
- "matching_scores0": mscores0,
- "matching_scores1": mscores1,
- "stop": i + 1,
- "matches": matches,
- "scores": mscores,
- "prune0": prune0,
- "prune1": prune1,
- }
- return pred
- def confidence_threshold(self, layer_index: int) -> float:
- """Scaled confidence threshold."""
- threshold = 0.8 + 0.1 * math.exp(-4.0 * layer_index / self.conf.n_layers)
- return min(max(threshold, 0), 1)
- def get_pruning_mask(self, confidences: Tensor, scores: Tensor, layer_index: int) -> Tensor:
- """Mask points which should be removed."""
- keep = scores > (1 - self.conf.width_confidence)
- if confidences is not None: # Low-confidence points are never pruned.
- keep |= confidences <= self.confidence_thresholds[layer_index]
- return keep
- def check_if_stop(
- self,
- confidences0: Tensor,
- confidences1: Tensor,
- layer_index: int,
- num_points: int,
- ) -> Tensor:
- """Evaluate stopping condition."""
- confidences = concatenate([confidences0, confidences1], -1)
- threshold = self.confidence_thresholds[layer_index]
- ratio_confident = 1.0 - (confidences < threshold).float().sum() / num_points
- return ratio_confident > self.conf.depth_confidence
- def pruning_min_kpts(self, device: torch.device) -> int:
- if self.conf.flash and FLASH_AVAILABLE and device.type == "cuda":
- return self.pruning_keypoint_thresholds["flash"]
- else:
- return self.pruning_keypoint_thresholds[device.type]
|