keynet.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from typing import List, Optional
  18. import torch
  19. import torch.nn.functional as F
  20. from torch import nn
  21. from typing_extensions import TypedDict
  22. from kornia.core import Module, Tensor, concatenate
  23. from kornia.filters import SpatialGradient
  24. from kornia.geometry.transform import pyrdown
  25. from .scale_space_detector import Detector_config, MultiResolutionDetector, get_default_detector_config
  26. class KeyNet_conf(TypedDict):
  27. num_filters: int
  28. num_levels: int
  29. kernel_size: int
  30. Detector_conf: Detector_config
  31. keynet_default_config: KeyNet_conf = {
  32. # Key.Net Model
  33. "num_filters": 8,
  34. "num_levels": 3,
  35. "kernel_size": 5,
  36. # Extraction Parameters
  37. "Detector_conf": get_default_detector_config(),
  38. }
  39. KeyNet_URL = "https://github.com/axelBarroso/Key.Net-Pytorch/raw/main/model/weights/keynet_pytorch.pth"
  40. class _FeatureExtractor(Module):
  41. """Helper class for KeyNet.
  42. It loads both, the handcrafted and learnable blocks
  43. """
  44. def __init__(self) -> None:
  45. super().__init__()
  46. self.hc_block = _HandcraftedBlock()
  47. self.lb_block = _LearnableBlock()
  48. def forward(self, x: Tensor) -> Tensor:
  49. x_hc = self.hc_block(x)
  50. x_lb = self.lb_block(x_hc)
  51. return x_lb
  52. class _HandcraftedBlock(Module):
  53. """Helper class for KeyNet, it defines the handcrafted filters within the Key.Net handcrafted block."""
  54. def __init__(self) -> None:
  55. super().__init__()
  56. self.spatial_gradient = SpatialGradient("sobel", 1)
  57. def forward(self, x: Tensor) -> Tensor:
  58. sobel = self.spatial_gradient(x)
  59. dx, dy = sobel[:, :, 0, :, :], sobel[:, :, 1, :, :]
  60. sobel_dx = self.spatial_gradient(dx)
  61. dxx, dxy = sobel_dx[:, :, 0, :, :], sobel_dx[:, :, 1, :, :]
  62. sobel_dy = self.spatial_gradient(dy)
  63. dyy = sobel_dy[:, :, 1, :, :]
  64. hc_feats = concatenate([dx, dy, dx**2.0, dy**2.0, dx * dy, dxy, dxy**2.0, dxx, dyy, dxx * dyy], 1)
  65. return hc_feats
  66. class _LearnableBlock(nn.Sequential):
  67. """Helper class for KeyNet.
  68. It defines the learnable blocks within the Key.Net
  69. """
  70. def __init__(self, in_channels: int = 10) -> None:
  71. super().__init__()
  72. self.conv0 = _KeyNetConvBlock(in_channels)
  73. self.conv1 = _KeyNetConvBlock()
  74. self.conv2 = _KeyNetConvBlock()
  75. def forward(self, x: Tensor) -> Tensor:
  76. x = self.conv2(self.conv1(self.conv0(x)))
  77. return x
  78. def _KeyNetConvBlock(
  79. in_channels: int = 8,
  80. out_channels: int = 8,
  81. kernel_size: int = 5,
  82. stride: int = 1,
  83. padding: int = 2,
  84. dilation: int = 1,
  85. ) -> nn.Sequential:
  86. """Create KeyNet Conv Block.
  87. Default learnable convolutional block for KeyNet.
  88. """
  89. return nn.Sequential(
  90. nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation),
  91. nn.BatchNorm2d(out_channels),
  92. nn.ReLU(inplace=True),
  93. )
  94. class KeyNet(Module):
  95. """Key.Net model definition -- local feature detector (response function).
  96. This is based on the original code
  97. from paper "Key.Net: Keypoint Detection by Handcrafted and Learned CNN Filters". See :cite:`KeyNet2019` for
  98. more details.
  99. .. image:: _static/img/KeyNet.png
  100. Args:
  101. pretrained: Download and set pretrained weights to the model.
  102. keynet_conf: Dict with initialization parameters. Do not pass it, unless you know what you are doing`.
  103. Returns:
  104. KeyNet response score.
  105. Shape:
  106. - Input: :math:`(B, 1, H, W)`
  107. - Output: :math:`(B, 1, H, W)`
  108. """
  109. def __init__(self, pretrained: bool = False, keynet_conf: KeyNet_conf = keynet_default_config) -> None:
  110. super().__init__()
  111. num_filters = keynet_conf["num_filters"]
  112. self.num_levels = keynet_conf["num_levels"]
  113. kernel_size = keynet_conf["kernel_size"]
  114. padding = kernel_size // 2
  115. self.feature_extractor = _FeatureExtractor()
  116. self.last_conv = nn.Sequential(
  117. nn.Conv2d(
  118. in_channels=num_filters * self.num_levels, out_channels=1, kernel_size=kernel_size, padding=padding
  119. ),
  120. nn.ReLU(inplace=True),
  121. )
  122. # use torch.hub to load pretrained model
  123. if pretrained:
  124. pretrained_dict = torch.hub.load_state_dict_from_url(KeyNet_URL, map_location=torch.device("cpu"))
  125. self.load_state_dict(pretrained_dict["state_dict"], strict=True)
  126. self.eval()
  127. def forward(self, x: Tensor) -> Tensor:
  128. """X - input image."""
  129. shape_im = x.shape
  130. feats: List[Tensor] = [self.feature_extractor(x)]
  131. for _ in range(1, self.num_levels):
  132. x = pyrdown(x, factor=1.2)
  133. feats_i = self.feature_extractor(x)
  134. feats_i = F.interpolate(feats_i, size=(shape_im[2], shape_im[3]), mode="bilinear")
  135. feats.append(feats_i)
  136. scores = self.last_conv(concatenate(feats, 1))
  137. return scores
  138. class KeyNetDetector(MultiResolutionDetector):
  139. """Multi-scale feature detector based on KeyNet.
  140. This is based on the original code from paper
  141. "Key.Net: Keypoint Detection by Handcrafted and Learned CNN Filters".
  142. See :cite:`KeyNet2019` for more details.
  143. .. image:: _static/img/keynet.jpg
  144. Args:
  145. pretrained: Download and set pretrained weights to the model.
  146. num_features: Number of features to detect.
  147. keynet_conf: Dict with initialization parameters. Do not pass it, unless you know what you are doing`.
  148. ori_module: for local feature orientation estimation. Default: :class:`~kornia.feature.PassLAF`,
  149. which does nothing. See :class:`~kornia.feature.LAFOrienter` for details.
  150. aff_module: for local feature affine shape estimation. Default: :class:`~kornia.feature.PassLAF`,
  151. which does nothing. See :class:`~kornia.feature.LAFAffineShapeEstimator` for details.
  152. """
  153. def __init__(
  154. self,
  155. pretrained: bool = False,
  156. num_features: int = 2048,
  157. keynet_conf: KeyNet_conf = keynet_default_config,
  158. ori_module: Optional[Module] = None,
  159. aff_module: Optional[Module] = None,
  160. ) -> None:
  161. model = KeyNet(pretrained, keynet_conf)
  162. super().__init__(model, num_features, keynet_conf["Detector_conf"], ori_module, aff_module)