lightglue.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667
  1. import warnings
  2. from pathlib import Path
  3. from types import SimpleNamespace
  4. from typing import Callable, List, Optional, Tuple
  5. import numpy as np
  6. import torch
  7. import torch.nn.functional as F
  8. from torch import nn
  9. try:
  10. from flash_attn.modules.mha import FlashCrossAttention
  11. except ModuleNotFoundError:
  12. FlashCrossAttention = None
  13. if FlashCrossAttention or hasattr(F, "scaled_dot_product_attention"):
  14. FLASH_AVAILABLE = True
  15. else:
  16. FLASH_AVAILABLE = False
  17. torch.backends.cudnn.deterministic = True
  18. AMP_CUSTOM_FWD_F32 = (
  19. torch.amp.custom_fwd(cast_inputs=torch.float32, device_type="cuda")
  20. if hasattr(torch, "amp") and hasattr(torch.amp, "custom_fwd")
  21. else torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
  22. )
  23. @AMP_CUSTOM_FWD_F32
  24. def normalize_keypoints(
  25. kpts: torch.Tensor, size: Optional[torch.Tensor] = None
  26. ) -> torch.Tensor:
  27. if size is None:
  28. size = 1 + kpts.max(-2).values - kpts.min(-2).values
  29. elif not isinstance(size, torch.Tensor):
  30. size = torch.tensor(size, device=kpts.device, dtype=kpts.dtype)
  31. size = size.to(kpts)
  32. shift = size / 2
  33. scale = size.max(-1).values / 2
  34. kpts = (kpts - shift[..., None, :]) / scale[..., None, None]
  35. return kpts
  36. def pad_to_length(x: torch.Tensor, length: int) -> Tuple[torch.Tensor]:
  37. if length <= x.shape[-2]:
  38. return x, torch.ones_like(x[..., :1], dtype=torch.bool)
  39. pad = torch.ones(
  40. *x.shape[:-2], length - x.shape[-2], x.shape[-1], device=x.device, dtype=x.dtype
  41. )
  42. y = torch.cat([x, pad], dim=-2)
  43. mask = torch.zeros(*y.shape[:-1], 1, dtype=torch.bool, device=x.device)
  44. mask[..., : x.shape[-2], :] = True
  45. return y, mask
  46. def rotate_half(x: torch.Tensor) -> torch.Tensor:
  47. x = x.unflatten(-1, (-1, 2))
  48. x1, x2 = x.unbind(dim=-1)
  49. return torch.stack((-x2, x1), dim=-1).flatten(start_dim=-2)
  50. def apply_cached_rotary_emb(freqs: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
  51. return (t * freqs[0]) + (rotate_half(t) * freqs[1])
  52. class LearnableFourierPositionalEncoding(nn.Module):
  53. def __init__(self, M: int, dim: int, F_dim: int = None, gamma: float = 1.0) -> None:
  54. super().__init__()
  55. F_dim = F_dim if F_dim is not None else dim
  56. self.gamma = gamma
  57. self.Wr = nn.Linear(M, F_dim // 2, bias=False)
  58. nn.init.normal_(self.Wr.weight.data, mean=0, std=self.gamma**-2)
  59. def forward(self, x: torch.Tensor) -> torch.Tensor:
  60. """encode position vector"""
  61. projected = self.Wr(x)
  62. cosines, sines = torch.cos(projected), torch.sin(projected)
  63. emb = torch.stack([cosines, sines], 0).unsqueeze(-3)
  64. return emb.repeat_interleave(2, dim=-1)
  65. class TokenConfidence(nn.Module):
  66. def __init__(self, dim: int) -> None:
  67. super().__init__()
  68. self.token = nn.Sequential(nn.Linear(dim, 1), nn.Sigmoid())
  69. def forward(self, desc0: torch.Tensor, desc1: torch.Tensor):
  70. """get confidence tokens"""
  71. return (
  72. self.token(desc0.detach()).squeeze(-1),
  73. self.token(desc1.detach()).squeeze(-1),
  74. )
  75. class Attention(nn.Module):
  76. def __init__(self, allow_flash: bool) -> None:
  77. super().__init__()
  78. if allow_flash and not FLASH_AVAILABLE:
  79. warnings.warn(
  80. "FlashAttention is not available. For optimal speed, "
  81. "consider installing torch >= 2.0 or flash-attn.",
  82. stacklevel=2,
  83. )
  84. self.enable_flash = allow_flash and FLASH_AVAILABLE
  85. self.has_sdp = hasattr(F, "scaled_dot_product_attention")
  86. if allow_flash and FlashCrossAttention:
  87. self.flash_ = FlashCrossAttention()
  88. if self.has_sdp:
  89. torch.backends.cuda.enable_flash_sdp(allow_flash)
  90. def forward(self, q, k, v, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
  91. if q.shape[-2] == 0 or k.shape[-2] == 0:
  92. return q.new_zeros((*q.shape[:-1], v.shape[-1]))
  93. if self.enable_flash and q.device.type == "cuda":
  94. # use torch 2.0 scaled_dot_product_attention with flash
  95. if self.has_sdp:
  96. args = [x.half().contiguous() for x in [q, k, v]]
  97. v = F.scaled_dot_product_attention(*args, attn_mask=mask).to(q.dtype)
  98. return v if mask is None else v.nan_to_num()
  99. else:
  100. assert mask is None
  101. q, k, v = [x.transpose(-2, -3).contiguous() for x in [q, k, v]]
  102. m = self.flash_(q.half(), torch.stack([k, v], 2).half())
  103. return m.transpose(-2, -3).to(q.dtype).clone()
  104. elif self.has_sdp:
  105. args = [x.contiguous() for x in [q, k, v]]
  106. v = F.scaled_dot_product_attention(*args, attn_mask=mask)
  107. return v if mask is None else v.nan_to_num()
  108. else:
  109. s = q.shape[-1] ** -0.5
  110. sim = torch.einsum("...id,...jd->...ij", q, k) * s
  111. if mask is not None:
  112. sim.masked_fill(~mask, -float("inf"))
  113. attn = F.softmax(sim, -1)
  114. return torch.einsum("...ij,...jd->...id", attn, v)
  115. class SelfBlock(nn.Module):
  116. def __init__(
  117. self, embed_dim: int, num_heads: int, flash: bool = False, bias: bool = True
  118. ) -> None:
  119. super().__init__()
  120. self.embed_dim = embed_dim
  121. self.num_heads = num_heads
  122. assert self.embed_dim % num_heads == 0
  123. self.head_dim = self.embed_dim // num_heads
  124. self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias)
  125. self.inner_attn = Attention(flash)
  126. self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  127. self.ffn = nn.Sequential(
  128. nn.Linear(2 * embed_dim, 2 * embed_dim),
  129. nn.LayerNorm(2 * embed_dim, elementwise_affine=True),
  130. nn.GELU(),
  131. nn.Linear(2 * embed_dim, embed_dim),
  132. )
  133. def forward(
  134. self,
  135. x: torch.Tensor,
  136. encoding: torch.Tensor,
  137. mask: Optional[torch.Tensor] = None,
  138. ) -> torch.Tensor:
  139. qkv = self.Wqkv(x)
  140. qkv = qkv.unflatten(-1, (self.num_heads, -1, 3)).transpose(1, 2)
  141. q, k, v = qkv[..., 0], qkv[..., 1], qkv[..., 2]
  142. q = apply_cached_rotary_emb(encoding, q)
  143. k = apply_cached_rotary_emb(encoding, k)
  144. context = self.inner_attn(q, k, v, mask=mask)
  145. message = self.out_proj(context.transpose(1, 2).flatten(start_dim=-2))
  146. return x + self.ffn(torch.cat([x, message], -1))
  147. class CrossBlock(nn.Module):
  148. def __init__(
  149. self, embed_dim: int, num_heads: int, flash: bool = False, bias: bool = True
  150. ) -> None:
  151. super().__init__()
  152. self.heads = num_heads
  153. dim_head = embed_dim // num_heads
  154. self.scale = dim_head**-0.5
  155. inner_dim = dim_head * num_heads
  156. self.to_qk = nn.Linear(embed_dim, inner_dim, bias=bias)
  157. self.to_v = nn.Linear(embed_dim, inner_dim, bias=bias)
  158. self.to_out = nn.Linear(inner_dim, embed_dim, bias=bias)
  159. self.ffn = nn.Sequential(
  160. nn.Linear(2 * embed_dim, 2 * embed_dim),
  161. nn.LayerNorm(2 * embed_dim, elementwise_affine=True),
  162. nn.GELU(),
  163. nn.Linear(2 * embed_dim, embed_dim),
  164. )
  165. if flash and FLASH_AVAILABLE:
  166. self.flash = Attention(True)
  167. else:
  168. self.flash = None
  169. def map_(self, func: Callable, x0: torch.Tensor, x1: torch.Tensor):
  170. return func(x0), func(x1)
  171. def forward(
  172. self, x0: torch.Tensor, x1: torch.Tensor, mask: Optional[torch.Tensor] = None
  173. ) -> List[torch.Tensor]:
  174. qk0, qk1 = self.map_(self.to_qk, x0, x1)
  175. v0, v1 = self.map_(self.to_v, x0, x1)
  176. qk0, qk1, v0, v1 = map(
  177. lambda t: t.unflatten(-1, (self.heads, -1)).transpose(1, 2),
  178. (qk0, qk1, v0, v1),
  179. )
  180. if self.flash is not None and qk0.device.type == "cuda":
  181. m0 = self.flash(qk0, qk1, v1, mask)
  182. m1 = self.flash(
  183. qk1, qk0, v0, mask.transpose(-1, -2) if mask is not None else None
  184. )
  185. else:
  186. qk0, qk1 = qk0 * self.scale**0.5, qk1 * self.scale**0.5
  187. sim = torch.einsum("bhid, bhjd -> bhij", qk0, qk1)
  188. if mask is not None:
  189. sim = sim.masked_fill(~mask, -float("inf"))
  190. attn01 = F.softmax(sim, dim=-1)
  191. attn10 = F.softmax(sim.transpose(-2, -1).contiguous(), dim=-1)
  192. m0 = torch.einsum("bhij, bhjd -> bhid", attn01, v1)
  193. m1 = torch.einsum("bhji, bhjd -> bhid", attn10.transpose(-2, -1), v0)
  194. if mask is not None:
  195. m0, m1 = m0.nan_to_num(), m1.nan_to_num()
  196. m0, m1 = self.map_(lambda t: t.transpose(1, 2).flatten(start_dim=-2), m0, m1)
  197. m0, m1 = self.map_(self.to_out, m0, m1)
  198. x0 = x0 + self.ffn(torch.cat([x0, m0], -1))
  199. x1 = x1 + self.ffn(torch.cat([x1, m1], -1))
  200. return x0, x1
  201. class TransformerLayer(nn.Module):
  202. def __init__(self, *args, **kwargs):
  203. super().__init__()
  204. self.self_attn = SelfBlock(*args, **kwargs)
  205. self.cross_attn = CrossBlock(*args, **kwargs)
  206. def forward(
  207. self,
  208. desc0,
  209. desc1,
  210. encoding0,
  211. encoding1,
  212. mask0: Optional[torch.Tensor] = None,
  213. mask1: Optional[torch.Tensor] = None,
  214. ):
  215. if mask0 is not None and mask1 is not None:
  216. return self.masked_forward(desc0, desc1, encoding0, encoding1, mask0, mask1)
  217. else:
  218. desc0 = self.self_attn(desc0, encoding0)
  219. desc1 = self.self_attn(desc1, encoding1)
  220. return self.cross_attn(desc0, desc1)
  221. # This part is compiled and allows padding inputs
  222. def masked_forward(self, desc0, desc1, encoding0, encoding1, mask0, mask1):
  223. mask = mask0 & mask1.transpose(-1, -2)
  224. mask0 = mask0 & mask0.transpose(-1, -2)
  225. mask1 = mask1 & mask1.transpose(-1, -2)
  226. desc0 = self.self_attn(desc0, encoding0, mask0)
  227. desc1 = self.self_attn(desc1, encoding1, mask1)
  228. return self.cross_attn(desc0, desc1, mask)
  229. def sigmoid_log_double_softmax(
  230. sim: torch.Tensor, z0: torch.Tensor, z1: torch.Tensor
  231. ) -> torch.Tensor:
  232. """create the log assignment matrix from logits and similarity"""
  233. b, m, n = sim.shape
  234. certainties = F.logsigmoid(z0) + F.logsigmoid(z1).transpose(1, 2)
  235. scores0 = F.log_softmax(sim, 2)
  236. scores1 = F.log_softmax(sim.transpose(-1, -2).contiguous(), 2).transpose(-1, -2)
  237. scores = sim.new_full((b, m + 1, n + 1), 0)
  238. scores[:, :m, :n] = scores0 + scores1 + certainties
  239. scores[:, :-1, -1] = F.logsigmoid(-z0.squeeze(-1))
  240. scores[:, -1, :-1] = F.logsigmoid(-z1.squeeze(-1))
  241. return scores
  242. class MatchAssignment(nn.Module):
  243. def __init__(self, dim: int) -> None:
  244. super().__init__()
  245. self.dim = dim
  246. self.matchability = nn.Linear(dim, 1, bias=True)
  247. self.final_proj = nn.Linear(dim, dim, bias=True)
  248. def forward(self, desc0: torch.Tensor, desc1: torch.Tensor):
  249. """build assignment matrix from descriptors"""
  250. mdesc0, mdesc1 = self.final_proj(desc0), self.final_proj(desc1)
  251. _, _, d = mdesc0.shape
  252. mdesc0, mdesc1 = mdesc0 / d**0.25, mdesc1 / d**0.25
  253. sim = torch.einsum("bmd,bnd->bmn", mdesc0, mdesc1)
  254. z0 = self.matchability(desc0)
  255. z1 = self.matchability(desc1)
  256. scores = sigmoid_log_double_softmax(sim, z0, z1)
  257. return scores, sim
  258. def get_matchability(self, desc: torch.Tensor):
  259. return torch.sigmoid(self.matchability(desc)).squeeze(-1)
  260. def filter_matches(scores: torch.Tensor, th: float):
  261. """obtain matches from a log assignment matrix [Bx M+1 x N+1]"""
  262. max0, max1 = scores[:, :-1, :-1].max(2), scores[:, :-1, :-1].max(1)
  263. m0, m1 = max0.indices, max1.indices
  264. indices0 = torch.arange(m0.shape[1], device=m0.device)[None]
  265. indices1 = torch.arange(m1.shape[1], device=m1.device)[None]
  266. mutual0 = indices0 == m1.gather(1, m0)
  267. mutual1 = indices1 == m0.gather(1, m1)
  268. max0_exp = max0.values.exp()
  269. zero = max0_exp.new_tensor(0)
  270. mscores0 = torch.where(mutual0, max0_exp, zero)
  271. mscores1 = torch.where(mutual1, mscores0.gather(1, m1), zero)
  272. valid0 = mutual0 & (mscores0 > th)
  273. valid1 = mutual1 & valid0.gather(1, m1)
  274. m0 = torch.where(valid0, m0, -1)
  275. m1 = torch.where(valid1, m1, -1)
  276. return m0, m1, mscores0, mscores1
  277. class LightGlue(nn.Module):
  278. default_conf = {
  279. "name": "lightglue", # just for interfacing
  280. "input_dim": 256, # input descriptor dimension (autoselected from weights)
  281. "descriptor_dim": 256,
  282. "add_scale_ori": False,
  283. "n_layers": 9,
  284. "num_heads": 4,
  285. "flash": True, # enable FlashAttention if available.
  286. "mp": False, # enable mixed precision
  287. "depth_confidence": 0.95, # early stopping, disable with -1
  288. "width_confidence": 0.99, # point pruning, disable with -1
  289. "filter_threshold": 0.1, # match threshold
  290. "weights": None,
  291. }
  292. # Point pruning involves an overhead (gather).
  293. # Therefore, we only activate it if there are enough keypoints.
  294. pruning_keypoint_thresholds = {
  295. "cpu": -1,
  296. "mps": -1,
  297. "cuda": 1024,
  298. "flash": 1536,
  299. }
  300. required_data_keys = ["image0", "image1"]
  301. version = "v0.1_arxiv"
  302. url = "https://github.com/cvg/LightGlue/releases/download/{}/{}.pth"
  303. features = {
  304. "superpoint": {
  305. "weights": "superpoint_lightglue",
  306. "input_dim": 256,
  307. },
  308. "disk": {
  309. "weights": "disk_lightglue",
  310. "input_dim": 128,
  311. },
  312. "aliked": {
  313. "weights": "aliked_lightglue",
  314. "input_dim": 128,
  315. },
  316. "raco-aliked": {
  317. "weights": "raco_aliked_lightglue",
  318. "input_dim": 128,
  319. },
  320. "sift": {
  321. "weights": "sift_lightglue",
  322. "input_dim": 128,
  323. "add_scale_ori": True,
  324. },
  325. "doghardnet": {
  326. "weights": "doghardnet_lightglue",
  327. "input_dim": 128,
  328. "add_scale_ori": True,
  329. },
  330. }
  331. def __init__(self, features="superpoint", **conf) -> None:
  332. super().__init__()
  333. self.conf = conf = SimpleNamespace(**{**self.default_conf, **conf})
  334. if features is not None:
  335. if features not in self.features:
  336. raise ValueError(
  337. f"Unsupported features: {features} not in "
  338. f"{{{','.join(self.features)}}}"
  339. )
  340. for k, v in self.features[features].items():
  341. setattr(conf, k, v)
  342. if conf.input_dim != conf.descriptor_dim:
  343. self.input_proj = nn.Linear(conf.input_dim, conf.descriptor_dim, bias=True)
  344. else:
  345. self.input_proj = nn.Identity()
  346. head_dim = conf.descriptor_dim // conf.num_heads
  347. self.posenc = LearnableFourierPositionalEncoding(
  348. 2 + 2 * self.conf.add_scale_ori, head_dim, head_dim
  349. )
  350. h, n, d = conf.num_heads, conf.n_layers, conf.descriptor_dim
  351. self.transformers = nn.ModuleList(
  352. [TransformerLayer(d, h, conf.flash) for _ in range(n)]
  353. )
  354. self.log_assignment = nn.ModuleList([MatchAssignment(d) for _ in range(n)])
  355. self.token_confidence = nn.ModuleList(
  356. [TokenConfidence(d) for _ in range(n - 1)]
  357. )
  358. self.register_buffer(
  359. "confidence_thresholds",
  360. torch.Tensor(
  361. [self.confidence_threshold(i) for i in range(self.conf.n_layers)]
  362. ),
  363. )
  364. state_dict = None
  365. if features is not None:
  366. fname = f"{conf.weights}_{self.version.replace('.', '-')}.pth"
  367. state_dict = torch.hub.load_state_dict_from_url(
  368. self.url.format(self.version, self.conf.weights),
  369. file_name=fname,
  370. )
  371. self.load_state_dict(state_dict, strict=False)
  372. elif conf.weights is not None:
  373. path = Path(__file__).parent
  374. path = path / "weights/{}.pth".format(self.conf.weights)
  375. state_dict = torch.load(str(path), map_location="cpu")
  376. if state_dict:
  377. # rename old state dict entries
  378. for i in range(self.conf.n_layers):
  379. pattern = f"self_attn.{i}", f"transformers.{i}.self_attn"
  380. state_dict = {k.replace(*pattern): v for k, v in state_dict.items()}
  381. pattern = f"cross_attn.{i}", f"transformers.{i}.cross_attn"
  382. state_dict = {k.replace(*pattern): v for k, v in state_dict.items()}
  383. self.load_state_dict(state_dict, strict=False)
  384. # static lengths LightGlue is compiled for (only used with torch.compile)
  385. self.static_lengths = None
  386. def compile(
  387. self, mode="reduce-overhead", static_lengths=[256, 512, 768, 1024, 1280, 1536]
  388. ):
  389. if self.conf.width_confidence != -1:
  390. warnings.warn(
  391. "Point pruning is partially disabled for compiled forward.",
  392. stacklevel=2,
  393. )
  394. torch._inductor.cudagraph_mark_step_begin()
  395. for i in range(self.conf.n_layers):
  396. self.transformers[i].masked_forward = torch.compile(
  397. self.transformers[i].masked_forward, mode=mode, fullgraph=True
  398. )
  399. self.static_lengths = static_lengths
  400. def forward(self, data: dict) -> dict:
  401. """
  402. Match keypoints and descriptors between two images
  403. Input (dict):
  404. image0: dict
  405. keypoints: [B x M x 2]
  406. descriptors: [B x M x D]
  407. image: [B x C x H x W] or image_size: [B x 2]
  408. image1: dict
  409. keypoints: [B x N x 2]
  410. descriptors: [B x N x D]
  411. image: [B x C x H x W] or image_size: [B x 2]
  412. Output (dict):
  413. matches0: [B x M]
  414. matching_scores0: [B x M]
  415. matches1: [B x N]
  416. matching_scores1: [B x N]
  417. matches: List[[Si x 2]]
  418. scores: List[[Si]]
  419. stop: int
  420. prune0: [B x M]
  421. prune1: [B x N]
  422. """
  423. with torch.autocast(enabled=self.conf.mp, device_type="cuda"):
  424. return self._forward(data)
  425. def _forward(self, data: dict) -> dict:
  426. for key in self.required_data_keys:
  427. assert key in data, f"Missing key {key} in data"
  428. data0, data1 = data["image0"], data["image1"]
  429. kpts0, kpts1 = data0["keypoints"], data1["keypoints"]
  430. b, m, _ = kpts0.shape
  431. b, n, _ = kpts1.shape
  432. device = kpts0.device
  433. size0, size1 = data0.get("image_size"), data1.get("image_size")
  434. kpts0 = normalize_keypoints(kpts0, size0).clone()
  435. kpts1 = normalize_keypoints(kpts1, size1).clone()
  436. if self.conf.add_scale_ori:
  437. kpts0 = torch.cat(
  438. [kpts0] + [data0[k].unsqueeze(-1) for k in ("scales", "oris")], -1
  439. )
  440. kpts1 = torch.cat(
  441. [kpts1] + [data1[k].unsqueeze(-1) for k in ("scales", "oris")], -1
  442. )
  443. desc0 = data0["descriptors"].detach().contiguous()
  444. desc1 = data1["descriptors"].detach().contiguous()
  445. assert desc0.shape[-1] == self.conf.input_dim
  446. assert desc1.shape[-1] == self.conf.input_dim
  447. if torch.is_autocast_enabled():
  448. desc0 = desc0.half()
  449. desc1 = desc1.half()
  450. mask0, mask1 = None, None
  451. c = max(m, n)
  452. do_compile = self.static_lengths and c <= max(self.static_lengths)
  453. if do_compile:
  454. kn = min([k for k in self.static_lengths if k >= c])
  455. desc0, mask0 = pad_to_length(desc0, kn)
  456. desc1, mask1 = pad_to_length(desc1, kn)
  457. kpts0, _ = pad_to_length(kpts0, kn)
  458. kpts1, _ = pad_to_length(kpts1, kn)
  459. desc0 = self.input_proj(desc0)
  460. desc1 = self.input_proj(desc1)
  461. # cache positional embeddings
  462. encoding0 = self.posenc(kpts0)
  463. encoding1 = self.posenc(kpts1)
  464. # GNN + final_proj + assignment
  465. do_early_stop = self.conf.depth_confidence > 0
  466. do_point_pruning = self.conf.width_confidence > 0 and not do_compile
  467. pruning_th = self.pruning_min_kpts(device)
  468. if do_point_pruning:
  469. ind0 = torch.arange(0, m, device=device)[None]
  470. ind1 = torch.arange(0, n, device=device)[None]
  471. # We store the index of the layer at which pruning is detected.
  472. prune0 = torch.ones_like(ind0)
  473. prune1 = torch.ones_like(ind1)
  474. token0, token1 = None, None
  475. for i in range(self.conf.n_layers):
  476. if desc0.shape[1] == 0 or desc1.shape[1] == 0: # no keypoints
  477. break
  478. desc0, desc1 = self.transformers[i](
  479. desc0, desc1, encoding0, encoding1, mask0=mask0, mask1=mask1
  480. )
  481. if i == self.conf.n_layers - 1:
  482. continue # no early stopping or adaptive width at last layer
  483. if do_early_stop:
  484. token0, token1 = self.token_confidence[i](desc0, desc1)
  485. if self.check_if_stop(token0[..., :m], token1[..., :n], i, m + n):
  486. break
  487. if do_point_pruning and desc0.shape[-2] > pruning_th:
  488. scores0 = self.log_assignment[i].get_matchability(desc0)
  489. prunemask0 = self.get_pruning_mask(token0, scores0, i)
  490. keep0 = torch.where(prunemask0)[1]
  491. ind0 = ind0.index_select(1, keep0)
  492. desc0 = desc0.index_select(1, keep0)
  493. encoding0 = encoding0.index_select(-2, keep0)
  494. prune0[:, ind0] += 1
  495. if do_point_pruning and desc1.shape[-2] > pruning_th:
  496. scores1 = self.log_assignment[i].get_matchability(desc1)
  497. prunemask1 = self.get_pruning_mask(token1, scores1, i)
  498. keep1 = torch.where(prunemask1)[1]
  499. ind1 = ind1.index_select(1, keep1)
  500. desc1 = desc1.index_select(1, keep1)
  501. encoding1 = encoding1.index_select(-2, keep1)
  502. prune1[:, ind1] += 1
  503. if desc0.shape[1] == 0 or desc1.shape[1] == 0: # no keypoints
  504. m0 = desc0.new_full((b, m), -1, dtype=torch.long)
  505. m1 = desc1.new_full((b, n), -1, dtype=torch.long)
  506. mscores0 = desc0.new_zeros((b, m))
  507. mscores1 = desc1.new_zeros((b, n))
  508. matches = desc0.new_empty((b, 0, 2), dtype=torch.long)
  509. mscores = desc0.new_empty((b, 0))
  510. if not do_point_pruning:
  511. prune0 = torch.ones_like(mscores0) * self.conf.n_layers
  512. prune1 = torch.ones_like(mscores1) * self.conf.n_layers
  513. return {
  514. "matches0": m0,
  515. "matches1": m1,
  516. "matching_scores0": mscores0,
  517. "matching_scores1": mscores1,
  518. "stop": i + 1,
  519. "matches": matches,
  520. "scores": mscores,
  521. "prune0": prune0,
  522. "prune1": prune1,
  523. }
  524. desc0, desc1 = desc0[..., :m, :], desc1[..., :n, :] # remove padding
  525. scores, _ = self.log_assignment[i](desc0, desc1)
  526. m0, m1, mscores0, mscores1 = filter_matches(scores, self.conf.filter_threshold)
  527. matches, mscores = [], []
  528. for k in range(b):
  529. valid = m0[k] > -1
  530. m_indices_0 = torch.where(valid)[0]
  531. m_indices_1 = m0[k][valid]
  532. if do_point_pruning:
  533. m_indices_0 = ind0[k, m_indices_0]
  534. m_indices_1 = ind1[k, m_indices_1]
  535. matches.append(torch.stack([m_indices_0, m_indices_1], -1))
  536. mscores.append(mscores0[k][valid])
  537. # TODO: Remove when hloc switches to the compact format.
  538. if do_point_pruning:
  539. m0_ = torch.full((b, m), -1, device=m0.device, dtype=m0.dtype)
  540. m1_ = torch.full((b, n), -1, device=m1.device, dtype=m1.dtype)
  541. m0_[:, ind0] = torch.where(m0 == -1, -1, ind1.gather(1, m0.clamp(min=0)))
  542. m1_[:, ind1] = torch.where(m1 == -1, -1, ind0.gather(1, m1.clamp(min=0)))
  543. mscores0_ = torch.zeros((b, m), device=mscores0.device)
  544. mscores1_ = torch.zeros((b, n), device=mscores1.device)
  545. mscores0_[:, ind0] = mscores0
  546. mscores1_[:, ind1] = mscores1
  547. m0, m1, mscores0, mscores1 = m0_, m1_, mscores0_, mscores1_
  548. else:
  549. prune0 = torch.ones_like(mscores0) * self.conf.n_layers
  550. prune1 = torch.ones_like(mscores1) * self.conf.n_layers
  551. return {
  552. "matches0": m0,
  553. "matches1": m1,
  554. "matching_scores0": mscores0,
  555. "matching_scores1": mscores1,
  556. "stop": i + 1,
  557. "matches": matches,
  558. "scores": mscores,
  559. "prune0": prune0,
  560. "prune1": prune1,
  561. }
  562. def confidence_threshold(self, layer_index: int) -> float:
  563. """scaled confidence threshold"""
  564. threshold = 0.8 + 0.1 * np.exp(-4.0 * layer_index / self.conf.n_layers)
  565. return np.clip(threshold, 0, 1)
  566. def get_pruning_mask(
  567. self, confidences: torch.Tensor, scores: torch.Tensor, layer_index: int
  568. ) -> torch.Tensor:
  569. """mask points which should be removed"""
  570. keep = scores > (1 - self.conf.width_confidence)
  571. if confidences is not None: # Low-confidence points are never pruned.
  572. keep |= confidences <= self.confidence_thresholds[layer_index]
  573. return keep
  574. def check_if_stop(
  575. self,
  576. confidences0: torch.Tensor,
  577. confidences1: torch.Tensor,
  578. layer_index: int,
  579. num_points: int,
  580. ) -> torch.Tensor:
  581. """evaluate stopping condition"""
  582. confidences = torch.cat([confidences0, confidences1], -1)
  583. threshold = self.confidence_thresholds[layer_index]
  584. ratio_confident = 1.0 - (confidences < threshold).float().sum() / num_points
  585. return ratio_confident > self.conf.depth_confidence
  586. def pruning_min_kpts(self, device: torch.device):
  587. if self.conf.flash and FLASH_AVAILABLE and device.type == "cuda":
  588. return self.pruning_keypoint_thresholds["flash"]
  589. else:
  590. return self.pruning_keypoint_thresholds[device.type]