mkd.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from typing import Any, Dict, List, Tuple, Union
  18. import torch
  19. import torch.nn.functional as F
  20. from torch import nn
  21. from kornia.constants import pi
  22. from kornia.core import Tensor, cos, sin, tensor, zeros
  23. from kornia.filters import GaussianBlur2d, SpatialGradient
  24. from kornia.geometry.conversions import cart2pol
  25. from kornia.utils import create_meshgrid
  26. # Precomputed coefficients for Von Mises kernel, given N and K(appa).
  27. sqrt2: float = 1.4142135623730951
  28. COEFFS_N1_K1: List[float] = [0.38214156, 0.48090413]
  29. COEFFS_N2_K8: List[float] = [0.14343168, 0.268285, 0.21979234]
  30. COEFFS_N3_K8: List[float] = [0.14343168, 0.268285, 0.21979234, 0.15838885]
  31. COEFFS: Dict[str, List[float]] = {"xy": COEFFS_N1_K1, "rhophi": COEFFS_N2_K8, "theta": COEFFS_N3_K8}
  32. urls: Dict[str, str] = {
  33. k: f"https://github.com/manyids2/mkd_pytorch/raw/master/mkd_pytorch/mkd-{k}-64.pth"
  34. for k in ["cart", "polar", "concat"]
  35. }
  36. def get_grid_dict(patch_size: int = 32) -> Dict[str, Tensor]:
  37. r"""Get cartesian and polar parametrizations of grid."""
  38. kgrid = create_meshgrid(height=patch_size, width=patch_size, normalized_coordinates=True)
  39. x = kgrid[0, :, :, 0]
  40. y = kgrid[0, :, :, 1]
  41. rho, phi = cart2pol(x, y)
  42. grid_dict = {"x": x, "y": y, "rho": rho, "phi": phi}
  43. return grid_dict
  44. def get_kron_order(d1: int, d2: int) -> Tensor:
  45. r"""Get order for doing kronecker product."""
  46. grid_d1, grid_d2 = torch.meshgrid(torch.arange(d1), torch.arange(d2), indexing="ij")
  47. kron_order = torch.stack([grid_d1, grid_d2], dim=2).reshape(-1, 2)
  48. return kron_order.to(torch.int64)
  49. class MKDGradients(nn.Module):
  50. r"""Module, which computes gradients of given patches, stacked as [magnitudes, orientations].
  51. Given gradients $g_x$, $g_y$ with respect to $x$, $y$ respectively,
  52. - $\mathbox{mags} = $\sqrt{g_x^2 + g_y^2 + eps}$
  53. - $\mathbox{oris} = $\mbox{tan}^{-1}(\nicefrac{g_y}{g_x})$.
  54. Args:
  55. patch_size: Input patch size in pixels.
  56. Returns:
  57. gradients of given patches.
  58. Shape:
  59. - Input: (B, 1, patch_size, patch_size)
  60. - Output: (B, 2, patch_size, patch_size)
  61. Example:
  62. >>> patches = torch.rand(23, 1, 32, 32)
  63. >>> gradient = MKDGradients()
  64. >>> g = gradient(patches) # 23x2x32x32
  65. """
  66. def __init__(self) -> None:
  67. super().__init__()
  68. self.eps = 1e-8
  69. self.grad = SpatialGradient(mode="diff", order=1, normalized=False)
  70. def forward(self, x: Tensor) -> Tensor:
  71. if not isinstance(x, Tensor):
  72. raise TypeError(f"Input type is not a Tensor. Got {type(x)}")
  73. if not len(x.shape) == 4:
  74. raise ValueError(f"Invalid input shape, we expect Bx1xHxW. Got: {x.shape}")
  75. # Modify 'diff' gradient. Before we had lambda function, but it is not jittable
  76. grads_xy = -self.grad(x)
  77. gx = grads_xy[:, :, 0, :, :]
  78. gy = grads_xy[:, :, 1, :, :]
  79. y = torch.cat(cart2pol(gx, gy, self.eps), dim=1)
  80. return y
  81. def __repr__(self) -> str:
  82. return self.__class__.__name__
  83. class VonMisesKernel(nn.Module):
  84. r"""Module, which computes parameters of Von Mises kernel given coefficients, and embeds given patches.
  85. Args:
  86. patch_size: Input patch size in pixels.
  87. coeffs: List of coefficients. Some examples are hardcoded in COEFFS,
  88. Returns:
  89. Von Mises embedding of given parametrization.
  90. Shape:
  91. - Input: (B, 1, patch_size, patch_size)
  92. - Output: (B, d, patch_size, patch_size)
  93. Examples:
  94. >>> oris = torch.rand(23, 1, 32, 32)
  95. >>> vm = VonMisesKernel(patch_size=32,
  96. ... coeffs=[0.14343168,
  97. ... 0.268285,
  98. ... 0.21979234])
  99. >>> emb = vm(oris) # 23x7x32x32
  100. """
  101. def __init__(self, patch_size: int, coeffs: Union[List[Union[float, int]], Tuple[Union[float, int], ...]]) -> None:
  102. super().__init__()
  103. self.patch_size = patch_size
  104. b_coeffs = tensor(coeffs)
  105. self.register_buffer("coeffs", b_coeffs)
  106. # Compute parameters.
  107. n = len(coeffs) - 1
  108. self.n = n
  109. self.d = 2 * n + 1
  110. # Precompute helper variables.
  111. emb0 = torch.ones([1, 1, patch_size, patch_size])
  112. frange = torch.arange(n) + 1
  113. frange = frange.reshape(-1, 1, 1)
  114. weights = zeros([2 * n + 1])
  115. weights[: n + 1] = torch.sqrt(b_coeffs)
  116. weights[n + 1 :] = torch.sqrt(b_coeffs[1:])
  117. weights = weights.reshape(-1, 1, 1)
  118. self.register_buffer("emb0", emb0)
  119. self.register_buffer("frange", frange)
  120. self.register_buffer("weights", weights)
  121. def forward(self, x: Tensor) -> Tensor:
  122. if not isinstance(x, Tensor):
  123. raise TypeError(f"Input type is not a Tensor. Got {type(x)}")
  124. if not len(x.shape) == 4 or x.shape[1] != 1:
  125. raise ValueError(f"Invalid input shape, we expect Bx1xHxW. Got: {x.shape}")
  126. if not isinstance(self.emb0, Tensor):
  127. raise TypeError(f"Emb0 type is not a Tensor. Got {type(x)}")
  128. emb0 = self.emb0.to(x).repeat(x.size(0), 1, 1, 1)
  129. frange = self.frange.to(x) * x
  130. emb1 = cos(frange)
  131. emb2 = sin(frange)
  132. embedding = torch.cat([emb0, emb1, emb2], dim=1)
  133. embedding = self.weights * embedding
  134. return embedding
  135. def __repr__(self) -> str:
  136. return f"{self.__class__.__name__}(patch_size={self.patch_size}, n={self.n}, d={self.d}, coeffs={self.coeffs})"
  137. class EmbedGradients(nn.Module):
  138. r"""Module that computes gradient embedding, weighted by sqrt of magnitudes of given patches.
  139. Args:
  140. patch_size: Input patch size in pixels.
  141. relative: absolute or relative gradients.
  142. Returns:
  143. Gradient embedding.
  144. Shape:
  145. - Input: (B, 2, patch_size, patch_size)
  146. - Output: (B, 7, patch_size, patch_size)
  147. Examples:
  148. >>> grads = torch.rand(23, 2, 32, 32)
  149. >>> emb_grads = EmbedGradients(patch_size=32,
  150. ... relative=False)
  151. >>> emb = emb_grads(grads) # 23x7x32x32
  152. """
  153. def __init__(self, patch_size: int = 32, relative: bool = False) -> None:
  154. super().__init__()
  155. self.patch_size = patch_size
  156. self.relative = relative
  157. self.eps = 1e-8
  158. # Theta kernel for gradients.
  159. self.kernel = VonMisesKernel(patch_size=patch_size, coeffs=COEFFS["theta"])
  160. # Relative gradients.
  161. kgrid = create_meshgrid(height=patch_size, width=patch_size, normalized_coordinates=True)
  162. _, phi = cart2pol(kgrid[:, :, :, 0], kgrid[:, :, :, 1])
  163. self.register_buffer("phi", phi)
  164. def emb_mags(self, mags: Tensor) -> Tensor:
  165. """Embed square roots of magnitudes with eps for numerical reasons."""
  166. mags = torch.sqrt(mags + self.eps)
  167. return mags
  168. def forward(self, grads: Tensor) -> Tensor:
  169. if not isinstance(grads, Tensor):
  170. raise TypeError(f"Input type is not a Tensor. Got {type(grads)}")
  171. if not len(grads.shape) == 4:
  172. raise ValueError(f"Invalid input shape, we expect Bx2xHxW. Got: {grads.shape}")
  173. mags = grads[:, :1, :, :]
  174. oris = grads[:, 1:, :, :]
  175. if self.relative:
  176. oris = oris - self.phi.to(oris)
  177. y = self.kernel(oris) * self.emb_mags(mags)
  178. return y
  179. def __repr__(self) -> str:
  180. return f"{self.__class__.__name__}(patch_size={self.patch_size}, relative={self.relative})"
  181. def spatial_kernel_embedding(kernel_type: str, grids: Dict[str, Tensor]) -> Tensor:
  182. r"""Compute embeddings for cartesian and polar parametrizations."""
  183. factors = {"phi": 1.0, "rho": pi / sqrt2, "x": pi / 2, "y": pi / 2}
  184. if kernel_type == "cart":
  185. coeffs_ = "xy"
  186. params_ = ["x", "y"]
  187. elif kernel_type == "polar":
  188. coeffs_ = "rhophi"
  189. params_ = ["phi", "rho"]
  190. # Infer patch_size.
  191. keys = list(grids.keys())
  192. patch_size = grids[keys[0]].shape[-1]
  193. # Scale appropriately.
  194. grids_normed = {k: v * factors[k] for k, v in grids.items()}
  195. grids_normed = {k: v.unsqueeze(0).unsqueeze(0).float() for k, v in grids_normed.items()}
  196. # x,y/rho,phi kernels.
  197. vm_a = VonMisesKernel(patch_size=patch_size, coeffs=COEFFS[coeffs_])
  198. vm_b = VonMisesKernel(patch_size=patch_size, coeffs=COEFFS[coeffs_])
  199. emb_a = vm_a(grids_normed[params_[0]]).squeeze()
  200. emb_b = vm_b(grids_normed[params_[1]]).squeeze()
  201. # Final precomputed position embedding.
  202. kron_order = get_kron_order(vm_a.d, vm_b.d)
  203. spatial_kernel = emb_a.index_select(0, kron_order[:, 0]) * emb_b.index_select(0, kron_order[:, 1])
  204. return spatial_kernel
  205. class ExplicitSpacialEncoding(nn.Module):
  206. r"""Module that computes explicit cartesian or polar embedding.
  207. Args:
  208. kernel_type: Parametrization of kernel ``'polar'`` or ``'cart'``.
  209. fmap_size: Input feature map size in pixels.
  210. in_dims: Dimensionality of input feature map.
  211. do_gmask: Apply gaussian mask.
  212. do_l2: Apply l2-normalization.
  213. Returns:
  214. Explicit cartesian or polar embedding.
  215. Shape:
  216. - Input: (B, in_dims, fmap_size, fmap_size)
  217. - Output: (B, out_dims, fmap_size, fmap_size)
  218. Example:
  219. >>> emb_ori = torch.rand(23, 7, 32, 32)
  220. >>> ese = ExplicitSpacialEncoding(kernel_type='polar',
  221. ... fmap_size=32,
  222. ... in_dims=7,
  223. ... do_gmask=True,
  224. ... do_l2=True)
  225. >>> desc = ese(emb_ori) # 23x175x32x32
  226. """
  227. def __init__(
  228. self,
  229. kernel_type: str = "polar",
  230. fmap_size: int = 32,
  231. in_dims: int = 7,
  232. do_gmask: bool = True,
  233. do_l2: bool = True,
  234. ) -> None:
  235. super().__init__()
  236. if kernel_type not in ["polar", "cart"]:
  237. raise NotImplementedError(f"{kernel_type} is not valid, use polar or cart).")
  238. self.kernel_type = kernel_type
  239. self.fmap_size = fmap_size
  240. self.in_dims = in_dims
  241. self.do_gmask = do_gmask
  242. self.do_l2 = do_l2
  243. self.grid = get_grid_dict(fmap_size)
  244. self.gmask = None
  245. # Precompute embedding.
  246. emb = spatial_kernel_embedding(self.kernel_type, self.grid)
  247. # Gaussian mask.
  248. if self.do_gmask:
  249. self.gmask = self.get_gmask(sigma=1.0)
  250. emb = emb * self.gmask
  251. # Store precomputed embedding.
  252. self.register_buffer("emb", emb.unsqueeze(0))
  253. self.d_emb: int = emb.shape[0]
  254. self.out_dims: int = self.in_dims * self.d_emb
  255. self.odims: int = self.out_dims
  256. # Store kronecker form.
  257. emb2, idx1 = self.init_kron()
  258. self.register_buffer("emb2", emb2)
  259. self.register_buffer("idx1", idx1)
  260. def get_gmask(self, sigma: float) -> Tensor:
  261. """Compute Gaussian mask."""
  262. norm_rho = self.grid["rho"] / self.grid["rho"].max()
  263. gmask = torch.exp(-1 * norm_rho**2 / sigma**2)
  264. return gmask
  265. def init_kron(self) -> Tuple[Tensor, Tensor]:
  266. """Initialize helper variables to calculate kronecker."""
  267. kron = get_kron_order(self.in_dims, self.d_emb)
  268. _emb = torch.jit.annotate(Tensor, self.emb)
  269. emb2 = torch.index_select(_emb, 1, kron[:, 1])
  270. return emb2, kron[:, 0]
  271. def forward(self, x: Tensor) -> Tensor:
  272. if not isinstance(x, Tensor):
  273. raise TypeError(f"Input type is not a Tensor. Got {type(x)}")
  274. if not ((len(x.shape) == 4) | (x.shape[1] == self.in_dims)):
  275. raise ValueError(f"Invalid input shape, we expect Bx{self.in_dims}xHxW. Got: {x.shape}")
  276. idx1 = torch.jit.annotate(Tensor, self.idx1)
  277. emb1 = torch.index_select(x, 1, idx1)
  278. output = emb1 * self.emb2
  279. output = output.sum(dim=(2, 3))
  280. if self.do_l2:
  281. output = F.normalize(output, dim=1)
  282. return output
  283. def __repr__(self) -> str:
  284. return (
  285. f"{self.__class__.__name__}("
  286. f"kernel_type={self.kernel_type}, "
  287. f"fmap_size={self.fmap_size}, "
  288. f"in_dims={self.in_dims}, "
  289. f"out_dims={self.out_dims}, "
  290. f"do_gmask={self.do_gmask}, "
  291. f"do_l2={self.do_l2})"
  292. )
  293. class Whitening(nn.Module):
  294. r"""Module, performs supervised or unsupervised whitening.
  295. This is based on the paper "Understanding and Improving Kernel Local Descriptors".
  296. See :cite:`mukundan2019understanding` for more details.
  297. Args:
  298. xform: Variant of whitening to use. None, 'lw', 'pca', 'pcaws', 'pcawt'.
  299. whitening_model: Dictionary with keys 'mean', 'eigvecs', 'eigvals' holding Tensors.
  300. in_dims: Dimensionality of input descriptors.
  301. output_dims: (int) Dimensionality reduction.
  302. keval: Shrinkage parameter.
  303. t: Attenuation parameter.
  304. Returns:
  305. l2-normalized, whitened descriptors.
  306. Shape:
  307. - Input: (B, in_dims, fmap_size, fmap_size)
  308. - Output: (B, out_dims, fmap_size, fmap_size)
  309. Examples:
  310. >>> descs = torch.rand(23, 238)
  311. >>> whitening_model = {'pca': {'mean': torch.zeros(238),
  312. ... 'eigvecs': torch.eye(238),
  313. ... 'eigvals': torch.ones(238)}}
  314. >>> whitening = Whitening(xform='pcawt',
  315. ... whitening_model=whitening_model,
  316. ... in_dims=238,
  317. ... output_dims=128,
  318. ... keval=40,
  319. ... t=0.7)
  320. >>> wdescs = whitening(descs) # 23x128
  321. """
  322. def __init__(
  323. self,
  324. xform: str,
  325. whitening_model: Union[Dict[str, Dict[str, Tensor]], None],
  326. in_dims: int,
  327. output_dims: int = 128,
  328. keval: int = 40,
  329. t: float = 0.7,
  330. ) -> None:
  331. super().__init__()
  332. self.xform = xform
  333. self.in_dims = in_dims
  334. self.keval = keval
  335. self.t = t
  336. self.pval = 1.0
  337. # Compute true output_dims.
  338. output_dims = min(output_dims, in_dims)
  339. self.output_dims = output_dims
  340. # Initialize identity transform.
  341. self.mean = nn.Parameter(zeros(in_dims), requires_grad=True)
  342. self.evecs = nn.Parameter(torch.eye(in_dims)[:, :output_dims], requires_grad=True)
  343. self.evals = nn.Parameter(torch.ones(in_dims)[:output_dims], requires_grad=True)
  344. if whitening_model is not None:
  345. self.load_whitening_parameters(whitening_model)
  346. def load_whitening_parameters(self, whitening_model: Dict[str, Dict[str, Tensor]]) -> None:
  347. algo = "lw" if self.xform == "lw" else "pca"
  348. wh_model = whitening_model[algo]
  349. self.mean.data = wh_model["mean"]
  350. self.evecs.data = wh_model["eigvecs"][:, : self.output_dims]
  351. self.evals.data = wh_model["eigvals"][: self.output_dims]
  352. modifications = {
  353. "pca": self._modify_pca,
  354. "lw": self._modify_lw,
  355. "pcaws": self._modify_pcaws,
  356. "pcawt": self._modify_pcawt,
  357. }
  358. # Call modification.
  359. modifications[self.xform]()
  360. def _modify_pca(self) -> None:
  361. """Modify powerlaw parameter."""
  362. self.pval = 0.5
  363. def _modify_lw(self) -> None:
  364. """No modification required."""
  365. def _modify_pcaws(self) -> None:
  366. """Shrinkage for eigenvalues."""
  367. alpha = self.evals[self.keval]
  368. evals = ((1 - alpha) * self.evals) + alpha
  369. self.evecs.data = self.evecs @ torch.diag(torch.pow(evals, -0.5))
  370. def _modify_pcawt(self) -> None:
  371. """Attenuation for eigenvalues."""
  372. m = -0.5 * self.t
  373. self.evecs.data = self.evecs @ torch.diag(torch.pow(self.evals, m))
  374. def forward(self, x: Tensor) -> Tensor:
  375. if not isinstance(x, Tensor):
  376. raise TypeError(f"Input type is not a Tensor. Got {type(x)}")
  377. if not len(x.shape) == 2:
  378. raise ValueError(f"Invalid input shape, we expect NxD. Got: {x.shape}")
  379. x = x - self.mean # Center the data.
  380. x = x @ self.evecs # Apply rotation and/or scaling.
  381. x = torch.sign(x) * torch.pow(torch.abs(x), self.pval) # Powerlaw.
  382. return F.normalize(x, dim=1)
  383. def __repr__(self) -> str:
  384. return f"{self.__class__.__name__}(xform={self.xform}, in_dims={self.in_dims}, output_dims={self.output_dims})"
  385. class MKDDescriptor(nn.Module):
  386. r"""Module that computes Multiple Kernel local descriptors.
  387. This is based on the paper "Understanding and Improving Kernel Local Descriptors".
  388. See :cite:`mukundan2019understanding` for more details.
  389. Args:
  390. patch_size: Input patch size in pixels.
  391. kernel_type: Parametrization of kernel ``'concat'``, ``'cart'``, ``'polar'``.
  392. whitening: Whitening transform to apply ``None``, ``'lw'``, ``'pca'``, ``'pcawt'``, ``'pcaws'``.
  393. training_set: Set that model was trained on ``'liberty'``, ``'notredame'``, ``'yosemite'``.
  394. output_dims: Dimensionality reduction.
  395. Returns:
  396. Explicit cartesian or polar embedding.
  397. Shape:
  398. - Input: :math:`(B, in_{dims}, fmap_{size}, fmap_{size})`.
  399. - Output: :math:`(B, out_{dims}, fmap_{size}, fmap_{size})`,
  400. Examples:
  401. >>> patches = torch.rand(23, 1, 32, 32)
  402. >>> mkd = MKDDescriptor(patch_size=32,
  403. ... kernel_type='concat',
  404. ... whitening='pcawt',
  405. ... training_set='liberty',
  406. ... output_dims=128)
  407. >>> desc = mkd(patches) # 23x128
  408. """
  409. def __init__(
  410. self,
  411. patch_size: int = 32,
  412. kernel_type: str = "concat",
  413. whitening: str = "pcawt",
  414. training_set: str = "liberty",
  415. output_dims: int = 128,
  416. ) -> None:
  417. super().__init__()
  418. self.patch_size: int = patch_size
  419. self.kernel_type: str = kernel_type
  420. self.whitening: str = whitening
  421. self.training_set: str = training_set
  422. self.sigma = 1.4 * (patch_size / 64)
  423. self.smoothing = GaussianBlur2d((5, 5), (self.sigma, self.sigma), "replicate")
  424. self.gradients = MKDGradients()
  425. # This stupid thing needed for jitting...
  426. polar_s: str = "polar"
  427. cart_s: str = "cart"
  428. self.parametrizations = [polar_s, cart_s] if self.kernel_type == "concat" else [self.kernel_type]
  429. # Initialize cartesian/polar embedding with absolute/relative gradients.
  430. self.odims: int = 0
  431. relative_orientations = {polar_s: True, cart_s: False}
  432. self.feats = {}
  433. for parametrization in self.parametrizations:
  434. gradient_embedding = EmbedGradients(patch_size=patch_size, relative=relative_orientations[parametrization])
  435. spatial_encoding = ExplicitSpacialEncoding(
  436. kernel_type=parametrization, fmap_size=patch_size, in_dims=gradient_embedding.kernel.d
  437. )
  438. self.feats[parametrization] = nn.Sequential(gradient_embedding, spatial_encoding)
  439. self.odims += spatial_encoding.odims
  440. # Compute true output_dims.
  441. self.output_dims: int = min(output_dims, self.odims)
  442. # Load supervised(lw)/unsupervised(pca) model trained on training_set.
  443. if self.whitening is not None:
  444. whitening_models = torch.hub.load_state_dict_from_url(
  445. urls[self.kernel_type], map_location=torch.device("cpu")
  446. )
  447. whitening_model = whitening_models[training_set]
  448. self.whitening_layer = Whitening(
  449. whitening, whitening_model, in_dims=self.odims, output_dims=self.output_dims
  450. )
  451. self.odims = self.output_dims
  452. self.eval()
  453. def forward(self, patches: Tensor) -> Tensor:
  454. if not isinstance(patches, Tensor):
  455. raise TypeError(f"Input type is not a Tensor. Got {type(patches)}")
  456. if not len(patches.shape) == 4:
  457. raise ValueError(f"Invalid input shape, we expect Bx1xHxW. Got: {patches.shape}")
  458. # Extract gradients.
  459. g = self.smoothing(patches)
  460. g = self.gradients(g)
  461. # Extract polar/cart features.
  462. features = []
  463. for parametrization in self.parametrizations:
  464. self.feats[parametrization].to(g.device)
  465. features.append(self.feats[parametrization](g))
  466. # Concatenate.
  467. y = torch.cat(features, dim=1)
  468. # l2-normalize.
  469. y = F.normalize(y, dim=1)
  470. # Whiten descriptors.
  471. if self.whitening is not None:
  472. y = self.whitening_layer(y)
  473. return y
  474. def __repr__(self) -> str:
  475. return (
  476. f"{self.__class__.__name__}("
  477. f"patch_size={self.patch_size}, "
  478. f"kernel_type={self.kernel_type}, "
  479. f"whitening={self.whitening}, "
  480. f"training_set={self.training_set}, "
  481. f"output_dims={self.output_dims})"
  482. )
  483. def load_whitening_model(kernel_type: str, training_set: str) -> Dict[str, Any]:
  484. """Load whitening model."""
  485. whitening_models = torch.hub.load_state_dict_from_url(urls[kernel_type], map_location=torch.device("cpu"))
  486. whitening_model = whitening_models[training_set]
  487. return whitening_model
  488. class SimpleKD(nn.Module):
  489. """Example to write custom Kernel Descriptors."""
  490. def __init__(
  491. self,
  492. patch_size: int = 32,
  493. kernel_type: str = "polar", # 'cart' 'polar'
  494. whitening: str = "pcawt", # 'lw', 'pca', 'pcaws', 'pcawt
  495. training_set: str = "liberty", # 'liberty', 'notredame', 'yosemite'
  496. output_dims: int = 128,
  497. ) -> None:
  498. super().__init__()
  499. relative: bool = kernel_type == "polar"
  500. sigma: float = 1.4 * (patch_size / 64)
  501. self.patch_size = patch_size
  502. # Sequence of modules.
  503. smoothing = GaussianBlur2d((5, 5), (sigma, sigma), "replicate")
  504. gradients = MKDGradients()
  505. ori = EmbedGradients(patch_size=patch_size, relative=relative)
  506. ese = ExplicitSpacialEncoding(kernel_type=kernel_type, fmap_size=patch_size, in_dims=ori.kernel.d)
  507. wh = Whitening(
  508. whitening, load_whitening_model(kernel_type, training_set), in_dims=ese.odims, output_dims=output_dims
  509. )
  510. self.features = nn.Sequential(smoothing, gradients, ori, ese, wh)
  511. def forward(self, x: Tensor) -> Tensor:
  512. return self.features(x)