siftdesc.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. import math
  18. from typing import Tuple
  19. import torch
  20. from torch import nn
  21. from kornia.core import Module, Tensor, concatenate, eye, normalize
  22. from kornia.core.check import KORNIA_CHECK_SHAPE
  23. from kornia.filters import get_gaussian_kernel2d, spatial_gradient
  24. from kornia.geometry.conversions import pi
  25. def _get_reshape_kernel(kd: int, ky: int, kx: int) -> Tensor:
  26. """Return neigh2channels conv kernel."""
  27. numel: int = kd * ky * kx
  28. # Fast-path: use static _eye_cache if available for small numel
  29. # (to avoid repeated allocations for common kernel sizes)
  30. # The cache size is limited for memory efficiency.
  31. # NOTE: If memory is a concern and large kd/ky/kx are rare, adjust _MAX_CACHED.
  32. _MAX_CACHED = 4096
  33. if numel <= _MAX_CACHED:
  34. if not hasattr(_get_reshape_kernel, "_eye_cache"):
  35. _get_reshape_kernel._eye_cache = {} # type: ignore[attr-defined]
  36. cache = _get_reshape_kernel._eye_cache # type: ignore[attr-defined]
  37. res = cache.get(numel)
  38. if res is None:
  39. res = eye(numel)
  40. cache[numel] = res
  41. return res.view(numel, kd, ky, kx)
  42. else:
  43. # fallback to normal allocation for big kernels
  44. return eye(numel).view(numel, kd, ky, kx)
  45. def get_sift_pooling_kernel(ksize: int = 25) -> Tensor:
  46. r"""Return a weighted pooling kernel for SIFT descriptor.
  47. Args:
  48. ksize: kernel_size.
  49. Returns:
  50. the pooling kernel with shape :math:`(ksize, ksize)`.
  51. """
  52. ks_2: float = float(ksize) / 2.0
  53. xc2 = ks_2 - (torch.arange(ksize).float() + 0.5 - ks_2).abs()
  54. kernel = torch.ger(xc2, xc2) / (ks_2**2)
  55. return kernel
  56. def get_sift_bin_ksize_stride_pad(patch_size: int, num_spatial_bins: int) -> Tuple[int, int, int]:
  57. r"""Return a tuple with SIFT parameters.
  58. Args:
  59. patch_size: the given patch size.
  60. num_spatial_bins: the ggiven number of spatial bins.
  61. Returns:
  62. ksize, stride, pad.
  63. """
  64. ksize: int = 2 * int(patch_size / (num_spatial_bins + 1))
  65. stride: int = patch_size // num_spatial_bins
  66. pad: int = ksize // 4
  67. out_size: int = (patch_size + 2 * pad - (ksize - 1) - 1) // stride + 1
  68. if out_size != num_spatial_bins:
  69. raise ValueError(
  70. f"Patch size {patch_size} is incompatible with requested number of spatial bins"
  71. f" {num_spatial_bins} for SIFT descriptor. Usually it happens when patch size is too small "
  72. " for num_spatial_bins specified"
  73. )
  74. return ksize, stride, pad
  75. class SIFTDescriptor(Module):
  76. r"""Module which computes SIFT descriptors of given patches.
  77. Args:
  78. patch_size: Input patch size in pixels.
  79. num_ang_bins: Number of angular bins.
  80. num_spatial_bins: Number of spatial bins.
  81. clipval: clipping value to reduce single-bin dominance
  82. rootsift: if ``True``, RootSIFT (Arandjelović et. al, 2012) is computed.
  83. Returns:
  84. SIFT descriptor of the patches with shape.
  85. Shape:
  86. - Input: :math:`(B, 1, \text{num_spatial_bins}, \text{num_spatial_bins})`
  87. - Output: :math:`(B, \text{num_ang_bins * num_spatial_bins ** 2})`
  88. Example:
  89. >>> input = torch.rand(23, 1, 32, 32)
  90. >>> SIFT = SIFTDescriptor(32, 8, 4)
  91. >>> descs = SIFT(input) # 23x128
  92. """
  93. def __repr__(self) -> str:
  94. return (
  95. f"{self.__class__.__name__}("
  96. f"num_ang_bins={self.num_ang_bins}, "
  97. f"num_spatial_bins={self.num_spatial_bins}, "
  98. f"patch_size={self.patch_size}, "
  99. f"rootsift={self.rootsift}, "
  100. f"clipval={self.clipval})"
  101. )
  102. def __init__(
  103. self,
  104. patch_size: int = 41,
  105. num_ang_bins: int = 8,
  106. num_spatial_bins: int = 4,
  107. rootsift: bool = True,
  108. clipval: float = 0.2,
  109. ) -> None:
  110. super().__init__()
  111. self.eps = 1e-10
  112. self.num_ang_bins = num_ang_bins
  113. self.num_spatial_bins = num_spatial_bins
  114. self.clipval = clipval
  115. self.rootsift = rootsift
  116. self.patch_size = patch_size
  117. ks: int = self.patch_size
  118. sigma: float = float(ks) / math.sqrt(2.0)
  119. self.gk = get_gaussian_kernel2d((ks, ks), (sigma, sigma), True)
  120. (self.bin_ksize, self.bin_stride, self.pad) = get_sift_bin_ksize_stride_pad(patch_size, num_spatial_bins)
  121. nw = get_sift_pooling_kernel(ksize=self.bin_ksize).float()
  122. self.pk = nn.Conv2d(
  123. 1,
  124. 1,
  125. kernel_size=(nw.size(0), nw.size(1)),
  126. stride=(self.bin_stride, self.bin_stride),
  127. padding=(self.pad, self.pad),
  128. bias=False,
  129. )
  130. self.pk.weight.data.copy_(nw.reshape(1, 1, nw.size(0), nw.size(1)))
  131. def get_pooling_kernel(self) -> Tensor:
  132. return self.pk.weight.detach()
  133. def get_weighting_kernel(self) -> Tensor:
  134. return self.gk.detach()
  135. def forward(self, input: Tensor) -> Tensor:
  136. KORNIA_CHECK_SHAPE(input, ["B", "1", f"{self.patch_size}", f"{self.patch_size}"])
  137. B: int = input.shape[0]
  138. self.pk = self.pk.to(input.dtype).to(input.device)
  139. grads = spatial_gradient(input, "diff")
  140. # unpack the edges
  141. gx = grads[:, :, 0]
  142. gy = grads[:, :, 1]
  143. mag = torch.sqrt(gx * gx + gy * gy + self.eps)
  144. ori = torch.atan2(gy, gx + self.eps) + 2.0 * pi
  145. mag = mag * self.gk.expand_as(mag).type_as(mag).to(mag.device)
  146. o_big = float(self.num_ang_bins) * ori / (2.0 * pi)
  147. bo0_big_ = torch.floor(o_big)
  148. wo1_big_ = o_big - bo0_big_
  149. bo0_big = bo0_big_ % self.num_ang_bins
  150. bo1_big = (bo0_big + 1) % self.num_ang_bins
  151. wo0_big = (1.0 - wo1_big_) * mag
  152. wo1_big = wo1_big_ * mag
  153. ang_bins = concatenate(
  154. [
  155. self.pk((bo0_big == i).to(input.dtype) * wo0_big + (bo1_big == i).to(input.dtype) * wo1_big)
  156. for i in range(0, self.num_ang_bins)
  157. ],
  158. 1,
  159. )
  160. ang_bins = ang_bins.view(B, -1)
  161. ang_bins = normalize(ang_bins, p=2)
  162. ang_bins = torch.clamp(ang_bins, 0.0, float(self.clipval))
  163. ang_bins = normalize(ang_bins, p=2)
  164. if self.rootsift:
  165. ang_bins = torch.sqrt(normalize(ang_bins, p=1) + self.eps)
  166. return ang_bins
  167. def sift_describe(
  168. input: Tensor,
  169. patch_size: int = 41,
  170. num_ang_bins: int = 8,
  171. num_spatial_bins: int = 4,
  172. rootsift: bool = True,
  173. clipval: float = 0.2,
  174. ) -> Tensor:
  175. r"""Compute the sift descriptor.
  176. See
  177. :class: `~kornia.feature.SIFTDescriptor` for details.
  178. """
  179. return SIFTDescriptor(patch_size, num_ang_bins, num_spatial_bins, rootsift, clipval)(input)
  180. class DenseSIFTDescriptor(Module):
  181. """Module, which computes SIFT descriptor densely over the image.
  182. Args:
  183. num_ang_bins: Number of angular bins. (8 is default)
  184. num_spatial_bins: Number of spatial bins per descriptor (4 is default).
  185. You might want to set odd number and relevant padding to keep feature map size
  186. spatial_bin_size: Size of a spatial bin in pixels (4 is default)
  187. clipval: clipping value to reduce single-bin dominance
  188. rootsift: (bool) if True, RootSIFT (Arandjelović et. al, 2012) is computed
  189. stride: default 1
  190. padding: default 0
  191. Returns:
  192. Tensor: DenseSIFT descriptor of the image
  193. Shape:
  194. - Input: (B, 1, H, W)
  195. - Output: (B, num_ang_bins * num_spatial_bins ** 2, (H+padding)/stride, (W+padding)/stride)
  196. Examples::
  197. >>> input = torch.rand(2, 1, 200, 300)
  198. >>> SIFT = DenseSIFTDescriptor()
  199. >>> descs = SIFT(input) # 2x128x194x294
  200. """
  201. def __repr__(self) -> str:
  202. return (
  203. f"{self.__class__.__name__}("
  204. f"num_ang_bins={self.num_ang_bins}, "
  205. f"num_spatial_bins={self.num_spatial_bins}, "
  206. f"spatial_bin_size={self.spatial_bin_size}, "
  207. f"rootsift={self.rootsift}, "
  208. f"stride={self.stride}, "
  209. f"clipval={self.clipval})"
  210. )
  211. def __init__(
  212. self,
  213. num_ang_bins: int = 8,
  214. num_spatial_bins: int = 4,
  215. spatial_bin_size: int = 4,
  216. rootsift: bool = True,
  217. clipval: float = 0.2,
  218. stride: int = 1,
  219. padding: int = 1,
  220. ) -> None:
  221. super().__init__()
  222. self.eps = 1e-10
  223. self.num_ang_bins = num_ang_bins
  224. self.num_spatial_bins = num_spatial_bins
  225. self.spatial_bin_size = spatial_bin_size
  226. self.clipval = clipval
  227. self.rootsift = rootsift
  228. self.stride = stride
  229. self.pad = padding
  230. # Only allocate pooling kernels once during construction
  231. nw = get_sift_pooling_kernel(ksize=self.spatial_bin_size).float()
  232. self.register_buffer("_bin_pooling_kernel_weight", nw.reshape(1, 1, nw.size(0), nw.size(1)))
  233. bin_pooling_kernel = nn.Conv2d(
  234. 1,
  235. 1,
  236. kernel_size=(nw.size(0), nw.size(1)),
  237. stride=(1, 1),
  238. bias=False,
  239. padding=(nw.size(0) // 2, nw.size(1) // 2),
  240. )
  241. bin_pooling_kernel.weight.data.copy_(self._bin_pooling_kernel_weight)
  242. self.bin_pooling_kernel = bin_pooling_kernel
  243. Pw = _get_reshape_kernel(num_ang_bins, num_spatial_bins, num_spatial_bins).float()
  244. self.register_buffer("_poolingconv_weight", Pw)
  245. PoolingConv = nn.Conv2d(
  246. num_ang_bins,
  247. num_ang_bins * num_spatial_bins**2,
  248. kernel_size=(num_spatial_bins, num_spatial_bins),
  249. stride=(self.stride, self.stride),
  250. bias=False,
  251. padding=(self.pad, self.pad),
  252. )
  253. PoolingConv.weight.data.copy_(self._poolingconv_weight)
  254. self.PoolingConv = PoolingConv
  255. # Cache pooling kernel tensor for fast return in get_pooling_kernel
  256. self._pooling_kernel = self._bin_pooling_kernel_weight.detach()
  257. def get_pooling_kernel(self) -> Tensor:
  258. # Return the cached detached pooling kernel directly for optimal speed
  259. return self._pooling_kernel
  260. def forward(self, input: Tensor) -> Tensor:
  261. KORNIA_CHECK_SHAPE(input, ["B", "1", "H", "W"])
  262. _B, _CH, _W, _H = input.size()
  263. self.bin_pooling_kernel = self.bin_pooling_kernel.to(input.dtype).to(input.device)
  264. self.PoolingConv = self.PoolingConv.to(input.dtype).to(input.device)
  265. grads = spatial_gradient(input, "diff")
  266. # unpack the edges
  267. gx = grads[:, :, 0]
  268. gy = grads[:, :, 1]
  269. mag = torch.sqrt(gx * gx + gy * gy + self.eps)
  270. ori = torch.atan2(gy, gx + self.eps) + 2.0 * pi
  271. o_big = float(self.num_ang_bins) * ori / (2.0 * pi)
  272. bo0_big_ = torch.floor(o_big)
  273. wo1_big_ = o_big - bo0_big_
  274. bo0_big = bo0_big_ % self.num_ang_bins
  275. bo1_big = (bo0_big + 1) % self.num_ang_bins
  276. wo0_big = (1.0 - wo1_big_) * mag
  277. wo1_big = wo1_big_ * mag
  278. ang_bins = concatenate(
  279. [
  280. self.bin_pooling_kernel(
  281. (bo0_big == i).to(input.dtype) * wo0_big + (bo1_big == i).to(input.dtype) * wo1_big
  282. )
  283. for i in range(0, self.num_ang_bins)
  284. ],
  285. 1,
  286. )
  287. out_no_norm = self.PoolingConv(ang_bins)
  288. out = normalize(out_no_norm, dim=1, p=2).clamp_(0, float(self.clipval))
  289. out = normalize(out, dim=1, p=2)
  290. if self.rootsift:
  291. out = torch.sqrt(normalize(out, p=1) + self.eps)
  292. return out