aliked.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775
  1. # BSD 3-Clause License
  2. # Copyright (c) 2022, Zhao Xiaoming
  3. # All rights reserved.
  4. # Redistribution and use in source and binary forms, with or without
  5. # modification, are permitted provided that the following conditions are met:
  6. # 1. Redistributions of source code must retain the above copyright notice, this
  7. # list of conditions and the following disclaimer.
  8. # 2. Redistributions in binary form must reproduce the above copyright notice,
  9. # this list of conditions and the following disclaimer in the documentation
  10. # and/or other materials provided with the distribution.
  11. # 3. Neither the name of the copyright holder nor the names of its
  12. # contributors may be used to endorse or promote products derived from
  13. # this software without specific prior written permission.
  14. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
  15. # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  16. # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  17. # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
  18. # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
  19. # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
  20. # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
  21. # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
  22. # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  23. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  24. # Authors:
  25. # Xiaoming Zhao, Xingming Wu, Weihai Chen, Peter C.Y. Chen, Qingsong Xu, and Zhengguo Li
  26. # Code from https://github.com/Shiaoming/ALIKED
  27. from typing import Callable, Optional
  28. import torch
  29. import torch.nn.functional as F
  30. import torchvision
  31. from kornia.color import grayscale_to_rgb
  32. from torch import nn
  33. from torch.nn.modules.utils import _pair
  34. from torchvision.models import resnet
  35. from .utils import Extractor, ImagePreprocessor
  36. def get_patches(
  37. tensor: torch.Tensor, required_corners: torch.Tensor, ps: int
  38. ) -> torch.Tensor:
  39. c, h, w = tensor.shape
  40. corner = (required_corners - ps / 2 + 1).long()
  41. corner[:, 0] = corner[:, 0].clamp(min=0, max=w - 1 - ps)
  42. corner[:, 1] = corner[:, 1].clamp(min=0, max=h - 1 - ps)
  43. offset = torch.arange(0, ps)
  44. kw = {"indexing": "ij"} if torch.__version__ >= "1.10" else {}
  45. x, y = torch.meshgrid(offset, offset, **kw)
  46. patches = torch.stack((x, y)).permute(2, 1, 0).unsqueeze(2)
  47. patches = patches.to(corner) + corner[None, None]
  48. pts = patches.reshape(-1, 2)
  49. sampled = tensor.permute(1, 2, 0)[tuple(pts.T)[::-1]]
  50. sampled = sampled.reshape(ps, ps, -1, c)
  51. assert sampled.shape[:3] == patches.shape[:3]
  52. return sampled.permute(2, 3, 0, 1)
  53. def simple_nms(scores: torch.Tensor, nms_radius: int):
  54. """Fast Non-maximum suppression to remove nearby points"""
  55. zeros = torch.zeros_like(scores)
  56. max_mask = scores == torch.nn.functional.max_pool2d(
  57. scores, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius
  58. )
  59. for _ in range(2):
  60. supp_mask = (
  61. torch.nn.functional.max_pool2d(
  62. max_mask.float(),
  63. kernel_size=nms_radius * 2 + 1,
  64. stride=1,
  65. padding=nms_radius,
  66. )
  67. > 0
  68. )
  69. supp_scores = torch.where(supp_mask, zeros, scores)
  70. new_max_mask = supp_scores == torch.nn.functional.max_pool2d(
  71. supp_scores, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius
  72. )
  73. max_mask = max_mask | (new_max_mask & (~supp_mask))
  74. return torch.where(max_mask, scores, zeros)
  75. class DKD(nn.Module):
  76. def __init__(
  77. self,
  78. radius: int = 2,
  79. top_k: int = 0,
  80. scores_th: float = 0.2,
  81. n_limit: int = 20000,
  82. ):
  83. """
  84. Args:
  85. radius: soft detection radius, kernel size is (2 * radius + 1)
  86. top_k: top_k > 0: return top k keypoints
  87. scores_th: top_k <= 0 threshold mode:
  88. scores_th > 0: return keypoints with scores>scores_th
  89. else: return keypoints with scores > scores.mean()
  90. n_limit: max number of keypoint in threshold mode
  91. """
  92. super().__init__()
  93. self.radius = radius
  94. self.top_k = top_k
  95. self.scores_th = scores_th
  96. self.n_limit = n_limit
  97. self.kernel_size = 2 * self.radius + 1
  98. self.temperature = 0.1 # tuned temperature
  99. self.unfold = nn.Unfold(kernel_size=self.kernel_size, padding=self.radius)
  100. # local xy grid
  101. x = torch.linspace(-self.radius, self.radius, self.kernel_size)
  102. # (kernel_size*kernel_size) x 2 : (w,h)
  103. kw = {"indexing": "ij"} if torch.__version__ >= "1.10" else {}
  104. self.hw_grid = (
  105. torch.stack(torch.meshgrid([x, x], **kw)).view(2, -1).t()[:, [1, 0]]
  106. )
  107. def forward(
  108. self,
  109. scores_map: torch.Tensor,
  110. sub_pixel: bool = True,
  111. image_size: Optional[torch.Tensor] = None,
  112. ):
  113. """
  114. :param scores_map: Bx1xHxW
  115. :param descriptor_map: BxCxHxW
  116. :param sub_pixel: whether to use sub-pixel keypoint detection
  117. :return: kpts: list[Nx2,...]; kptscores: list[N,....] normalised position: -1~1
  118. """
  119. b, c, h, w = scores_map.shape
  120. scores_nograd = scores_map.detach()
  121. nms_scores = simple_nms(scores_nograd, self.radius)
  122. # remove border
  123. nms_scores[:, :, : self.radius, :] = 0
  124. nms_scores[:, :, :, : self.radius] = 0
  125. if image_size is not None:
  126. for i in range(scores_map.shape[0]):
  127. w, h = image_size[i].long()
  128. nms_scores[i, :, h.item() - self.radius :, :] = 0
  129. nms_scores[i, :, :, w.item() - self.radius :] = 0
  130. else:
  131. nms_scores[:, :, -self.radius :, :] = 0
  132. nms_scores[:, :, :, -self.radius :] = 0
  133. # detect keypoints without grad
  134. if self.top_k > 0:
  135. topk = torch.topk(nms_scores.view(b, -1), self.top_k)
  136. indices_keypoints = [topk.indices[i] for i in range(b)] # B x top_k
  137. else:
  138. if self.scores_th > 0:
  139. masks = nms_scores > self.scores_th
  140. if masks.sum() == 0:
  141. th = scores_nograd.reshape(b, -1).mean(dim=1) # th = self.scores_th
  142. masks = nms_scores > th.reshape(b, 1, 1, 1)
  143. else:
  144. th = scores_nograd.reshape(b, -1).mean(dim=1) # th = self.scores_th
  145. masks = nms_scores > th.reshape(b, 1, 1, 1)
  146. masks = masks.reshape(b, -1)
  147. indices_keypoints = [] # list, B x (any size)
  148. scores_view = scores_nograd.reshape(b, -1)
  149. for mask, scores in zip(masks, scores_view):
  150. indices = mask.nonzero()[:, 0]
  151. if len(indices) > self.n_limit:
  152. kpts_sc = scores[indices]
  153. sort_idx = kpts_sc.sort(descending=True)[1]
  154. sel_idx = sort_idx[: self.n_limit]
  155. indices = indices[sel_idx]
  156. indices_keypoints.append(indices)
  157. wh = torch.tensor([w - 1, h - 1], device=scores_nograd.device)
  158. keypoints = []
  159. scoredispersitys = []
  160. kptscores = []
  161. if sub_pixel:
  162. # detect soft keypoints with grad backpropagation
  163. patches = self.unfold(scores_map) # B x (kernel**2) x (H*W)
  164. self.hw_grid = self.hw_grid.to(scores_map) # to device
  165. for b_idx in range(b):
  166. patch = patches[b_idx].t() # (H*W) x (kernel**2)
  167. indices_kpt = indices_keypoints[
  168. b_idx
  169. ] # one dimension vector, say its size is M
  170. patch_scores = patch[indices_kpt] # M x (kernel**2)
  171. keypoints_xy_nms = torch.stack(
  172. [indices_kpt % w, torch.div(indices_kpt, w, rounding_mode="trunc")],
  173. dim=1,
  174. ) # Mx2
  175. # max is detached to prevent undesired backprop loops in the graph
  176. max_v = patch_scores.max(dim=1).values.detach()[:, None]
  177. x_exp = (
  178. (patch_scores - max_v) / self.temperature
  179. ).exp() # M * (kernel**2), in [0, 1]
  180. # \frac{ \sum{(i,j) \times \exp(x/T)} }{ \sum{\exp(x/T)} }
  181. xy_residual = (
  182. x_exp @ self.hw_grid / x_exp.sum(dim=1)[:, None]
  183. ) # Soft-argmax, Mx2
  184. hw_grid_dist2 = (
  185. torch.norm(
  186. (self.hw_grid[None, :, :] - xy_residual[:, None, :])
  187. / self.radius,
  188. dim=-1,
  189. )
  190. ** 2
  191. )
  192. scoredispersity = (x_exp * hw_grid_dist2).sum(dim=1) / x_exp.sum(dim=1)
  193. # compute result keypoints
  194. keypoints_xy = keypoints_xy_nms + xy_residual
  195. keypoints_xy = keypoints_xy / wh * 2 - 1 # (w,h) -> (-1~1,-1~1)
  196. kptscore = torch.nn.functional.grid_sample(
  197. scores_map[b_idx].unsqueeze(0),
  198. keypoints_xy.view(1, 1, -1, 2),
  199. mode="bilinear",
  200. align_corners=True,
  201. )[
  202. 0, 0, 0, :
  203. ] # CxN
  204. keypoints.append(keypoints_xy)
  205. scoredispersitys.append(scoredispersity)
  206. kptscores.append(kptscore)
  207. else:
  208. for b_idx in range(b):
  209. indices_kpt = indices_keypoints[
  210. b_idx
  211. ] # one dimension vector, say its size is M
  212. # To avoid warning: UserWarning: __floordiv__ is deprecated
  213. keypoints_xy_nms = torch.stack(
  214. [indices_kpt % w, torch.div(indices_kpt, w, rounding_mode="trunc")],
  215. dim=1,
  216. ) # Mx2
  217. keypoints_xy = keypoints_xy_nms / wh * 2 - 1 # (w,h) -> (-1~1,-1~1)
  218. kptscore = torch.nn.functional.grid_sample(
  219. scores_map[b_idx].unsqueeze(0),
  220. keypoints_xy.view(1, 1, -1, 2),
  221. mode="bilinear",
  222. align_corners=True,
  223. )[
  224. 0, 0, 0, :
  225. ] # CxN
  226. keypoints.append(keypoints_xy)
  227. scoredispersitys.append(kptscore) # for jit.script compatability
  228. kptscores.append(kptscore)
  229. return keypoints, kptscores, scoredispersitys
  230. class InputPadder(object):
  231. """Pads images such that dimensions are divisible by 8"""
  232. def __init__(self, h: int, w: int, divis_by: int = 8):
  233. self.ht = h
  234. self.wd = w
  235. pad_ht = (((self.ht // divis_by) + 1) * divis_by - self.ht) % divis_by
  236. pad_wd = (((self.wd // divis_by) + 1) * divis_by - self.wd) % divis_by
  237. self._pad = [
  238. pad_wd // 2,
  239. pad_wd - pad_wd // 2,
  240. pad_ht // 2,
  241. pad_ht - pad_ht // 2,
  242. ]
  243. def pad(self, x: torch.Tensor):
  244. assert x.ndim == 4
  245. return F.pad(x, self._pad, mode="replicate")
  246. def unpad(self, x: torch.Tensor):
  247. assert x.ndim == 4
  248. ht = x.shape[-2]
  249. wd = x.shape[-1]
  250. c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]]
  251. return x[..., c[0] : c[1], c[2] : c[3]]
  252. class DeformableConv2d(nn.Module):
  253. def __init__(
  254. self,
  255. in_channels,
  256. out_channels,
  257. kernel_size=3,
  258. stride=1,
  259. padding=1,
  260. bias=False,
  261. mask=False,
  262. ):
  263. super(DeformableConv2d, self).__init__()
  264. self.padding = padding
  265. self.mask = mask
  266. self.channel_num = (
  267. 3 * kernel_size * kernel_size if mask else 2 * kernel_size * kernel_size
  268. )
  269. self.offset_conv = nn.Conv2d(
  270. in_channels,
  271. self.channel_num,
  272. kernel_size=kernel_size,
  273. stride=stride,
  274. padding=self.padding,
  275. bias=True,
  276. )
  277. self.regular_conv = nn.Conv2d(
  278. in_channels=in_channels,
  279. out_channels=out_channels,
  280. kernel_size=kernel_size,
  281. stride=stride,
  282. padding=self.padding,
  283. bias=bias,
  284. )
  285. def forward(self, x):
  286. h, w = x.shape[2:]
  287. max_offset = max(h, w) / 4.0
  288. out = self.offset_conv(x)
  289. if self.mask:
  290. o1, o2, mask = torch.chunk(out, 3, dim=1)
  291. offset = torch.cat((o1, o2), dim=1)
  292. mask = torch.sigmoid(mask)
  293. else:
  294. offset = out
  295. mask = None
  296. offset = offset.clamp(-max_offset, max_offset)
  297. x = torchvision.ops.deform_conv2d(
  298. input=x,
  299. offset=offset,
  300. weight=self.regular_conv.weight,
  301. bias=self.regular_conv.bias,
  302. padding=self.padding,
  303. mask=mask,
  304. )
  305. return x
  306. def get_conv(
  307. inplanes,
  308. planes,
  309. kernel_size=3,
  310. stride=1,
  311. padding=1,
  312. bias=False,
  313. conv_type="conv",
  314. mask=False,
  315. ):
  316. if conv_type == "conv":
  317. conv = nn.Conv2d(
  318. inplanes,
  319. planes,
  320. kernel_size=kernel_size,
  321. stride=stride,
  322. padding=padding,
  323. bias=bias,
  324. )
  325. elif conv_type == "dcn":
  326. conv = DeformableConv2d(
  327. inplanes,
  328. planes,
  329. kernel_size=kernel_size,
  330. stride=stride,
  331. padding=_pair(padding),
  332. bias=bias,
  333. mask=mask,
  334. )
  335. else:
  336. raise TypeError
  337. return conv
  338. class ConvBlock(nn.Module):
  339. def __init__(
  340. self,
  341. in_channels,
  342. out_channels,
  343. gate: Optional[Callable[..., nn.Module]] = None,
  344. norm_layer: Optional[Callable[..., nn.Module]] = None,
  345. conv_type: str = "conv",
  346. mask: bool = False,
  347. ):
  348. super().__init__()
  349. if gate is None:
  350. self.gate = nn.ReLU(inplace=True)
  351. else:
  352. self.gate = gate
  353. if norm_layer is None:
  354. norm_layer = nn.BatchNorm2d
  355. self.conv1 = get_conv(
  356. in_channels, out_channels, kernel_size=3, conv_type=conv_type, mask=mask
  357. )
  358. self.bn1 = norm_layer(out_channels)
  359. self.conv2 = get_conv(
  360. out_channels, out_channels, kernel_size=3, conv_type=conv_type, mask=mask
  361. )
  362. self.bn2 = norm_layer(out_channels)
  363. def forward(self, x):
  364. x = self.gate(self.bn1(self.conv1(x))) # B x in_channels x H x W
  365. x = self.gate(self.bn2(self.conv2(x))) # B x out_channels x H x W
  366. return x
  367. # modified based on torchvision\models\resnet.py#27->BasicBlock
  368. class ResBlock(nn.Module):
  369. expansion: int = 1
  370. def __init__(
  371. self,
  372. inplanes: int,
  373. planes: int,
  374. stride: int = 1,
  375. downsample: Optional[nn.Module] = None,
  376. groups: int = 1,
  377. base_width: int = 64,
  378. dilation: int = 1,
  379. gate: Optional[Callable[..., nn.Module]] = None,
  380. norm_layer: Optional[Callable[..., nn.Module]] = None,
  381. conv_type: str = "conv",
  382. mask: bool = False,
  383. ) -> None:
  384. super(ResBlock, self).__init__()
  385. if gate is None:
  386. self.gate = nn.ReLU(inplace=True)
  387. else:
  388. self.gate = gate
  389. if norm_layer is None:
  390. norm_layer = nn.BatchNorm2d
  391. if groups != 1 or base_width != 64:
  392. raise ValueError("ResBlock only supports groups=1 and base_width=64")
  393. if dilation > 1:
  394. raise NotImplementedError("Dilation > 1 not supported in ResBlock")
  395. # Both self.conv1 and self.downsample layers
  396. # downsample the input when stride != 1
  397. self.conv1 = get_conv(
  398. inplanes, planes, kernel_size=3, conv_type=conv_type, mask=mask
  399. )
  400. self.bn1 = norm_layer(planes)
  401. self.conv2 = get_conv(
  402. planes, planes, kernel_size=3, conv_type=conv_type, mask=mask
  403. )
  404. self.bn2 = norm_layer(planes)
  405. self.downsample = downsample
  406. self.stride = stride
  407. def forward(self, x: torch.Tensor) -> torch.Tensor:
  408. identity = x
  409. out = self.conv1(x)
  410. out = self.bn1(out)
  411. out = self.gate(out)
  412. out = self.conv2(out)
  413. out = self.bn2(out)
  414. if self.downsample is not None:
  415. identity = self.downsample(x)
  416. out += identity
  417. out = self.gate(out)
  418. return out
  419. class SDDH(nn.Module):
  420. def __init__(
  421. self,
  422. dims: int,
  423. kernel_size: int = 3,
  424. n_pos: int = 8,
  425. gate=nn.ReLU(),
  426. conv2D=False,
  427. mask=False,
  428. ):
  429. super(SDDH, self).__init__()
  430. self.kernel_size = kernel_size
  431. self.n_pos = n_pos
  432. self.conv2D = conv2D
  433. self.mask = mask
  434. self.get_patches_func = get_patches
  435. # estimate offsets
  436. self.channel_num = 3 * n_pos if mask else 2 * n_pos
  437. self.offset_conv = nn.Sequential(
  438. nn.Conv2d(
  439. dims,
  440. self.channel_num,
  441. kernel_size=kernel_size,
  442. stride=1,
  443. padding=0,
  444. bias=True,
  445. ),
  446. gate,
  447. nn.Conv2d(
  448. self.channel_num,
  449. self.channel_num,
  450. kernel_size=1,
  451. stride=1,
  452. padding=0,
  453. bias=True,
  454. ),
  455. )
  456. # sampled feature conv
  457. self.sf_conv = nn.Conv2d(
  458. dims, dims, kernel_size=1, stride=1, padding=0, bias=False
  459. )
  460. # convM
  461. if not conv2D:
  462. # deformable desc weights
  463. agg_weights = torch.nn.Parameter(torch.rand(n_pos, dims, dims))
  464. self.register_parameter("agg_weights", agg_weights)
  465. else:
  466. self.convM = nn.Conv2d(
  467. dims * n_pos, dims, kernel_size=1, stride=1, padding=0, bias=False
  468. )
  469. def forward(self, x, keypoints):
  470. # x: [B,C,H,W]
  471. # keypoints: list, [[N_kpts,2], ...] (w,h)
  472. b, c, h, w = x.shape
  473. wh = torch.tensor([[w - 1, h - 1]], device=x.device)
  474. max_offset = max(h, w) / 4.0
  475. offsets = []
  476. descriptors = []
  477. # get offsets for each keypoint
  478. for ib in range(b):
  479. xi, kptsi = x[ib], keypoints[ib]
  480. kptsi_wh = (kptsi / 2 + 0.5) * wh
  481. N_kpts = len(kptsi)
  482. if self.kernel_size > 1:
  483. patch = self.get_patches_func(
  484. xi, kptsi_wh.long(), self.kernel_size
  485. ) # [N_kpts, C, K, K]
  486. else:
  487. kptsi_wh_long = kptsi_wh.long()
  488. patch = (
  489. xi[:, kptsi_wh_long[:, 1], kptsi_wh_long[:, 0]]
  490. .permute(1, 0)
  491. .reshape(N_kpts, c, 1, 1)
  492. )
  493. offset = self.offset_conv(patch).clamp(
  494. -max_offset, max_offset
  495. ) # [N_kpts, 2*n_pos, 1, 1]
  496. if self.mask:
  497. offset = (
  498. offset[:, :, 0, 0].view(N_kpts, 3, self.n_pos).permute(0, 2, 1)
  499. ) # [N_kpts, n_pos, 3]
  500. offset = offset[:, :, :-1] # [N_kpts, n_pos, 2]
  501. mask_weight = torch.sigmoid(offset[:, :, -1]) # [N_kpts, n_pos]
  502. else:
  503. offset = (
  504. offset[:, :, 0, 0].view(N_kpts, 2, self.n_pos).permute(0, 2, 1)
  505. ) # [N_kpts, n_pos, 2]
  506. offsets.append(offset) # for visualization
  507. # get sample positions
  508. pos = kptsi_wh.unsqueeze(1) + offset # [N_kpts, n_pos, 2]
  509. pos = 2.0 * pos / wh[None] - 1
  510. pos = pos.reshape(1, N_kpts * self.n_pos, 1, 2)
  511. # sample features
  512. features = F.grid_sample(
  513. xi.unsqueeze(0), pos, mode="bilinear", align_corners=True
  514. ) # [1,C,(N_kpts*n_pos),1]
  515. features = features.reshape(c, N_kpts, self.n_pos, 1).permute(
  516. 1, 0, 2, 3
  517. ) # [N_kpts, C, n_pos, 1]
  518. if self.mask:
  519. features = torch.einsum("ncpo,np->ncpo", features, mask_weight)
  520. features = torch.selu_(self.sf_conv(features)).squeeze(
  521. -1
  522. ) # [N_kpts, C, n_pos]
  523. # convM
  524. if not self.conv2D:
  525. descs = torch.einsum(
  526. "ncp,pcd->nd", features, self.agg_weights
  527. ) # [N_kpts, C]
  528. else:
  529. features = features.reshape(N_kpts, -1)[
  530. :, :, None, None
  531. ] # [N_kpts, C*n_pos, 1, 1]
  532. descs = self.convM(features).squeeze() # [N_kpts, C]
  533. # normalize
  534. descs = F.normalize(descs, p=2.0, dim=1)
  535. descriptors.append(descs)
  536. return descriptors, offsets
  537. class ALIKED(Extractor):
  538. default_conf = {
  539. "model_name": "aliked-n16",
  540. "max_num_keypoints": -1,
  541. "detection_threshold": 0.2,
  542. "nms_radius": 2,
  543. }
  544. checkpoint_url = "https://github.com/Shiaoming/ALIKED/raw/main/models/{}.pth"
  545. n_limit_max = 20000
  546. # c1, c2, c3, c4, dim, K, M
  547. cfgs = {
  548. "aliked-t16": [8, 16, 32, 64, 64, 3, 16],
  549. "aliked-n16": [16, 32, 64, 128, 128, 3, 16],
  550. "aliked-n16rot": [16, 32, 64, 128, 128, 3, 16],
  551. "aliked-n32": [16, 32, 64, 128, 128, 3, 32],
  552. }
  553. preprocess_conf = {
  554. "resize": 1024,
  555. }
  556. required_data_keys = ["image"]
  557. def __init__(self, **conf):
  558. super().__init__(**conf) # Update with default configuration.
  559. conf = self.conf
  560. c1, c2, c3, c4, dim, K, M = self.cfgs[conf.model_name]
  561. conv_types = ["conv", "conv", "dcn", "dcn"]
  562. conv2D = False
  563. mask = False
  564. # build model
  565. self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2)
  566. self.pool4 = nn.AvgPool2d(kernel_size=4, stride=4)
  567. self.norm = nn.BatchNorm2d
  568. self.gate = nn.SELU(inplace=True)
  569. self.block1 = ConvBlock(3, c1, self.gate, self.norm, conv_type=conv_types[0])
  570. self.block2 = self.get_resblock(c1, c2, conv_types[1], mask)
  571. self.block3 = self.get_resblock(c2, c3, conv_types[2], mask)
  572. self.block4 = self.get_resblock(c3, c4, conv_types[3], mask)
  573. self.conv1 = resnet.conv1x1(c1, dim // 4)
  574. self.conv2 = resnet.conv1x1(c2, dim // 4)
  575. self.conv3 = resnet.conv1x1(c3, dim // 4)
  576. self.conv4 = resnet.conv1x1(dim, dim // 4)
  577. self.upsample2 = nn.Upsample(
  578. scale_factor=2, mode="bilinear", align_corners=True
  579. )
  580. self.upsample4 = nn.Upsample(
  581. scale_factor=4, mode="bilinear", align_corners=True
  582. )
  583. self.upsample8 = nn.Upsample(
  584. scale_factor=8, mode="bilinear", align_corners=True
  585. )
  586. self.upsample32 = nn.Upsample(
  587. scale_factor=32, mode="bilinear", align_corners=True
  588. )
  589. self.score_head = nn.Sequential(
  590. resnet.conv1x1(dim, 8),
  591. self.gate,
  592. resnet.conv3x3(8, 4),
  593. self.gate,
  594. resnet.conv3x3(4, 4),
  595. self.gate,
  596. resnet.conv3x3(4, 1),
  597. )
  598. self.desc_head = SDDH(dim, K, M, gate=self.gate, conv2D=conv2D, mask=mask)
  599. self.dkd = DKD(
  600. radius=conf.nms_radius,
  601. top_k=-1 if conf.detection_threshold > 0 else conf.max_num_keypoints,
  602. scores_th=conf.detection_threshold,
  603. n_limit=(
  604. conf.max_num_keypoints
  605. if conf.max_num_keypoints > 0
  606. else self.n_limit_max
  607. ),
  608. )
  609. state_dict = torch.hub.load_state_dict_from_url(
  610. self.checkpoint_url.format(conf.model_name), map_location="cpu"
  611. )
  612. self.load_state_dict(state_dict, strict=True)
  613. def get_resblock(self, c_in, c_out, conv_type, mask):
  614. return ResBlock(
  615. c_in,
  616. c_out,
  617. 1,
  618. nn.Conv2d(c_in, c_out, 1),
  619. gate=self.gate,
  620. norm_layer=self.norm,
  621. conv_type=conv_type,
  622. mask=mask,
  623. )
  624. def extract_dense_map(self, image):
  625. # Pads images such that dimensions are divisible by
  626. div_by = 2**5
  627. padder = InputPadder(image.shape[-2], image.shape[-1], div_by)
  628. image = padder.pad(image)
  629. # ================================== feature encoder
  630. x1 = self.block1(image) # B x c1 x H x W
  631. x2 = self.pool2(x1)
  632. x2 = self.block2(x2) # B x c2 x H/2 x W/2
  633. x3 = self.pool4(x2)
  634. x3 = self.block3(x3) # B x c3 x H/8 x W/8
  635. x4 = self.pool4(x3)
  636. x4 = self.block4(x4) # B x dim x H/32 x W/32
  637. # ================================== feature aggregation
  638. x1 = self.gate(self.conv1(x1)) # B x dim//4 x H x W
  639. x2 = self.gate(self.conv2(x2)) # B x dim//4 x H//2 x W//2
  640. x3 = self.gate(self.conv3(x3)) # B x dim//4 x H//8 x W//8
  641. x4 = self.gate(self.conv4(x4)) # B x dim//4 x H//32 x W//32
  642. x2_up = self.upsample2(x2) # B x dim//4 x H x W
  643. x3_up = self.upsample8(x3) # B x dim//4 x H x W
  644. x4_up = self.upsample32(x4) # B x dim//4 x H x W
  645. x1234 = torch.cat([x1, x2_up, x3_up, x4_up], dim=1)
  646. # ================================== score head
  647. score_map = torch.sigmoid(self.score_head(x1234))
  648. feature_map = torch.nn.functional.normalize(x1234, p=2, dim=1)
  649. # Unpads images
  650. feature_map = padder.unpad(feature_map)
  651. score_map = padder.unpad(score_map)
  652. return feature_map, score_map
  653. def describe(
  654. self, keypoints: torch.Tensor, img: torch.Tensor, **conf
  655. ) -> torch.Tensor:
  656. """Extract descriptors for a set of keypoints."""
  657. if img.dim() == 3:
  658. img = img[None] # add batch dim
  659. assert img.dim() == 4 and img.shape[0] == 1
  660. w, h = img.shape[-2:][::-1]
  661. wh = torch.tensor([w - 1, h - 1], device=img.device)
  662. img, _ = ImagePreprocessor(**{**self.preprocess_conf, **conf})(img)
  663. keypoints_n = 2.0 * keypoints / wh[None, None] - 1 # [-1, 1]
  664. # Extract dense features on resized img
  665. feature_map, _ = self.extract_dense_map(img)
  666. return torch.stack(self.desc_head(feature_map, keypoints_n)[0])
  667. def forward(self, data: dict) -> dict:
  668. image = data["image"]
  669. if image.shape[1] == 1:
  670. image = grayscale_to_rgb(image)
  671. feature_map, score_map = self.extract_dense_map(image)
  672. keypoints, kptscores, scoredispersitys = self.dkd(
  673. score_map, image_size=data.get("image_size")
  674. )
  675. descriptors, offsets = self.desc_head(feature_map, keypoints)
  676. _, _, h, w = image.shape
  677. wh = torch.tensor([w - 1, h - 1], device=image.device)
  678. # no padding required
  679. # we can set detection_threshold=-1 and conf.max_num_keypoints > 0
  680. return {
  681. "keypoints": wh * (torch.stack(keypoints) + 1) / 2.0, # B x N x 2
  682. "descriptors": torch.stack(descriptors), # B x N x D
  683. "keypoint_scores": torch.stack(kptscores), # B x N
  684. }