affine_shape.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. import math
  18. import warnings
  19. from typing import Dict, Optional
  20. import torch
  21. from torch import nn
  22. from kornia.core.check import KORNIA_CHECK_LAF, KORNIA_CHECK_SHAPE
  23. from kornia.filters.kernels import get_gaussian_kernel2d
  24. from kornia.filters.sobel import SpatialGradient
  25. from .laf import (
  26. ellipse_to_laf,
  27. extract_patches_from_pyramid,
  28. get_laf_orientation,
  29. get_laf_scale,
  30. make_upright,
  31. scale_laf,
  32. set_laf_orientation,
  33. )
  34. urls: Dict[str, str] = {}
  35. urls["affnet"] = "https://github.com/ducha-aiki/affnet/raw/master/pretrained/AffNet.pth"
  36. class PatchAffineShapeEstimator(nn.Module):
  37. r"""Module, which estimates the second moment matrix of the patch gradients.
  38. The method determines the affine shape of the local feature as in :cite:`baumberg2000`.
  39. Args:
  40. patch_size: the input image patch size.
  41. eps: for safe division.
  42. """
  43. def __init__(self, patch_size: int = 19, eps: float = 1e-10) -> None:
  44. super().__init__()
  45. self.patch_size: int = patch_size
  46. self.gradient: nn.Module = SpatialGradient("sobel", 1)
  47. self.eps: float = eps
  48. sigma: float = float(self.patch_size) / math.sqrt(2.0)
  49. self.weighting: torch.Tensor = get_gaussian_kernel2d((self.patch_size, self.patch_size), (sigma, sigma), True)
  50. def __repr__(self) -> str:
  51. return f"{self.__class__.__name__}(patch_size={self.patch_size}, eps={self.eps})"
  52. def forward(self, patch: torch.Tensor) -> torch.Tensor:
  53. """Run forward.
  54. Args:
  55. patch: :math:`(B, 1, H, W)`
  56. Returns:
  57. torch.Tensor: ellipse_shape :math:`(B, 1, 3)`
  58. """
  59. KORNIA_CHECK_SHAPE(patch, ["B", "1", "H", "W"])
  60. self.weighting = self.weighting.to(patch.dtype).to(patch.device)
  61. grads: torch.Tensor = self.gradient(patch) * self.weighting
  62. # unpack the edges
  63. gx: torch.Tensor = grads[:, :, 0]
  64. gy: torch.Tensor = grads[:, :, 1]
  65. # abc == 1st axis, mixture, 2nd axis. Ellipse_shape is a 2nd moment matrix.
  66. ellipse_shape = torch.cat(
  67. [
  68. gx.pow(2).mean(dim=2).mean(dim=2, keepdim=True),
  69. (gx * gy).mean(dim=2).mean(dim=2, keepdim=True),
  70. gy.pow(2).mean(dim=2).mean(dim=2, keepdim=True),
  71. ],
  72. dim=2,
  73. )
  74. # Now lets detect degenerate cases: when 2 or 3 elements are close to zero (e.g. if patch is completely black
  75. bad_mask = ((ellipse_shape < self.eps).float().sum(dim=2, keepdim=True) >= 2).to(ellipse_shape.dtype)
  76. # We will replace degenerate shape with circular shapes.
  77. circular_shape = torch.tensor([1.0, 0.0, 1.0]).to(ellipse_shape.device).to(ellipse_shape.dtype).view(1, 1, 3)
  78. ellipse_shape = ellipse_shape * (1.0 - bad_mask) + circular_shape * bad_mask
  79. # normalization
  80. ellipse_shape = ellipse_shape / ellipse_shape.max(dim=2, keepdim=True)[0]
  81. return ellipse_shape
  82. class LAFAffineShapeEstimator(nn.Module):
  83. """Module, which extracts patches using input images and local affine frames (LAFs).
  84. Then runs :class:`~kornia.feature.PatchAffineShapeEstimator` on patches to estimate LAFs shape.
  85. Then original LAF shape is replaced with estimated one. The original LAF orientation is not preserved,
  86. so it is recommended to first run LAFAffineShapeEstimator and then LAFOrienter,
  87. Args:
  88. patch_size: the input image patch size.
  89. affine_shape_detector: Patch affine shape estimator, :class:`~kornia.feature.PatchAffineShapeEstimator`.
  90. preserve_orientation: if True, the original orientation is preserved.
  91. """ # pylint: disable
  92. def __init__(
  93. self, patch_size: int = 32, affine_shape_detector: Optional[nn.Module] = None, preserve_orientation: bool = True
  94. ) -> None:
  95. super().__init__()
  96. self.patch_size = patch_size
  97. self.affine_shape_detector = affine_shape_detector or PatchAffineShapeEstimator(self.patch_size)
  98. self.preserve_orientation = preserve_orientation
  99. if preserve_orientation:
  100. warnings.warn(
  101. "`LAFAffineShapeEstimator` default behaviour is changed "
  102. "and now it does preserve original LAF orientation. "
  103. "Make sure your code accounts for this.",
  104. DeprecationWarning,
  105. stacklevel=2,
  106. )
  107. def __repr__(self) -> str:
  108. return (
  109. f"{self.__class__.__name__}"
  110. f"(patch_size={self.patch_size}, "
  111. f"affine_shape_detector={self.affine_shape_detector}, "
  112. f"preserve_orientation={self.preserve_orientation})"
  113. )
  114. def forward(self, laf: torch.Tensor, img: torch.Tensor) -> torch.Tensor:
  115. """Run forward.
  116. Args:
  117. laf: :math:`(B, N, 2, 3)`
  118. img: :math:`(B, 1, H, W)`
  119. Returns:
  120. LAF_out: :math:`(B, N, 2, 3)`
  121. """
  122. KORNIA_CHECK_LAF(laf)
  123. KORNIA_CHECK_SHAPE(img, ["B", "1", "H", "W"])
  124. B, N = laf.shape[:2]
  125. PS: int = self.patch_size
  126. patches: torch.Tensor = extract_patches_from_pyramid(img, make_upright(laf), PS, True).view(-1, 1, PS, PS)
  127. ellipse_shape: torch.Tensor = self.affine_shape_detector(patches)
  128. ellipses = torch.cat([laf.view(-1, 2, 3)[..., 2].unsqueeze(1), ellipse_shape], dim=2).view(B, N, 5)
  129. scale_orig = get_laf_scale(laf)
  130. if self.preserve_orientation:
  131. ori_orig = get_laf_orientation(laf)
  132. laf_out = ellipse_to_laf(ellipses)
  133. ellipse_scale = get_laf_scale(laf_out)
  134. laf_out = scale_laf(laf_out, scale_orig / ellipse_scale)
  135. if self.preserve_orientation:
  136. laf_out = set_laf_orientation(laf_out, ori_orig)
  137. return laf_out
  138. class LAFAffNetShapeEstimator(nn.Module):
  139. """Module, which extracts patches using input images and local affine frames (LAFs).
  140. Then runs AffNet on patches to estimate LAFs shape. This is based on the original code from paper
  141. "Repeatability Is Not Enough: Learning Discriminative Affine Regions via Discriminability"".
  142. See :cite:`AffNet2018` for more details.
  143. Then original LAF shape is replaced with estimated one. The original LAF orientation is not preserved,
  144. so it is recommended to first run LAFAffineShapeEstimator and then LAFOrienter.
  145. Args:
  146. pretrained: Download and set pretrained weights to the model.
  147. """
  148. def __init__(self, pretrained: bool = False, preserve_orientation: bool = True) -> None:
  149. super().__init__()
  150. self.features = nn.Sequential(
  151. nn.Conv2d(1, 16, kernel_size=3, padding=1, bias=False),
  152. nn.BatchNorm2d(16, affine=False),
  153. nn.ReLU(),
  154. nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias=False),
  155. nn.BatchNorm2d(16, affine=False),
  156. nn.ReLU(),
  157. nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1, bias=False),
  158. nn.BatchNorm2d(32, affine=False),
  159. nn.ReLU(),
  160. nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1, bias=False),
  161. nn.BatchNorm2d(32, affine=False),
  162. nn.ReLU(),
  163. nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1, bias=False),
  164. nn.BatchNorm2d(64, affine=False),
  165. nn.ReLU(),
  166. nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False),
  167. nn.BatchNorm2d(64, affine=False),
  168. nn.ReLU(),
  169. nn.Dropout(0.25),
  170. nn.Conv2d(64, 3, kernel_size=8, stride=1, padding=0, bias=True),
  171. nn.Tanh(),
  172. nn.AdaptiveAvgPool2d(1),
  173. )
  174. self.patch_size = 32
  175. # use torch.hub to load pretrained model
  176. if pretrained:
  177. pretrained_dict = torch.hub.load_state_dict_from_url(urls["affnet"], map_location=torch.device("cpu"))
  178. self.load_state_dict(pretrained_dict["state_dict"], strict=False)
  179. self.preserve_orientation = preserve_orientation
  180. if preserve_orientation:
  181. warnings.warn(
  182. "`LAFAffNetShapeEstimator` default behaviour is changed "
  183. "and now it does preserve original LAF orientation. "
  184. "Make sure your code accounts for this.",
  185. DeprecationWarning,
  186. stacklevel=2,
  187. )
  188. self.eval()
  189. @staticmethod
  190. def _normalize_input(x: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
  191. """Normalize the input by batch."""
  192. sp, mp = torch.std_mean(x, dim=(-3, -2, -1), keepdim=True)
  193. # WARNING: we need to .detach() input, otherwise the gradients produced by
  194. # the patches extractor with F.grid_sample are very noisy, making the detector
  195. # training totally unstable.
  196. return (x - mp.detach()) / (sp.detach() + eps)
  197. def forward(self, laf: torch.Tensor, img: torch.Tensor) -> torch.Tensor:
  198. """Run forward.
  199. Args:
  200. laf: :math:`(B, N, 2, 3)`
  201. img: :math:`(B, 1, H, W)`
  202. Returns:
  203. LAF_out: :math:`(B, N, 2, 3)`
  204. """
  205. KORNIA_CHECK_LAF(laf)
  206. KORNIA_CHECK_SHAPE(img, ["B", "1", "H", "W"])
  207. B, N = laf.shape[:2]
  208. PS: int = self.patch_size
  209. patches: torch.Tensor = extract_patches_from_pyramid(img, make_upright(laf), PS, True).view(-1, 1, PS, PS)
  210. xy = self.features(self._normalize_input(patches)).view(-1, 3)
  211. a1 = torch.cat([1.0 + xy[:, 0].reshape(-1, 1, 1), 0 * xy[:, 0].reshape(-1, 1, 1)], dim=2)
  212. a2 = torch.cat([xy[:, 1].reshape(-1, 1, 1), 1.0 + xy[:, 2].reshape(-1, 1, 1)], dim=2)
  213. new_laf_no_center = torch.cat([a1, a2], dim=1).reshape(B, N, 2, 2)
  214. new_laf = torch.cat([new_laf_no_center, laf[:, :, :, 2:3]], dim=3)
  215. scale_orig = get_laf_scale(laf)
  216. if self.preserve_orientation:
  217. ori_orig = get_laf_orientation(laf)
  218. ellipse_scale = get_laf_scale(new_laf)
  219. laf_out = scale_laf(make_upright(new_laf), scale_orig / ellipse_scale)
  220. if self.preserve_orientation:
  221. laf_out = set_laf_orientation(laf_out, ori_orig)
  222. return laf_out