lightglue.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. import math
  18. import warnings
  19. from pathlib import Path
  20. from types import SimpleNamespace
  21. from typing import Any, Callable, ClassVar, Dict, List, Optional, Sequence, Tuple
  22. import torch
  23. import torch.nn.functional as F
  24. from torch import nn
  25. from kornia.core import (
  26. Module,
  27. ModuleList,
  28. Tensor,
  29. arange,
  30. concatenate,
  31. cos,
  32. einsum,
  33. ones,
  34. ones_like,
  35. sin,
  36. stack,
  37. where,
  38. zeros,
  39. )
  40. from kornia.core.check import KORNIA_CHECK
  41. from kornia.feature.laf import laf_to_three_points, scale_laf
  42. from kornia.utils._compat import custom_fwd
  43. try:
  44. from flash_attn.modules.mha import FlashCrossAttention
  45. except ModuleNotFoundError:
  46. FlashCrossAttention = None
  47. if FlashCrossAttention or hasattr(F, "scaled_dot_product_attention"):
  48. FLASH_AVAILABLE = True
  49. else:
  50. FLASH_AVAILABLE = False
  51. def math_clamp(x, min_, max_): # type: ignore
  52. """Clamp a value to lie within [min, max]."""
  53. return min(max(x, min_), max_)
  54. @custom_fwd(cast_inputs=torch.float32)
  55. def normalize_keypoints(kpts: Tensor, size: Tensor) -> Tensor:
  56. """Normalize tensor of keypoints."""
  57. if isinstance(size, torch.Size):
  58. size = torch.tensor(size)[None]
  59. shift = size.float().to(kpts) / 2
  60. scale = size.max(1).values.float().to(kpts) / 2
  61. kpts = (kpts - shift[:, None]) / scale[:, None, None]
  62. return kpts
  63. def pad_to_length(x: Tensor, length: int) -> Tuple[Tensor, Tensor]:
  64. """Pad tensor to desired length."""
  65. if length <= x.shape[-2]:
  66. return x, ones_like(x[..., :1], dtype=torch.bool)
  67. pad = ones(*x.shape[:-2], length - x.shape[-2], x.shape[-1], device=x.device, dtype=x.dtype)
  68. y = concatenate([x, pad], dim=-2)
  69. mask = zeros(*y.shape[:-1], 1, dtype=torch.bool, device=x.device)
  70. mask[..., : x.shape[-2], :] = True
  71. return y, mask
  72. def rotate_half(x: Tensor) -> Tensor:
  73. """Apply half rotation."""
  74. x = x.unflatten(-1, (-1, 2))
  75. x1, x2 = x.unbind(dim=-1)
  76. return stack((-x2, x1), dim=-1).flatten(start_dim=-2)
  77. def apply_cached_rotary_emb(freqs: Tensor, t: Tensor) -> Tensor:
  78. """Apply rotary embedding."""
  79. return (t * freqs[0]) + (rotate_half(t) * freqs[1])
  80. class LearnableFourierPositionalEncoding(Module):
  81. def __init__(self, M: int, dim: int, F_dim: Optional[int] = None, gamma: float = 1.0) -> None:
  82. super().__init__()
  83. F_dim = F_dim if F_dim is not None else dim
  84. self.gamma = gamma
  85. self.Wr = nn.Linear(M, F_dim // 2, bias=False)
  86. nn.init.normal_(self.Wr.weight.data, mean=0, std=self.gamma**-2)
  87. def forward(self, x: Tensor) -> Tensor:
  88. """Encode position vector."""
  89. projected = self.Wr(x)
  90. cosines, sines = cos(projected), sin(projected)
  91. emb = stack([cosines, sines], 0).unsqueeze(-3)
  92. return emb.repeat_interleave(2, dim=-1)
  93. class TokenConfidence(Module):
  94. def __init__(self, dim: int) -> None:
  95. super().__init__()
  96. self.token = nn.Sequential(nn.Linear(dim, 1), nn.Sigmoid())
  97. def forward(self, desc0: Tensor, desc1: Tensor) -> Tuple[Tensor, Tensor]:
  98. """Get confidence tokens."""
  99. dtype = self.token[0].weight.dtype
  100. orig_dtype = desc0.dtype
  101. return (
  102. self.token(desc0.detach().to(dtype)).squeeze(-1).to(orig_dtype),
  103. self.token(desc1.detach().to(dtype)).squeeze(-1).to(orig_dtype),
  104. )
  105. class Attention(Module):
  106. def __init__(self, allow_flash: bool) -> None:
  107. super().__init__()
  108. if allow_flash and not FLASH_AVAILABLE:
  109. warnings.warn(
  110. "FlashAttention is not available. For optimal speed, consider installing torch >= 2.0 or flash-attn.",
  111. stacklevel=2,
  112. )
  113. self.enable_flash = allow_flash and FLASH_AVAILABLE
  114. self.has_sdp = hasattr(F, "scaled_dot_product_attention")
  115. if allow_flash and FlashCrossAttention:
  116. self.flash_ = FlashCrossAttention()
  117. if self.has_sdp:
  118. torch.backends.cuda.enable_flash_sdp(allow_flash)
  119. def forward(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None) -> Tensor:
  120. if self.enable_flash and q.device.type == "cuda":
  121. # use torch 2.0 scaled_dot_product_attention with flash
  122. if self.has_sdp:
  123. args = [x.half().contiguous() for x in [q, k, v]]
  124. v = F.scaled_dot_product_attention(*args, attn_mask=mask).to(q.dtype) # type: ignore
  125. return v if mask is None else v.nan_to_num()
  126. else:
  127. KORNIA_CHECK(mask is None)
  128. q, k, v = (x.transpose(-2, -3).contiguous() for x in [q, k, v])
  129. m = self.flash_(q.half(), stack([k, v], 2).half())
  130. return m.transpose(-2, -3).to(q.dtype).clone()
  131. elif self.has_sdp:
  132. args = [x.contiguous() for x in [q, k, v]]
  133. v = F.scaled_dot_product_attention(*args, attn_mask=mask) # type: ignore
  134. return v if mask is None else v.nan_to_num()
  135. else:
  136. s = q.shape[-1] ** -0.5
  137. sim = einsum("...id,...jd->...ij", q, k) * s
  138. if mask is not None:
  139. sim.masked_fill(~mask, -float("inf"))
  140. attn = F.softmax(sim, -1)
  141. return einsum("...ij,...jd->...id", attn, v)
  142. class SelfBlock(Module):
  143. def __init__(self, embed_dim: int, num_heads: int, flash: bool = False, bias: bool = True) -> None:
  144. super().__init__()
  145. self.embed_dim = embed_dim
  146. self.num_heads = num_heads
  147. KORNIA_CHECK(self.embed_dim % num_heads == 0, "Embed dimension should be dividable by num_heads")
  148. self.head_dim = self.embed_dim // num_heads
  149. self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias)
  150. self.inner_attn = Attention(flash)
  151. self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  152. self.ffn = nn.Sequential(
  153. nn.Linear(2 * embed_dim, 2 * embed_dim),
  154. nn.LayerNorm(2 * embed_dim, elementwise_affine=True),
  155. nn.GELU(),
  156. nn.Linear(2 * embed_dim, embed_dim),
  157. )
  158. def forward(
  159. self,
  160. x: Tensor,
  161. encoding: Tensor,
  162. mask: Optional[Tensor] = None,
  163. ) -> Tensor:
  164. qkv = self.Wqkv(x)
  165. qkv = qkv.unflatten(-1, (self.num_heads, -1, 3)).transpose(1, 2)
  166. q, k, v = qkv[..., 0], qkv[..., 1], qkv[..., 2]
  167. q = apply_cached_rotary_emb(encoding, q)
  168. k = apply_cached_rotary_emb(encoding, k)
  169. context = self.inner_attn(q, k, v, mask=mask)
  170. message = self.out_proj(context.transpose(1, 2).flatten(start_dim=-2))
  171. return x + self.ffn(concatenate([x, message], -1))
  172. class CrossBlock(Module):
  173. def __init__(self, embed_dim: int, num_heads: int, flash: bool = False, bias: bool = True) -> None:
  174. super().__init__()
  175. self.heads = num_heads
  176. dim_head = embed_dim // num_heads
  177. self.scale = dim_head**-0.5
  178. inner_dim = dim_head * num_heads
  179. self.to_qk = nn.Linear(embed_dim, inner_dim, bias=bias)
  180. self.to_v = nn.Linear(embed_dim, inner_dim, bias=bias)
  181. self.to_out = nn.Linear(inner_dim, embed_dim, bias=bias)
  182. self.ffn = nn.Sequential(
  183. nn.Linear(2 * embed_dim, 2 * embed_dim),
  184. nn.LayerNorm(2 * embed_dim, elementwise_affine=True),
  185. nn.GELU(),
  186. nn.Linear(2 * embed_dim, embed_dim),
  187. )
  188. if flash and FLASH_AVAILABLE:
  189. self.flash = Attention(True)
  190. else:
  191. self.flash = None # type: ignore
  192. def map_(self, func: Callable, x0: Tensor, x1: Tensor) -> Tuple[Tensor, Tensor]: # type: ignore
  193. return func(x0), func(x1)
  194. def forward(self, x0: Tensor, x1: Tensor, mask: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
  195. qk0, qk1 = self.map_(self.to_qk, x0, x1)
  196. v0, v1 = self.map_(self.to_v, x0, x1)
  197. qk0, qk1, v0, v1 = (t.unflatten(-1, (self.heads, -1)).transpose(1, 2) for t in (qk0, qk1, v0, v1))
  198. if self.flash is not None and qk0.device.type == "cuda":
  199. m0 = self.flash(qk0, qk1, v1, mask)
  200. m1 = self.flash(qk1, qk0, v0, mask.transpose(-1, -2) if mask is not None else None)
  201. else:
  202. qk0, qk1 = qk0 * self.scale**0.5, qk1 * self.scale**0.5
  203. sim = einsum("bhid, bhjd -> bhij", qk0, qk1)
  204. if mask is not None:
  205. sim = sim.masked_fill(~mask, -float("inf"))
  206. attn01 = F.softmax(sim, dim=-1)
  207. attn10 = F.softmax(sim.transpose(-2, -1).contiguous(), dim=-1)
  208. m0 = einsum("bhij, bhjd -> bhid", attn01, v1)
  209. m1 = einsum("bhji, bhjd -> bhid", attn10.transpose(-2, -1), v0)
  210. if mask is not None:
  211. m0, m1 = m0.nan_to_num(), m1.nan_to_num()
  212. m0, m1 = self.map_(lambda t: t.transpose(1, 2).flatten(start_dim=-2), m0, m1)
  213. m0, m1 = self.map_(self.to_out, m0, m1)
  214. x0 = x0 + self.ffn(concatenate([x0, m0], -1))
  215. x1 = x1 + self.ffn(concatenate([x1, m1], -1))
  216. return x0, x1
  217. class TransformerLayer(Module):
  218. def __init__(self, *args, **kwargs): # type: ignore
  219. super().__init__()
  220. self.self_attn = SelfBlock(*args, **kwargs)
  221. self.cross_attn = CrossBlock(*args, **kwargs)
  222. def forward(
  223. self,
  224. desc0: Tensor,
  225. desc1: Tensor,
  226. encoding0: Tensor,
  227. encoding1: Tensor,
  228. mask0: Optional[Tensor] = None,
  229. mask1: Optional[Tensor] = None,
  230. ) -> Tensor:
  231. if mask0 is not None and mask1 is not None:
  232. return self.masked_forward(desc0, desc1, encoding0, encoding1, mask0, mask1)
  233. else:
  234. desc0 = self.self_attn(desc0, encoding0)
  235. desc1 = self.self_attn(desc1, encoding1)
  236. return self.cross_attn(desc0, desc1)
  237. # This part is compiled and allows padding inputs
  238. def masked_forward(
  239. self, desc0: Tensor, desc1: Tensor, encoding0: Tensor, encoding1: Tensor, mask0: Tensor, mask1: Tensor
  240. ) -> Tensor:
  241. mask = mask0 & mask1.transpose(-1, -2)
  242. mask0 = mask0 & mask0.transpose(-1, -2)
  243. mask1 = mask1 & mask1.transpose(-1, -2)
  244. desc0 = self.self_attn(desc0, encoding0, mask0)
  245. desc1 = self.self_attn(desc1, encoding1, mask1)
  246. return self.cross_attn(desc0, desc1, mask)
  247. def sigmoid_log_double_softmax(sim: Tensor, z0: Tensor, z1: Tensor) -> Tensor:
  248. """Create the log assignment matrix from logits and similarity."""
  249. b, m, n = sim.shape
  250. certainties = F.logsigmoid(z0) + F.logsigmoid(z1).transpose(1, 2)
  251. scores0 = F.log_softmax(sim, 2)
  252. scores1 = F.log_softmax(sim.transpose(-1, -2).contiguous(), 2).transpose(-1, -2)
  253. scores = sim.new_full((b, m + 1, n + 1), 0)
  254. scores[:, :m, :n] = scores0 + scores1 + certainties
  255. scores[:, :-1, -1] = F.logsigmoid(-z0.squeeze(-1))
  256. scores[:, -1, :-1] = F.logsigmoid(-z1.squeeze(-1))
  257. return scores
  258. class MatchAssignment(Module):
  259. def __init__(self, dim: int) -> None:
  260. super().__init__()
  261. self.dim = dim
  262. self.matchability = nn.Linear(dim, 1, bias=True)
  263. self.final_proj = nn.Linear(dim, dim, bias=True)
  264. def forward(self, desc0: Tensor, desc1: Tensor) -> Tuple[Tensor, Tensor]:
  265. """Build assignment matrix from descriptors."""
  266. mdesc0, mdesc1 = self.final_proj(desc0), self.final_proj(desc1)
  267. _, _, d = mdesc0.shape
  268. mdesc0, mdesc1 = mdesc0 / d**0.25, mdesc1 / d**0.25
  269. sim = einsum("bmd,bnd->bmn", mdesc0, mdesc1)
  270. z0 = self.matchability(desc0)
  271. z1 = self.matchability(desc1)
  272. scores = sigmoid_log_double_softmax(sim, z0, z1)
  273. return scores, sim
  274. def get_matchability(self, desc: Tensor) -> Tensor:
  275. return torch.sigmoid(self.matchability(desc)).squeeze(-1)
  276. def filter_matches(scores: Tensor, th: float) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
  277. """Obtain matches from a log assignment matrix [Bx M+1 x N+1]."""
  278. max0, max1 = scores[:, :-1, :-1].max(2), scores[:, :-1, :-1].max(1)
  279. m0, m1 = max0.indices, max1.indices
  280. indices0 = arange(m0.shape[1], device=m0.device)[None]
  281. indices1 = arange(m1.shape[1], device=m1.device)[None]
  282. mutual0 = indices0 == m1.gather(1, m0)
  283. mutual1 = indices1 == m0.gather(1, m1)
  284. max0_exp = max0.values.exp()
  285. zero = max0_exp.new_tensor(0)
  286. mscores0 = where(mutual0, max0_exp, zero)
  287. mscores1 = where(mutual1, mscores0.gather(1, m1), zero)
  288. valid0 = mutual0 & (mscores0 > th)
  289. valid1 = mutual1 & valid0.gather(1, m1)
  290. m0 = where(valid0, m0, -1)
  291. m1 = where(valid1, m1, -1)
  292. return m0, m1, mscores0, mscores1
  293. class LightGlue(Module):
  294. default_conf: ClassVar[Dict[str, Any]] = {
  295. "name": "lightglue", # just for interfacing
  296. "input_dim": 256, # input descriptor dimension (autoselected from weights)
  297. "descriptor_dim": 256,
  298. "add_scale_ori": False,
  299. "add_laf": False, # for KeyNetAffNetHardNet
  300. "scale_coef": 1.0, # to compensate for the SIFT scale bigger than KeyNet
  301. "n_layers": 9,
  302. "num_heads": 4,
  303. "flash": True, # enable FlashAttention if available.
  304. "mp": False, # enable mixed precision
  305. "depth_confidence": 0.95, # early stopping, disable with -1
  306. "width_confidence": 0.99, # point pruning, disable with -1
  307. "filter_threshold": 0.1, # match threshold
  308. "weights": None,
  309. }
  310. # Point pruning involves an overhead (gather).
  311. # Therefore, we only activate it if there are enough keypoints.
  312. pruning_keypoint_thresholds: ClassVar[Dict[str, Any]] = {
  313. "cpu": -1,
  314. "mps": -1,
  315. "cuda": 1024,
  316. "flash": 1536,
  317. }
  318. required_data_keys: ClassVar[List[str]] = ["image0", "image1"]
  319. version: ClassVar[str] = "v0.1_arxiv"
  320. url: ClassVar[str] = "https://github.com/cvg/LightGlue/releases/download/{}/{}_lightglue.pth"
  321. features: ClassVar[Dict[str, Any]] = {
  322. "superpoint": {
  323. "weights": "superpoint_lightglue",
  324. "input_dim": 256,
  325. },
  326. "dedodeb": {
  327. "weights": "dedodeb_lightglue",
  328. "input_dim": 256,
  329. },
  330. "dedodeg": {
  331. "weights": "dedodeg_lightglue",
  332. "input_dim": 256,
  333. },
  334. "disk": {
  335. "weights": "disk_lightglue",
  336. "input_dim": 128,
  337. },
  338. "aliked": {
  339. "weights": "aliked_lightglue",
  340. "input_dim": 128,
  341. },
  342. "sift": {
  343. "weights": "sift_lightglue",
  344. "input_dim": 128,
  345. "add_scale_ori": True,
  346. },
  347. "dog_affnet_hardnet": {
  348. "weights": "doghardnet",
  349. "input_dim": 128,
  350. "width_confidence": -1,
  351. "depth_confidence": -1,
  352. "add_scale_ori": True,
  353. "scale_coef": 1.0 / 6.0,
  354. },
  355. "doghardnet": {
  356. "weights": "doghardnet",
  357. "input_dim": 128,
  358. "width_confidence": -1,
  359. "depth_confidence": -1,
  360. "add_scale_ori": True,
  361. "scale_coef": 1.0 / 6.0,
  362. },
  363. "keynet_affnet_hardnet": {
  364. "weights": "keynet_affnet_hardnet_lightglue",
  365. "input_dim": 128,
  366. "width_confidence": -1,
  367. "depth_confidence": -1,
  368. "add_laf": True,
  369. },
  370. }
  371. def __init__(self, features: str = "superpoint", **conf_) -> None: # type: ignore
  372. super().__init__()
  373. self.conf = conf = SimpleNamespace(**{**self.default_conf, **conf_})
  374. if features is not None:
  375. KORNIA_CHECK(features in list(self.features.keys()), "Features keys are wrong")
  376. for k, v in self.features[features].items():
  377. setattr(conf, k, v)
  378. KORNIA_CHECK(not (self.conf.add_scale_ori and self.conf.add_laf)) # we use either scale ori, or LAF
  379. if conf.input_dim != conf.descriptor_dim:
  380. self.input_proj = nn.Linear(conf.input_dim, conf.descriptor_dim, bias=True)
  381. else:
  382. self.input_proj = nn.Identity() # type: ignore
  383. head_dim = conf.descriptor_dim // conf.num_heads
  384. self.posenc = LearnableFourierPositionalEncoding(
  385. 2 + 2 * conf.add_scale_ori + 4 * conf.add_laf,
  386. head_dim,
  387. head_dim,
  388. )
  389. h, n, d = conf.num_heads, conf.n_layers, conf.descriptor_dim
  390. self.transformers = ModuleList([TransformerLayer(d, h, conf.flash) for _ in range(n)])
  391. self.log_assignment = ModuleList([MatchAssignment(d) for _ in range(n)])
  392. self.token_confidence = ModuleList([TokenConfidence(d) for _ in range(n - 1)])
  393. self.register_buffer(
  394. "confidence_thresholds",
  395. Tensor([self.confidence_threshold(i) for i in range(self.conf.n_layers)]),
  396. )
  397. state_dict = None
  398. if features is not None:
  399. fname = f"{conf.weights}_{self.version}.pth".replace(".", "-")
  400. if features == "dog_affnet_hardnet":
  401. features = "doghardnet" # new dog model is better for affnet as well
  402. if features in ["keynet_affnet_hardnet"]:
  403. fname = "keynet_affnet_hardnet_lightlue.pth"
  404. url = "http://cmp.felk.cvut.cz/~mishkdmy/models/keynet_affnet_hardnet_lightlue.pth"
  405. elif features in ["dedodeb"]:
  406. fname = "dedodeb_lightglue.pth"
  407. url = "http://cmp.felk.cvut.cz/~mishkdmy/models/dedodeb_lightglue.pth"
  408. elif features in ["dedodeg"]:
  409. fname = "dedodeg_lightglue.pth"
  410. url = "http://cmp.felk.cvut.cz/~mishkdmy/models/dedodeg_lightglue.pth"
  411. else:
  412. url = self.url.format(self.version, features)
  413. state_dict = torch.hub.load_state_dict_from_url(url, file_name=fname)
  414. elif conf.weights is not None:
  415. path = Path(__file__).parent
  416. path = path / f"weights/{self.conf.weights}.pth"
  417. state_dict = torch.load(str(path), map_location="cpu")
  418. if state_dict:
  419. # rename old state dict entries
  420. for i in range(self.conf.n_layers):
  421. pattern = f"self_attn.{i}", f"transformers.{i}.self_attn"
  422. state_dict = {k.replace(*pattern): v for k, v in state_dict.items()}
  423. pattern = f"cross_attn.{i}", f"transformers.{i}.cross_attn"
  424. state_dict = {k.replace(*pattern): v for k, v in state_dict.items()}
  425. self.load_state_dict(state_dict, strict=False)
  426. print("Loaded LightGlue model")
  427. # static lengths LightGlue is compiled for (only used with torch.compile)
  428. self.static_lengths = None
  429. def compile(
  430. self, mode: str = "reduce-overhead", static_lengths: Sequence[int] = (256, 512, 768, 1024, 1280, 1536)
  431. ) -> None:
  432. if self.conf.width_confidence != -1:
  433. warnings.warn(
  434. "Point pruning is partially disabled for compiled forward.",
  435. stacklevel=2,
  436. )
  437. for i in range(self.conf.n_layers):
  438. self.transformers[i].masked_forward = torch.compile( # type: ignore[assignment]
  439. self.transformers[i].masked_forward, mode=mode, fullgraph=True
  440. )
  441. self.static_lengths = static_lengths # type: ignore
  442. def forward(self, data: dict) -> dict: # type: ignore
  443. """Match keypoints and descriptors between two images.
  444. Input (dict):
  445. image0: dict
  446. keypoints: [B x M x 2]
  447. descriptors: [B x M x D]
  448. image: [B x C x H x W] or image_size: [B x 2]
  449. image1: dict
  450. keypoints: [B x N x 2]
  451. descriptors: [B x N x D]
  452. image: [B x C x H x W] or image_size: [B x 2]
  453. Output (dict):
  454. log_assignment: [B x M+1 x N+1]
  455. matches0: [B x M]
  456. matching_scores0: [B x M]
  457. matches1: [B x N]
  458. matching_scores1: [B x N]
  459. matches: List[[Si x 2]], scores: List[[Si]]
  460. """
  461. with torch.autocast(enabled=self.conf.mp, device_type="cuda"):
  462. return self._forward(data)
  463. def _forward(self, data: dict) -> dict: # type: ignore
  464. for key in self.required_data_keys:
  465. KORNIA_CHECK(key in data, f"Missing key {key} in data")
  466. data0, data1 = data["image0"], data["image1"]
  467. kpts0, kpts1 = data0["keypoints"], data1["keypoints"]
  468. b, m, _ = kpts0.shape
  469. b, n, _ = kpts1.shape
  470. device = kpts0.device
  471. size0, size1 = data0.get("image_size"), data1.get("image_size")
  472. size0 = size0 if size0 is not None else data0["image"].shape[-2:][::-1]
  473. size1 = size1 if size1 is not None else data1["image"].shape[-2:][::-1]
  474. kpts0 = normalize_keypoints(kpts0, size0).clone()
  475. kpts1 = normalize_keypoints(kpts1, size1).clone()
  476. KORNIA_CHECK(torch.all(kpts0 >= -1).item() and torch.all(kpts0 <= 1).item(), "") # type: ignore
  477. KORNIA_CHECK(torch.all(kpts1 >= -1).item() and torch.all(kpts1 <= 1).item(), "") # type: ignore
  478. if self.conf.add_scale_ori:
  479. kpts0 = concatenate([kpts0] + [data0[k].unsqueeze(-1) for k in ("scales", "oris")], -1)
  480. if self.conf.scale_coef != 1.0:
  481. kpts0[..., -2] = kpts0[..., -2] * self.conf.scale_coef
  482. kpts1 = concatenate([kpts1] + [data1[k].unsqueeze(-1) for k in ("scales", "oris")], -1)
  483. if self.conf.scale_coef != 1.0:
  484. kpts1[..., -2] = kpts1[..., -2] * self.conf.scale_coef
  485. elif self.conf.add_laf:
  486. laf0 = scale_laf(data0["lafs"], self.conf.scale_coef)
  487. laf1 = scale_laf(data1["lafs"], self.conf.scale_coef)
  488. laf0 = laf_to_three_points(laf0)
  489. laf1 = laf_to_three_points(laf1)
  490. kpts0 = concatenate(
  491. [
  492. kpts0,
  493. normalize_keypoints(laf0[..., 0], size0).clone().to(kpts0.dtype),
  494. normalize_keypoints(laf0[..., 1], size0).clone().to(kpts0.dtype),
  495. ],
  496. -1,
  497. )
  498. kpts1 = concatenate(
  499. [
  500. kpts1,
  501. normalize_keypoints(laf1[..., 0], size1).clone().to(kpts1.dtype),
  502. normalize_keypoints(laf1[..., 1], size1).clone().to(kpts1.dtype),
  503. ],
  504. -1,
  505. )
  506. desc0 = data0["descriptors"].detach().contiguous()
  507. desc1 = data1["descriptors"].detach().contiguous()
  508. KORNIA_CHECK(desc0.shape[-1] == self.conf.input_dim, "Descriptor dimension does not match input dim in config")
  509. KORNIA_CHECK(desc1.shape[-1] == self.conf.input_dim, "Descriptor dimension does not match input dim in config")
  510. if torch.is_autocast_enabled():
  511. desc0 = desc0.half()
  512. desc1 = desc1.half()
  513. mask0, mask1 = None, None
  514. c = max(m, n)
  515. do_compile = self.static_lengths and c <= max(self.static_lengths)
  516. if do_compile:
  517. kn = min([k for k in self.static_lengths if k >= c])
  518. desc0, mask0 = pad_to_length(desc0, kn)
  519. desc1, mask1 = pad_to_length(desc1, kn)
  520. kpts0, _ = pad_to_length(kpts0, kn)
  521. kpts1, _ = pad_to_length(kpts1, kn)
  522. desc0 = self.input_proj(desc0)
  523. desc1 = self.input_proj(desc1)
  524. # cache positional embeddings
  525. encoding0 = self.posenc(kpts0)
  526. encoding1 = self.posenc(kpts1)
  527. # GNN + final_proj + assignment
  528. do_early_stop = self.conf.depth_confidence > 0
  529. do_point_pruning = self.conf.width_confidence > 0 and not do_compile
  530. pruning_th = self.pruning_min_kpts(device)
  531. if do_point_pruning:
  532. ind0 = arange(0, m, device=device)[None]
  533. ind1 = arange(0, n, device=device)[None]
  534. # We store the index of the layer at which pruning is detected.
  535. prune0 = ones_like(ind0)
  536. prune1 = ones_like(ind1)
  537. token0, token1 = None, None
  538. for i in range(self.conf.n_layers):
  539. desc0, desc1 = self.transformers[i](desc0, desc1, encoding0, encoding1, mask0=mask0, mask1=mask1)
  540. if i == self.conf.n_layers - 1:
  541. continue # no early stopping or adaptive width at last layer
  542. if do_early_stop:
  543. token0, token1 = self.token_confidence[i](desc0, desc1)
  544. if self.check_if_stop(token0[..., :m, :], token1[..., :n, :], i, m + n):
  545. break
  546. if do_point_pruning and desc0.shape[-2] > pruning_th:
  547. scores0 = self.log_assignment[i].get_matchability(desc0)
  548. prunemask0 = self.get_pruning_mask(token0, scores0, i) # type: ignore
  549. keep0 = where(prunemask0)[1]
  550. ind0 = ind0.index_select(1, keep0)
  551. desc0 = desc0.index_select(1, keep0)
  552. encoding0 = encoding0.index_select(-2, keep0)
  553. prune0[:, ind0] += 1
  554. if do_point_pruning and desc1.shape[-2] > pruning_th:
  555. scores1 = self.log_assignment[i].get_matchability(desc1)
  556. prunemask1 = self.get_pruning_mask(token1, scores1, i) # type: ignore
  557. keep1 = where(prunemask1)[1]
  558. ind1 = ind1.index_select(1, keep1)
  559. desc1 = desc1.index_select(1, keep1)
  560. encoding1 = encoding1.index_select(-2, keep1)
  561. prune1[:, ind1] += 1
  562. desc0, desc1 = desc0[..., :m, :], desc1[..., :n, :]
  563. scores, _ = self.log_assignment[i](desc0, desc1)
  564. m0, m1, mscores0, mscores1 = filter_matches(scores, self.conf.filter_threshold)
  565. matches, mscores = [], []
  566. for k in range(b):
  567. valid = m0[k] > -1
  568. m_indices_0 = where(valid)[0]
  569. m_indices_1 = m0[k][valid]
  570. if do_point_pruning:
  571. m_indices_0 = ind0[k, m_indices_0]
  572. m_indices_1 = ind1[k, m_indices_1]
  573. matches.append(stack([m_indices_0, m_indices_1], -1))
  574. mscores.append(mscores0[k][valid])
  575. # TODO: Remove when hloc switches to the compact format.
  576. if do_point_pruning:
  577. m0_ = torch.full((b, m), -1, device=m0.device, dtype=m0.dtype)
  578. m1_ = torch.full((b, n), -1, device=m1.device, dtype=m1.dtype)
  579. m0_[:, ind0] = where(m0 == -1, -1, ind1.gather(1, m0.clamp(min=0)))
  580. m1_[:, ind1] = where(m1 == -1, -1, ind0.gather(1, m1.clamp(min=0)))
  581. mscores0_ = zeros((b, m), device=mscores0.device)
  582. mscores1_ = zeros((b, n), device=mscores1.device)
  583. mscores0_[:, ind0] = mscores0
  584. mscores1_[:, ind1] = mscores1
  585. m0, m1, mscores0, mscores1 = m0_, m1_, mscores0_, mscores1_
  586. else:
  587. prune0 = ones_like(mscores0) * self.conf.n_layers
  588. prune1 = ones_like(mscores1) * self.conf.n_layers
  589. pred = {
  590. "log_assignment": scores,
  591. "matches0": m0,
  592. "matches1": m1,
  593. "matching_scores0": mscores0,
  594. "matching_scores1": mscores1,
  595. "stop": i + 1,
  596. "matches": matches,
  597. "scores": mscores,
  598. "prune0": prune0,
  599. "prune1": prune1,
  600. }
  601. return pred
  602. def confidence_threshold(self, layer_index: int) -> float:
  603. """Scaled confidence threshold."""
  604. threshold = 0.8 + 0.1 * math.exp(-4.0 * layer_index / self.conf.n_layers)
  605. return min(max(threshold, 0), 1)
  606. def get_pruning_mask(self, confidences: Tensor, scores: Tensor, layer_index: int) -> Tensor:
  607. """Mask points which should be removed."""
  608. keep = scores > (1 - self.conf.width_confidence)
  609. if confidences is not None: # Low-confidence points are never pruned.
  610. keep |= confidences <= self.confidence_thresholds[layer_index]
  611. return keep
  612. def check_if_stop(
  613. self,
  614. confidences0: Tensor,
  615. confidences1: Tensor,
  616. layer_index: int,
  617. num_points: int,
  618. ) -> Tensor:
  619. """Evaluate stopping condition."""
  620. confidences = concatenate([confidences0, confidences1], -1)
  621. threshold = self.confidence_thresholds[layer_index]
  622. ratio_confident = 1.0 - (confidences < threshold).float().sum() / num_points
  623. return ratio_confident > self.conf.depth_confidence
  624. def pruning_min_kpts(self, device: torch.device) -> int:
  625. if self.conf.flash and FLASH_AVAILABLE and device.type == "cuda":
  626. return self.pruning_keypoint_thresholds["flash"]
  627. else:
  628. return self.pruning_keypoint_thresholds[device.type]