mambaglue.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859
  1. import warnings
  2. from pathlib import Path
  3. from types import SimpleNamespace
  4. from typing import Callable, List, Optional, Tuple
  5. from omegaconf import OmegaConf
  6. import numpy as np
  7. import torch
  8. import torch.nn.functional as F
  9. from torch import nn
  10. from einops import rearrange, repeat
  11. from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
  12. import math
  13. import requests
  14. from torch.hub import download_url_to_file
  15. try:
  16. from flash_attn.modules.mha import FlashCrossAttention
  17. except ModuleNotFoundError:
  18. FlashCrossAttention = None
  19. if FlashCrossAttention or hasattr(F, "scaled_dot_product_attention"):
  20. FLASH_AVAILABLE = True
  21. else:
  22. FLASH_AVAILABLE = False
  23. torch.backends.cudnn.deterministic = True
  24. @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
  25. def normalize_keypoints(
  26. kpts: torch.Tensor, size: Optional[torch.Tensor] = None
  27. ) -> torch.Tensor:
  28. if size is None:
  29. size = 1 + kpts.max(-2).values - kpts.min(-2).values
  30. elif not isinstance(size, torch.Tensor):
  31. size = torch.tensor(size, device=kpts.device, dtype=kpts.dtype)
  32. size = size.to(kpts)
  33. shift = size / 2
  34. scale = size.max(-1).values / 2
  35. kpts = (kpts - shift[..., None, :]) / scale[..., None, None]
  36. return kpts
  37. def pad_to_length(x: torch.Tensor, length: int) -> Tuple[torch.Tensor]:
  38. if length <= x.shape[-2]:
  39. return x, torch.ones_like(x[..., :1], dtype=torch.bool)
  40. pad = torch.ones(
  41. *x.shape[:-2], length - x.shape[-2], x.shape[-1], device=x.device, dtype=x.dtype
  42. )
  43. y = torch.cat([x, pad], dim=-2)
  44. mask = torch.zeros(*y.shape[:-1], 1, dtype=torch.bool, device=x.device)
  45. mask[..., : x.shape[-2], :] = True
  46. return y, mask
  47. def rotate_half(x: torch.Tensor) -> torch.Tensor:
  48. x = x.unflatten(-1, (-1, 2))
  49. x1, x2 = x.unbind(dim=-1)
  50. return torch.stack((-x2, x1), dim=-1).flatten(start_dim=-2)
  51. def apply_cached_rotary_emb(freqs: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
  52. return (t * freqs[0]) + (rotate_half(t) * freqs[1])
  53. class LearnableFourierPositionalEncoding(nn.Module):
  54. def __init__(self, M: int, dim: int, F_dim: int = None, gamma: float = 1.0) -> None:
  55. super().__init__()
  56. F_dim = F_dim if F_dim is not None else dim
  57. self.gamma = gamma
  58. self.Wr = nn.Linear(M, F_dim // 2, bias=False)
  59. nn.init.normal_(self.Wr.weight.data, mean=0, std=self.gamma**-2)
  60. def forward(self, x: torch.Tensor) -> torch.Tensor:
  61. """encode position vector"""
  62. projected = self.Wr(x)
  63. cosines, sines = torch.cos(projected), torch.sin(projected)
  64. emb = torch.stack([cosines, sines], 0).unsqueeze(-3)
  65. return emb.repeat_interleave(2, dim=-1)
  66. class TokenConfidence(nn.Module):
  67. def __init__(self, dim: int) -> None:
  68. super().__init__()
  69. self.token = nn.Sequential(
  70. nn.Linear(dim, dim // 2),
  71. nn.ReLU(),
  72. nn.Linear(dim // 2, dim // 4),
  73. nn.ReLU(),
  74. nn.Linear(dim // 4, 1),
  75. nn.Sigmoid(),
  76. )
  77. def forward(self, desc0: torch.Tensor, desc1: torch.Tensor):
  78. """get confidence tokens"""
  79. return (
  80. self.token(desc0.detach()).squeeze(-1),
  81. self.token(desc1.detach()).squeeze(-1),
  82. )
  83. class Attention(nn.Module):
  84. def __init__(self, allow_flash: bool) -> None:
  85. super().__init__()
  86. if allow_flash and not FLASH_AVAILABLE:
  87. warnings.warn(
  88. "FlashAttention is not available. For optimal speed, "
  89. "consider installing torch >= 2.0 or flash-attn.",
  90. stacklevel=2,
  91. )
  92. self.enable_flash = allow_flash and FLASH_AVAILABLE
  93. self.has_sdp = hasattr(F, "scaled_dot_product_attention")
  94. if allow_flash and FlashCrossAttention:
  95. self.flash_ = FlashCrossAttention()
  96. if self.has_sdp:
  97. torch.backends.cuda.enable_flash_sdp(allow_flash)
  98. def forward(self, q, k, v, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
  99. if q.shape[-2] == 0 or k.shape[-2] == 0:
  100. return q.new_zeros((*q.shape[:-1], v.shape[-1]))
  101. if self.enable_flash and q.device.type == "cuda":
  102. # use torch 2.0 scaled_dot_product_attention with flash
  103. if self.has_sdp:
  104. args = [x.half().contiguous() for x in [q, k, v]]
  105. v = F.scaled_dot_product_attention(*args, attn_mask=mask).to(q.dtype)
  106. return v if mask is None else v.nan_to_num()
  107. else:
  108. assert mask is None
  109. q, k, v = [x.transpose(-2, -3).contiguous() for x in [q, k, v]]
  110. m = self.flash_(q.half(), torch.stack([k, v], 2).half())
  111. return m.transpose(-2, -3).to(q.dtype).clone()
  112. elif self.has_sdp:
  113. args = [x.contiguous() for x in [q, k, v]]
  114. v = F.scaled_dot_product_attention(*args, attn_mask=mask)
  115. return v if mask is None else v.nan_to_num()
  116. else:
  117. s = q.shape[-1] ** -0.5
  118. sim = torch.einsum("...id,...jd->...ij", q, k) * s
  119. if mask is not None:
  120. sim.masked_fill(~mask, -float("inf"))
  121. attn = F.softmax(sim, -1)
  122. return torch.einsum("...ij,...jd->...id", attn, v)
  123. class MambaMixer(nn.Module):
  124. def __init__(
  125. self,
  126. d_model,
  127. d_state=16,
  128. d_conv=5,
  129. expand=2,
  130. dt_rank="auto",
  131. dt_min=0.001,
  132. dt_max=0.1,
  133. dt_init="random",
  134. dt_scale=1.0,
  135. dt_init_floor=1e-4,
  136. conv_bias=True,
  137. bias=False,
  138. use_fast_path=True,
  139. layer_idx=None,
  140. device=None,
  141. dtype=None,
  142. ):
  143. factory_kwargs = {"device": device, "dtype": dtype}
  144. super().__init__()
  145. self.d_model = d_model
  146. self.d_state = d_state
  147. self.d_conv = d_conv
  148. self.expand = expand
  149. self.d_inner = int(self.expand * self.d_model)
  150. self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
  151. self.use_fast_path = use_fast_path
  152. self.layer_idx = layer_idx
  153. self.in_proj = nn.Linear(
  154. self.d_model, self.d_inner, bias=bias, **factory_kwargs
  155. )
  156. self.x_proj = nn.Linear(
  157. self.d_inner // 2,
  158. self.dt_rank + self.d_state * 2,
  159. bias=False,
  160. **factory_kwargs,
  161. )
  162. self.dt_proj = nn.Linear(
  163. self.dt_rank, self.d_inner // 2, bias=True, **factory_kwargs
  164. )
  165. dt_init_std = self.dt_rank**-0.5 * dt_scale
  166. if dt_init == "constant":
  167. nn.init.constant_(self.dt_proj.weight, dt_init_std)
  168. elif dt_init == "random":
  169. nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
  170. else:
  171. raise NotImplementedError
  172. dt = torch.exp(
  173. torch.rand(self.d_inner // 2, **factory_kwargs)
  174. * (math.log(dt_max) - math.log(dt_min))
  175. + math.log(dt_min)
  176. ).clamp(min=dt_init_floor)
  177. inv_dt = dt + torch.log(-torch.expm1(-dt))
  178. with torch.no_grad():
  179. self.dt_proj.bias.copy_(inv_dt)
  180. self.dt_proj.bias._no_reinit = True
  181. A = repeat(
  182. torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
  183. "n -> d n",
  184. d=self.d_inner // 2,
  185. ).contiguous()
  186. A_log = torch.log(A)
  187. self.A_log = nn.Parameter(A_log)
  188. self.A_log._no_weight_decay = True
  189. self.D = nn.Parameter(torch.ones(self.d_inner // 2, device=device))
  190. self.D._no_weight_decay = True
  191. self.out_proj = nn.Linear(
  192. self.d_inner // 2, self.d_model, bias=bias, **factory_kwargs
  193. )
  194. self.conv1d_x = nn.Conv1d(
  195. in_channels=self.d_inner // 2,
  196. out_channels=self.d_inner // 2,
  197. bias=conv_bias // 2,
  198. kernel_size=d_conv,
  199. groups=self.d_inner // 2,
  200. **factory_kwargs,
  201. )
  202. self.conv1d_z = nn.Conv1d(
  203. in_channels=self.d_inner // 2,
  204. out_channels=self.d_inner // 2,
  205. bias=conv_bias // 2,
  206. kernel_size=d_conv,
  207. groups=self.d_inner // 2,
  208. **factory_kwargs,
  209. )
  210. def forward(self, hidden_states):
  211. """
  212. hidden_states: (B, L, D)
  213. Returns: same shape as hidden_states
  214. """
  215. (
  216. _,
  217. seqlen,
  218. _,
  219. ) = (
  220. hidden_states.shape
  221. ) # [B, 512, D] 512(L) : max_num_keypoints, D : 256 (usually)
  222. xz = self.in_proj(hidden_states) # [B, L, 2*D]
  223. xz = rearrange(xz, "b l d -> b d l") # [B, 2D, L]
  224. x, z = xz.chunk(2, dim=1) # [B, D, L], [B, D, L]
  225. A = -torch.exp(self.A_log.float()) # [256]
  226. x = F.silu(
  227. F.conv1d(
  228. input=x,
  229. weight=self.conv1d_x.weight,
  230. bias=self.conv1d_x.bias,
  231. padding="same",
  232. groups=self.d_inner // 2,
  233. )
  234. ) # [B, D, L]
  235. z = F.silu(
  236. F.conv1d(
  237. input=z,
  238. weight=self.conv1d_z.weight,
  239. bias=self.conv1d_z.bias,
  240. padding="same",
  241. groups=self.d_inner // 2,
  242. )
  243. )
  244. x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # [B*D, L]
  245. dt, B, C = torch.split(
  246. x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1
  247. ) # [BL, dt_rank], [16384, 16], [16384, 16]
  248. dt = rearrange(
  249. self.dt_proj(dt), "(b l) d -> b d l", l=seqlen
  250. ) # [B, dt_rank, L]
  251. B = rearrange(
  252. B, "(b l) dstate -> b dstate l", l=seqlen
  253. ).contiguous() # [B, dt_state, L]
  254. C = rearrange(
  255. C, "(b l) dstate -> b dstate l", l=seqlen
  256. ).contiguous() # [B, dt_state, L]
  257. y = selective_scan_fn(
  258. x,
  259. dt,
  260. A,
  261. B,
  262. C,
  263. self.D.float(),
  264. z=None,
  265. delta_bias=self.dt_proj.bias.float(),
  266. delta_softplus=True,
  267. return_last_state=None,
  268. ) # [B, D, L]
  269. y = rearrange(y, "b d l -> b l d") # [B, L, D]
  270. z = rearrange(z, "b d l -> b l d") # [B, L, D]
  271. out_y = self.out_proj(y)
  272. out_z = self.out_proj(z)
  273. return out_y, out_z # [B, 512, 256], [B, 512, 256]
  274. class MambaAttentionMixer(nn.Module):
  275. def __init__(
  276. self, embed_dim: int, num_heads: int, flash: bool = False, bias: bool = True
  277. ) -> None:
  278. super().__init__()
  279. self.embed_dim = embed_dim
  280. self.num_heads = num_heads
  281. assert self.embed_dim % num_heads == 0
  282. self.head_dim = self.embed_dim // num_heads
  283. self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias)
  284. self.inner_attn = Attention(flash)
  285. self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  286. self.ffn = nn.Sequential(
  287. nn.Linear(4 * embed_dim, 2 * embed_dim),
  288. nn.LayerNorm(2 * embed_dim, elementwise_affine=True),
  289. nn.GELU(),
  290. nn.Linear(2 * embed_dim, embed_dim),
  291. )
  292. # Mamba
  293. self.mamba_mixer = MambaMixer(self.embed_dim)
  294. def forward(
  295. self,
  296. x: torch.Tensor,
  297. encoding: torch.Tensor,
  298. mask: Optional[torch.Tensor] = None,
  299. ) -> torch.Tensor:
  300. qkv = self.Wqkv(x)
  301. qkv = qkv.unflatten(-1, (self.num_heads, -1, 3)).transpose(1, 2)
  302. q, k, v = qkv[..., 0], qkv[..., 1], qkv[..., 2]
  303. q = apply_cached_rotary_emb(encoding, q)
  304. k = apply_cached_rotary_emb(encoding, k)
  305. context = self.inner_attn(q, k, v, mask=mask)
  306. message = self.out_proj(context.transpose(1, 2).flatten(start_dim=-2))
  307. # Mamba
  308. mamba_y, mamba_z = self.mamba_mixer(x)
  309. return x + self.ffn(torch.cat([x, message, mamba_y, mamba_z], -1))
  310. class CrossBlock(nn.Module):
  311. def __init__(
  312. self, embed_dim: int, num_heads: int, flash: bool = False, bias: bool = True
  313. ) -> None:
  314. super().__init__()
  315. self.heads = num_heads
  316. dim_head = embed_dim // num_heads
  317. self.scale = dim_head**-0.5
  318. inner_dim = dim_head * num_heads
  319. self.to_qk = nn.Linear(embed_dim, inner_dim, bias=bias)
  320. self.to_v = nn.Linear(embed_dim, inner_dim, bias=bias)
  321. self.to_out = nn.Linear(inner_dim, embed_dim, bias=bias)
  322. self.ffn = nn.Sequential(
  323. nn.Linear(2 * embed_dim, 2 * embed_dim),
  324. nn.LayerNorm(2 * embed_dim, elementwise_affine=True),
  325. nn.GELU(),
  326. nn.Linear(2 * embed_dim, embed_dim),
  327. )
  328. if flash and FLASH_AVAILABLE:
  329. self.flash = Attention(True)
  330. else:
  331. self.flash = None
  332. def map_(self, func: Callable, x0: torch.Tensor, x1: torch.Tensor):
  333. return func(x0), func(x1)
  334. def forward(
  335. self, x0: torch.Tensor, x1: torch.Tensor, mask: Optional[torch.Tensor] = None
  336. ) -> List[torch.Tensor]:
  337. qk0, qk1 = self.map_(self.to_qk, x0, x1)
  338. v0, v1 = self.map_(self.to_v, x0, x1)
  339. qk0, qk1, v0, v1 = map(
  340. lambda t: t.unflatten(-1, (self.heads, -1)).transpose(1, 2),
  341. (qk0, qk1, v0, v1),
  342. )
  343. if self.flash is not None and qk0.device.type == "cuda":
  344. m0 = self.flash(qk0, qk1, v1, mask)
  345. m1 = self.flash(
  346. qk1, qk0, v0, mask.transpose(-1, -2) if mask is not None else None
  347. )
  348. else:
  349. qk0, qk1 = qk0 * self.scale**0.5, qk1 * self.scale**0.5
  350. sim = torch.einsum("bhid, bhjd -> bhij", qk0, qk1)
  351. if mask is not None:
  352. sim = sim.masked_fill(~mask, -float("inf"))
  353. attn01 = F.softmax(sim, dim=-1)
  354. attn10 = F.softmax(sim.transpose(-2, -1).contiguous(), dim=-1)
  355. m0 = torch.einsum("bhij, bhjd -> bhid", attn01, v1)
  356. m1 = torch.einsum("bhji, bhjd -> bhid", attn10.transpose(-2, -1), v0)
  357. if mask is not None:
  358. m0, m1 = m0.nan_to_num(), m1.nan_to_num()
  359. m0, m1 = self.map_(lambda t: t.transpose(1, 2).flatten(start_dim=-2), m0, m1)
  360. m0, m1 = self.map_(self.to_out, m0, m1)
  361. x0 = x0 + self.ffn(torch.cat([x0, m0], -1))
  362. x1 = x1 + self.ffn(torch.cat([x1, m1], -1))
  363. return x0, x1
  364. class TransformerMambaLayer(nn.Module):
  365. def __init__(self, *args, **kwargs):
  366. super().__init__()
  367. self.mamba_selfattn_mixer = MambaAttentionMixer(*args, **kwargs)
  368. self.cross_attn = CrossBlock(*args, **kwargs)
  369. def forward(
  370. self,
  371. desc0,
  372. desc1,
  373. encoding0,
  374. encoding1,
  375. mask0: Optional[torch.Tensor] = None,
  376. mask1: Optional[torch.Tensor] = None,
  377. ):
  378. if mask0 is not None and mask1 is not None:
  379. return self.masked_forward(desc0, desc1, encoding0, encoding1, mask0, mask1)
  380. else:
  381. desc0 = self.mamba_selfattn_mixer(desc0, encoding0)
  382. desc1 = self.mamba_selfattn_mixer(desc1, encoding1)
  383. return self.cross_attn(desc0, desc1)
  384. # This part is compiled and allows padding inputs
  385. def masked_forward(self, desc0, desc1, encoding0, encoding1, mask0, mask1):
  386. mask = mask0 & mask1.transpose(-1, -2)
  387. mask0 = mask0 & mask0.transpose(-1, -2)
  388. mask1 = mask1 & mask1.transpose(-1, -2)
  389. desc0 = self.mamba_selfattn_mixer(desc0, encoding0, mask0)
  390. desc1 = self.mamba_selfattn_mixer(desc1, encoding1, mask1)
  391. return self.cross_attn(desc0, desc1, mask)
  392. def sigmoid_log_double_softmax(
  393. sim: torch.Tensor, z0: torch.Tensor, z1: torch.Tensor
  394. ) -> torch.Tensor:
  395. """create the log assignment matrix from logits and similarity"""
  396. b, m, n = sim.shape
  397. certainties = F.logsigmoid(z0) + F.logsigmoid(z1).transpose(1, 2)
  398. scores0 = F.log_softmax(sim, 2)
  399. scores1 = F.log_softmax(sim.transpose(-1, -2).contiguous(), 2).transpose(-1, -2)
  400. scores = sim.new_full((b, m + 1, n + 1), 0)
  401. scores[:, :m, :n] = scores0 + scores1 + certainties
  402. scores[:, :-1, -1] = F.logsigmoid(-z0.squeeze(-1))
  403. scores[:, -1, :-1] = F.logsigmoid(-z1.squeeze(-1))
  404. return scores
  405. class MatchAssignment(nn.Module):
  406. def __init__(self, dim: int) -> None:
  407. super().__init__()
  408. self.dim = dim
  409. self.matchability = nn.Linear(dim, 1, bias=True)
  410. self.final_proj = nn.Linear(dim, dim, bias=True)
  411. def forward(self, desc0: torch.Tensor, desc1: torch.Tensor):
  412. """build assignment matrix from descriptors"""
  413. mdesc0, mdesc1 = self.final_proj(desc0), self.final_proj(desc1)
  414. _, _, d = mdesc0.shape
  415. mdesc0, mdesc1 = mdesc0 / d**0.25, mdesc1 / d**0.25
  416. sim = torch.einsum("bmd,bnd->bmn", mdesc0, mdesc1)
  417. z0 = self.matchability(desc0)
  418. z1 = self.matchability(desc1)
  419. scores = sigmoid_log_double_softmax(sim, z0, z1)
  420. return scores, sim
  421. def get_matchability(self, desc: torch.Tensor):
  422. return torch.sigmoid(self.matchability(desc)).squeeze(-1)
  423. def filter_matches(scores: torch.Tensor, th: float):
  424. """obtain matches from a log assignment matrix [Bx M+1 x N+1]"""
  425. max0, max1 = scores[:, :-1, :-1].max(2), scores[:, :-1, :-1].max(1)
  426. m0, m1 = max0.indices, max1.indices
  427. indices0 = torch.arange(m0.shape[1], device=m0.device)[None]
  428. indices1 = torch.arange(m1.shape[1], device=m1.device)[None]
  429. mutual0 = indices0 == m1.gather(1, m0)
  430. mutual1 = indices1 == m0.gather(1, m1)
  431. max0_exp = max0.values.exp()
  432. zero = max0_exp.new_tensor(0)
  433. mscores0 = torch.where(mutual0, max0_exp, zero)
  434. mscores1 = torch.where(mutual1, mscores0.gather(1, m1), zero)
  435. valid0 = mutual0 & (mscores0 > th)
  436. valid1 = mutual1 & valid0.gather(1, m1)
  437. m0 = torch.where(valid0, m0, -1)
  438. m1 = torch.where(valid1, m1, -1)
  439. return m0, m1, mscores0, mscores1
  440. class MambaGlue(nn.Module):
  441. default_conf = {
  442. "name": "mambaglue", # just for interfacing
  443. "input_dim": 256, # input descriptor dimension (autoselected from weights)
  444. "descriptor_dim": 256,
  445. "add_scale_ori": False,
  446. "n_layers": 9,
  447. "num_heads": 4,
  448. "flash": True, # enable FlashAttention if available.
  449. "mp": False, # enable mixed precision
  450. "depth_confidence": -1, # early stopping, disable with -1
  451. "width_confidence": -1, # point pruning, disable with -1
  452. "filter_threshold": 0.01, # match threshold
  453. "weights": None,
  454. }
  455. # Point pruning involves an overhead (gather).
  456. # Therefore, we only activate it if there are enough keypoints.
  457. pruning_keypoint_thresholds = {
  458. "cpu": -1,
  459. "mps": -1,
  460. "cuda": 1024,
  461. "flash": 1536,
  462. }
  463. required_data_keys = ["image0", "image1"]
  464. version = "v0.1"
  465. url = "https://github.com/url-kaist/MambaGlue/releases/download/{}/{}_mambaglue.tar" # (will be) releases/v0.1/superpoint_mambaglue.tar
  466. # Train your own for now and use it on local
  467. features = {
  468. "superpoint": {
  469. "weights": "superpoint_mambaglue",
  470. "input_dim": 256,
  471. },
  472. "disk": {
  473. "weights": "disk_mambaglue",
  474. "input_dim": 128,
  475. },
  476. "aliked": {
  477. "weights": "aliked_mambaglue",
  478. "input_dim": 128,
  479. },
  480. "sift": {
  481. "weights": "sift_mambaglue",
  482. "input_dim": 128,
  483. "add_scale_ori": True,
  484. },
  485. "doghardnet": {
  486. "weights": "doghardnet_mambaglue",
  487. "input_dim": 128,
  488. "add_scale_ori": True,
  489. },
  490. }
  491. def __init__(self, features="superpoint", **conf) -> None:
  492. super().__init__()
  493. self.conf = conf = SimpleNamespace(**{**self.default_conf, **conf})
  494. # self.conf = conf = OmegaConf.merge(self.default_conf, conf)
  495. if features is not None:
  496. if features not in self.features:
  497. raise ValueError(
  498. f"Unsupported features: {features} not in "
  499. f"{{{','.join(self.features)}}}"
  500. )
  501. for k, v in self.features[features].items():
  502. setattr(conf, k, v)
  503. if conf.input_dim != conf.descriptor_dim:
  504. self.input_proj = nn.Linear(conf.input_dim, conf.descriptor_dim, bias=True)
  505. else:
  506. self.input_proj = nn.Identity()
  507. head_dim = conf.descriptor_dim // conf.num_heads
  508. self.posenc = LearnableFourierPositionalEncoding(
  509. 2 + 2 * self.conf.add_scale_ori, head_dim, head_dim
  510. )
  511. h, n, d = conf.num_heads, conf.n_layers, conf.descriptor_dim
  512. self.transformermambas = nn.ModuleList(
  513. [TransformerMambaLayer(d, h, conf.flash) for _ in range(n)]
  514. )
  515. self.log_assignment = nn.ModuleList([MatchAssignment(d) for _ in range(n)])
  516. self.token_confidence = nn.ModuleList(
  517. [TokenConfidence(d) for _ in range(n - 1)]
  518. )
  519. self.register_buffer(
  520. "confidence_thresholds",
  521. torch.Tensor(
  522. [self.confidence_threshold(i) for i in range(self.conf.n_layers)]
  523. ),
  524. )
  525. state_dict = None
  526. if features is not None:
  527. # When using released weight
  528. # fname = f"{conf.weights}_{self.version.replace('.', '-')}.tar"
  529. # state_dict = torch.hub.load_state_dict_from_url(
  530. # self.url.format(self.version, features), file_name=fname
  531. # )
  532. ##### LOCAL weight
  533. local_path = Path(
  534. "checkpoint_best.tar"
  535. ) # local path for your own weight (.tar or .pth)
  536. print(f"Attempting to load from: {local_path}")
  537. if not local_path.exists():
  538. raise FileNotFoundError(
  539. f"Weights file not found at {local_path}. Please download it manually."
  540. )
  541. checkpoint = torch.load(str(local_path), map_location="cpu")
  542. # Extract only the model weights from the checkpoint
  543. if "model" in checkpoint:
  544. state_dict = checkpoint["model"]
  545. print("Successfully extracted model weights from the checkpoint.")
  546. else:
  547. raise KeyError(
  548. "The checkpoint does not contain 'model' key. Available keys are: ",
  549. checkpoint.keys(),
  550. )
  551. # Load the state dict into your model
  552. self.load_state_dict(state_dict, strict=False)
  553. elif conf.weights is not None:
  554. path = Path(__file__).parent
  555. path = path / "weights/{}.pth".format(self.conf.weights)
  556. state_dict = torch.load(str(path), map_location="cpu")
  557. if state_dict:
  558. # rename mismatched state dict entries
  559. for i in range(self.conf.n_layers):
  560. pattern = "matcher.", ""
  561. state_dict = {k.replace(*pattern): v for k, v in state_dict.items()}
  562. pattern = "extractor", ""
  563. state_dict = {k.replace(*pattern): v for k, v in state_dict.items()}
  564. self.load_state_dict(state_dict, strict=False)
  565. # static lengths MambaGlue is compiled for (only used with torch.compile)
  566. self.static_lengths = None
  567. def compile(
  568. self, mode="reduce-overhead", static_lengths=[256, 512, 768, 1024, 1280, 1536]
  569. ):
  570. if self.conf.width_confidence != -1:
  571. warnings.warn(
  572. "Point pruning is partially disabled for compiled forward.",
  573. stacklevel=2,
  574. )
  575. torch._inductor.cudagraph_mark_step_begin()
  576. for i in range(self.conf.n_layers):
  577. self.transformermambas[i].masked_forward = torch.compile(
  578. self.transformermambas[i].masked_forward, mode=mode, fullgraph=True
  579. )
  580. self.static_lengths = static_lengths
  581. def forward(self, data: dict) -> dict:
  582. """
  583. Match keypoints and descriptors between two images
  584. Input (dict):
  585. image0: dict
  586. keypoints: [B x M x 2]
  587. descriptors: [B x M x D]
  588. image: [B x C x H x W] or image_size: [B x 2]
  589. image1: dict
  590. keypoints: [B x N x 2]
  591. descriptors: [B x N x D]
  592. image: [B x C x H x W] or image_size: [B x 2]
  593. Output (dict):
  594. matches0: [B x M]
  595. matching_scores0: [B x M]
  596. matches1: [B x N]
  597. matching_scores1: [B x N]
  598. matches: List[[Si x 2]]
  599. scores: List[[Si]]
  600. stop: int
  601. prune0: [B x M]
  602. prune1: [B x N]
  603. """
  604. with torch.autocast(enabled=self.conf.mp, device_type="cuda"):
  605. return self._forward(data)
  606. def _forward(self, data: dict) -> dict:
  607. for key in self.required_data_keys:
  608. assert key in data, f"Missing key {key} in data"
  609. data0, data1 = data["image0"], data["image1"]
  610. kpts0, kpts1 = data0["keypoints"], data1["keypoints"]
  611. b, m, _ = kpts0.shape
  612. b, n, _ = kpts1.shape
  613. device = kpts0.device
  614. size0, size1 = data0.get("image_size"), data1.get("image_size")
  615. kpts0 = normalize_keypoints(kpts0, size0).clone()
  616. kpts1 = normalize_keypoints(kpts1, size1).clone()
  617. if self.conf.add_scale_ori:
  618. kpts0 = torch.cat(
  619. [kpts0] + [data0[k].unsqueeze(-1) for k in ("scales", "oris")], -1
  620. )
  621. kpts1 = torch.cat(
  622. [kpts1] + [data1[k].unsqueeze(-1) for k in ("scales", "oris")], -1
  623. )
  624. desc0 = data0["descriptors"].detach().contiguous()
  625. desc1 = data1["descriptors"].detach().contiguous()
  626. assert desc0.shape[-1] == self.conf.input_dim
  627. assert desc1.shape[-1] == self.conf.input_dim
  628. if torch.is_autocast_enabled():
  629. desc0 = desc0.half()
  630. desc1 = desc1.half()
  631. mask0, mask1 = None, None
  632. c = max(m, n)
  633. do_compile = self.static_lengths and c <= max(self.static_lengths)
  634. # do_compile = False
  635. if do_compile:
  636. kn = min([k for k in self.static_lengths if k >= c])
  637. desc0, mask0 = pad_to_length(desc0, kn)
  638. desc1, mask1 = pad_to_length(desc1, kn)
  639. kpts0, _ = pad_to_length(kpts0, kn)
  640. kpts1, _ = pad_to_length(kpts1, kn)
  641. desc0 = self.input_proj(desc0)
  642. desc1 = self.input_proj(desc1)
  643. # cache positional embeddings
  644. encoding0 = self.posenc(kpts0)
  645. encoding1 = self.posenc(kpts1)
  646. # GNN + final_proj + assignment
  647. do_early_stop = self.conf.depth_confidence > 0
  648. do_point_pruning = self.conf.width_confidence > 0 and not do_compile
  649. pruning_th = self.pruning_min_kpts(device)
  650. if do_point_pruning:
  651. ind0 = torch.arange(0, m, device=device)[None]
  652. ind1 = torch.arange(0, n, device=device)[None]
  653. # We store the index of the layer at which pruning is detected.
  654. prune0 = torch.ones_like(ind0)
  655. prune1 = torch.ones_like(ind1)
  656. token0, token1 = None, None
  657. for i in range(self.conf.n_layers):
  658. if desc0.shape[1] == 0 or desc1.shape[1] == 0: # no keypoints
  659. break
  660. desc0, desc1 = self.transformermambas[i](
  661. desc0, desc1, encoding0, encoding1, mask0=mask0, mask1=mask1
  662. )
  663. if i == self.conf.n_layers - 1:
  664. continue # no early stopping or adaptive width at last layer
  665. if do_early_stop:
  666. token0, token1 = self.token_confidence[i](desc0, desc1)
  667. if self.check_if_stop(token0[..., :m], token1[..., :n], i, m + n):
  668. break
  669. if do_point_pruning and desc0.shape[-2] > pruning_th:
  670. scores0 = self.log_assignment[i].get_matchability(desc0)
  671. prunemask0 = self.get_pruning_mask(token0, scores0, i)
  672. keep0 = torch.where(prunemask0)[1]
  673. ind0 = ind0.index_select(1, keep0)
  674. desc0 = desc0.index_select(1, keep0)
  675. encoding0 = encoding0.index_select(-2, keep0)
  676. prune0[:, ind0] += 1
  677. if do_point_pruning and desc1.shape[-2] > pruning_th:
  678. scores1 = self.log_assignment[i].get_matchability(desc1)
  679. prunemask1 = self.get_pruning_mask(token1, scores1, i)
  680. keep1 = torch.where(prunemask1)[1]
  681. ind1 = ind1.index_select(1, keep1)
  682. desc1 = desc1.index_select(1, keep1)
  683. encoding1 = encoding1.index_select(-2, keep1)
  684. prune1[:, ind1] += 1
  685. if desc0.shape[1] == 0 or desc1.shape[1] == 0: # no keypoints
  686. m0 = desc0.new_full((b, m), -1, dtype=torch.long)
  687. m1 = desc1.new_full((b, n), -1, dtype=torch.long)
  688. mscores0 = desc0.new_zeros((b, m))
  689. mscores1 = desc1.new_zeros((b, n))
  690. matches = desc0.new_empty((b, 0, 2), dtype=torch.long)
  691. mscores = desc0.new_empty((b, 0))
  692. if not do_point_pruning:
  693. prune0 = torch.ones_like(mscores0) * self.conf.n_layers
  694. prune1 = torch.ones_like(mscores1) * self.conf.n_layers
  695. return {
  696. "matches0": m0,
  697. "matches1": m1,
  698. "matching_scores0": mscores0,
  699. "matching_scores1": mscores1,
  700. "stop": i + 1,
  701. "matches": matches,
  702. "scores": mscores,
  703. "prune0": prune0,
  704. "prune1": prune1,
  705. }
  706. desc0, desc1 = desc0[..., :m, :], desc1[..., :n, :] # remove padding
  707. scores, _ = self.log_assignment[i](desc0, desc1)
  708. m0, m1, mscores0, mscores1 = filter_matches(scores, self.conf.filter_threshold)
  709. matches, mscores = [], []
  710. for k in range(b):
  711. valid = m0[k] > -1
  712. m_indices_0 = torch.where(valid)[0]
  713. m_indices_1 = m0[k][valid]
  714. if do_point_pruning:
  715. m_indices_0 = ind0[k, m_indices_0]
  716. m_indices_1 = ind1[k, m_indices_1]
  717. matches.append(torch.stack([m_indices_0, m_indices_1], -1))
  718. mscores.append(mscores0[k][valid])
  719. if do_point_pruning:
  720. m0_ = torch.full((b, m), -1, device=m0.device, dtype=m0.dtype)
  721. m1_ = torch.full((b, n), -1, device=m1.device, dtype=m1.dtype)
  722. m0_[:, ind0] = torch.where(m0 == -1, -1, ind1.gather(1, m0.clamp(min=0)))
  723. m1_[:, ind1] = torch.where(m1 == -1, -1, ind0.gather(1, m1.clamp(min=0)))
  724. mscores0_ = torch.zeros((b, m), device=mscores0.device)
  725. mscores1_ = torch.zeros((b, n), device=mscores1.device)
  726. mscores0_[:, ind0] = mscores0
  727. mscores1_[:, ind1] = mscores1
  728. m0, m1, mscores0, mscores1 = m0_, m1_, mscores0_, mscores1_
  729. else:
  730. prune0 = torch.ones_like(mscores0) * self.conf.n_layers
  731. prune1 = torch.ones_like(mscores1) * self.conf.n_layers
  732. return {
  733. "matches0": m0,
  734. "matches1": m1,
  735. "matching_scores0": mscores0,
  736. "matching_scores1": mscores1,
  737. "stop": i + 1,
  738. "matches": matches,
  739. "scores": mscores,
  740. "prune0": prune0,
  741. "prune1": prune1,
  742. }
  743. def confidence_threshold(self, layer_index: int) -> float:
  744. """scaled confidence threshold"""
  745. threshold = 0.8 + 0.1 * np.exp(-4.0 * layer_index / self.conf.n_layers)
  746. return np.clip(threshold, 0, 1)
  747. def get_pruning_mask(
  748. self, confidences: torch.Tensor, scores: torch.Tensor, layer_index: int
  749. ) -> torch.Tensor:
  750. """mask points which should be removed"""
  751. keep = scores > (1 - self.conf.width_confidence)
  752. if confidences is not None: # Low-confidence points are never pruned.
  753. keep |= confidences <= self.confidence_thresholds[layer_index]
  754. return keep
  755. def check_if_stop(
  756. self,
  757. confidences0: torch.Tensor,
  758. confidences1: torch.Tensor,
  759. layer_index: int,
  760. num_points: int,
  761. ) -> torch.Tensor:
  762. """evaluate stopping condition"""
  763. confidences = torch.cat([confidences0, confidences1], -1)
  764. threshold = self.confidence_thresholds[layer_index]
  765. ratio_confident = 1.0 - (confidences < threshold).float().sum() / num_points
  766. return ratio_confident > self.conf.depth_confidence
  767. def pruning_min_kpts(self, device: torch.device):
  768. if self.conf.flash and FLASH_AVAILABLE and device.type == "cuda":
  769. return self.pruning_keypoint_thresholds["flash"]
  770. else:
  771. return self.pruning_keypoint_thresholds[device.type]