orientation.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from typing import Dict, Optional
  18. import torch
  19. import torch.nn.functional as F
  20. from torch import nn
  21. from kornia.constants import pi
  22. from kornia.core.check import KORNIA_CHECK_LAF, KORNIA_CHECK_SHAPE
  23. from kornia.filters import SpatialGradient, get_gaussian_discrete_kernel1d, get_gaussian_kernel2d
  24. from kornia.geometry import rad2deg
  25. from .laf import extract_patches_from_pyramid, get_laf_orientation, set_laf_orientation
  26. urls: Dict[str, str] = {}
  27. urls["orinet"] = "https://github.com/ducha-aiki/affnet/raw/master/pretrained/OriNet.pth"
  28. class PassLAF(nn.Module):
  29. """Dummy module to use instead of local feature orientation or affine shape estimator."""
  30. def forward(self, laf: torch.Tensor, img: torch.Tensor) -> torch.Tensor:
  31. """Run forward.
  32. Args:
  33. laf: :math:`(B, N, 2, 3)`
  34. img: :math:`(B, 1, H, W)`
  35. Returns:
  36. LAF, unchanged :math:`(B, N, 2, 3)`
  37. """
  38. return laf
  39. class PatchDominantGradientOrientation(nn.Module):
  40. """Module, which estimates the dominant gradient orientation of the given patches, in radians.
  41. Zero angle points towards right.
  42. Args:
  43. patch_size: size of the (square) input patch.
  44. num_angular_bins: number of histogram bins.
  45. eps: for safe division, and arctan.
  46. """
  47. def __init__(self, patch_size: int = 32, num_angular_bins: int = 36, eps: float = 1e-8) -> None:
  48. super().__init__()
  49. self.patch_size = patch_size
  50. self.num_ang_bins = num_angular_bins
  51. self.gradient = SpatialGradient("sobel", 1)
  52. self.eps = eps
  53. self.angular_smooth = nn.Conv1d(1, 1, kernel_size=5, padding=2, bias=False, padding_mode="circular")
  54. with torch.no_grad():
  55. self.angular_smooth.weight[:] = get_gaussian_discrete_kernel1d(5, 1.6)
  56. sigma: float = float(self.patch_size) / 6.0
  57. self.weighting = get_gaussian_kernel2d((self.patch_size, self.patch_size), (sigma, sigma), True)
  58. def __repr__(self) -> str:
  59. return (
  60. f"{self.__class__.__name__}(patch_size={self.patch_size}, num_ang_bins={self.num_ang_bins}, eps={self.eps})"
  61. )
  62. def forward(self, patch: torch.Tensor) -> torch.Tensor:
  63. """Run forward.
  64. Args:
  65. patch: :math:`(B, 1, H, W)`
  66. Returns:
  67. angle in radians: :math:`(B)`
  68. """
  69. KORNIA_CHECK_SHAPE(patch, ["B", "1", "H", "W"])
  70. _, CH, W, H = patch.size()
  71. if (W != self.patch_size) or (H != self.patch_size) or (CH != 1):
  72. raise TypeError(
  73. f"input shape should be must be [Bx1x{self.patch_size}x{self.patch_size}]. Got {patch.size()}"
  74. )
  75. self.weighting = self.weighting.to(patch.dtype).to(patch.device)
  76. self.angular_smooth = self.angular_smooth.to(patch.dtype).to(patch.device)
  77. grads: torch.Tensor = self.gradient(patch)
  78. # unpack the edges
  79. gx: torch.Tensor = grads[:, :, 0]
  80. gy: torch.Tensor = grads[:, :, 1]
  81. mag: torch.Tensor = torch.sqrt(gx * gx + gy * gy + self.eps) * self.weighting
  82. ori: torch.Tensor = torch.atan2(gy, gx + self.eps) + 2.0 * pi
  83. o_big = float(self.num_ang_bins) * (ori + 1.0 * pi) / (2.0 * pi)
  84. bo0_big = torch.floor(o_big)
  85. wo1_big = o_big - bo0_big
  86. bo0_big = bo0_big % self.num_ang_bins
  87. bo1_big = (bo0_big + 1) % self.num_ang_bins
  88. wo0_big = (1.0 - wo1_big) * mag
  89. wo1_big = wo1_big * mag
  90. ang_bins_list = []
  91. for i in range(0, self.num_ang_bins):
  92. ang_bins_i = F.adaptive_avg_pool2d(
  93. (bo0_big == i).to(patch.dtype) * wo0_big + (bo1_big == i).to(patch.dtype) * wo1_big, (1, 1)
  94. )
  95. ang_bins_list.append(ang_bins_i)
  96. ang_bins = torch.cat(ang_bins_list, 1).view(-1, 1, self.num_ang_bins)
  97. ang_bins = self.angular_smooth(ang_bins).view(-1, self.num_ang_bins)
  98. values, indices = ang_bins.max(1)
  99. indices_left = (self.num_ang_bins + indices - 1) % self.num_ang_bins
  100. indices_right = (indices + 1) % self.num_ang_bins
  101. left = torch.gather(ang_bins, 1, indices_left.reshape(-1, 1)).reshape(-1)
  102. center = values
  103. right = torch.gather(ang_bins, 1, indices_right.reshape(-1, 1)).reshape(-1)
  104. c_subpix = 0.5 * (left - right) / (left + right - 2.0 * center)
  105. angle = -((2.0 * pi * (indices.to(patch.dtype) + c_subpix) / float(self.num_ang_bins)) - pi)
  106. return angle
  107. class OriNet(nn.Module):
  108. """Network, which estimates the canonical orientation of the given 32x32 patches, in radians.
  109. Zero angle points towards right. This is based on the original code from paper
  110. "Repeatability Is Not Enough: Learning Discriminative Affine Regions via Discriminability"".
  111. See :cite:`AffNet2018` for more details.
  112. Args:
  113. pretrained: Download and set pretrained weights to the model.
  114. eps: to avoid division by zero in atan2.
  115. Returns:
  116. Angle in radians.
  117. Shape:
  118. - Input: (B, 1, 32, 32)
  119. - Output: (B)
  120. Examples:
  121. >>> input = torch.rand(16, 1, 32, 32)
  122. >>> orinet = OriNet()
  123. >>> angle = orinet(input) # 16
  124. """
  125. def __init__(self, pretrained: bool = False, eps: float = 1e-8) -> None:
  126. super().__init__()
  127. self.features = nn.Sequential(
  128. nn.Conv2d(1, 16, kernel_size=3, padding=1, bias=False),
  129. nn.BatchNorm2d(16, affine=False),
  130. nn.ReLU(),
  131. nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias=False),
  132. nn.BatchNorm2d(16, affine=False),
  133. nn.ReLU(),
  134. nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1, bias=False),
  135. nn.BatchNorm2d(32, affine=False),
  136. nn.ReLU(),
  137. nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1, bias=False),
  138. nn.BatchNorm2d(32, affine=False),
  139. nn.ReLU(),
  140. nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1, bias=False),
  141. nn.BatchNorm2d(64, affine=False),
  142. nn.ReLU(),
  143. nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False),
  144. nn.BatchNorm2d(64, affine=False),
  145. nn.ReLU(),
  146. nn.Dropout(0.25),
  147. nn.Conv2d(64, 2, kernel_size=8, stride=1, padding=1, bias=True),
  148. nn.Tanh(),
  149. nn.AdaptiveAvgPool2d(1),
  150. )
  151. self.eps = eps
  152. # use torch.hub to load pretrained model
  153. if pretrained:
  154. pretrained_dict = torch.hub.load_state_dict_from_url(urls["orinet"], map_location=torch.device("cpu"))
  155. self.load_state_dict(pretrained_dict["state_dict"], strict=False)
  156. self.eval()
  157. @staticmethod
  158. def _normalize_input(x: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
  159. """Utility function that normalizes the input by batch."""
  160. sp, mp = torch.std_mean(x, dim=(-3, -2, -1), keepdim=True)
  161. # WARNING: we need to .detach() input, otherwise the gradients produced by
  162. # the patches extractor with F.grid_sample are very noisy, making the detector
  163. # training totally unstable.
  164. return (x - mp.detach()) / (sp.detach() + eps)
  165. def forward(self, patch: torch.Tensor) -> torch.Tensor:
  166. """Run forward.
  167. Args:
  168. patch: :math:`(B, 1, H, W)`
  169. Returns:
  170. angle in radians: :math:`(B)`
  171. """
  172. xy = self.features(self._normalize_input(patch)).view(-1, 2)
  173. angle = torch.atan2(xy[:, 0] + 1e-8, xy[:, 1] + self.eps)
  174. return angle
  175. class LAFOrienter(nn.Module):
  176. """Module, which extracts patches using input images and local affine frames (LAFs).
  177. Then runs :class:`~kornia.feature.PatchDominantGradientOrientation` or
  178. :class:`~kornia.feature.OriNet` on patches and then rotates the LAFs by the estimated angles
  179. Args:
  180. patch_size:
  181. num_angular_bins:
  182. angle_detector: Patch orientation estimator, e.g. :class:`~kornia.feature.PatchDominantGradientOrientation`
  183. or OriNet.
  184. """ # pylint: disable
  185. def __init__(
  186. self, patch_size: int = 32, num_angular_bins: int = 36, angle_detector: Optional[nn.Module] = None
  187. ) -> None:
  188. super().__init__()
  189. self.patch_size = patch_size
  190. self.num_ang_bins = num_angular_bins
  191. self.angle_detector: nn.Module
  192. if angle_detector is None:
  193. self.angle_detector = PatchDominantGradientOrientation(self.patch_size, self.num_ang_bins)
  194. else:
  195. self.angle_detector = angle_detector
  196. def __repr__(self) -> str:
  197. return f"{self.__class__.__name__}(patch_size={self.patch_size}, angle_detector={self.angle_detector})"
  198. def forward(self, laf: torch.Tensor, img: torch.Tensor) -> torch.Tensor:
  199. """Run forward.
  200. Args:
  201. laf: :math:`(B, N, 2, 3)`
  202. img: :math:`(B, 1, H, W)`
  203. Returns:
  204. LAF_out: :math:`(B, N, 2, 3)`
  205. """
  206. KORNIA_CHECK_LAF(laf)
  207. KORNIA_CHECK_SHAPE(img, ["B", "C", "H", "W"])
  208. if laf.size(0) != img.size(0):
  209. raise ValueError(f"Batch size of laf and img should be the same. Got {img.size(0)}, {laf.size(0)}")
  210. B, N = laf.shape[:2]
  211. patches: torch.Tensor = extract_patches_from_pyramid(img, laf, self.patch_size).view(
  212. -1, 1, self.patch_size, self.patch_size
  213. )
  214. angles_radians: torch.Tensor = self.angle_detector(patches).view(B, N)
  215. prev_angle = get_laf_orientation(laf).view_as(angles_radians)
  216. laf_out: torch.Tensor = set_laf_orientation(laf, rad2deg(angles_radians) + prev_angle)
  217. return laf_out