hynet.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from typing import Dict
  18. import torch
  19. from torch import nn
  20. from kornia.core import Module, Parameter, Tensor, tensor, zeros
  21. urls: Dict[str, str] = {}
  22. urls["liberty"] = "https://github.com/ducha-aiki/Key.Net-Pytorch/raw/main/model/HyNet/weights/HyNet_LIB.pth" # pylint: disable
  23. urls["notredame"] = "https://github.com/ducha-aiki/Key.Net-Pytorch/raw/main/model/HyNet/weights/HyNet_ND.pth" # pylint: disable
  24. urls["yosemite"] = "https://github.com/ducha-aiki/Key.Net-Pytorch/raw/main/model/HyNet/weights/HyNet_YOS.pth" # pylint: disable
  25. class FilterResponseNorm2d(Module):
  26. r"""Feature Response Normalization layer from 'Filter Response Normalization Layer: Eliminating Batch Dependence
  27. in the Training of Deep Neural Networks', see :cite:`FRN2019` for more details.
  28. .. math::
  29. y = \gamma \times \frac{x}{\sqrt{\mathrm{E}[x^2]} + |\epsilon|} + \beta
  30. Args:
  31. num_features: number of channels
  32. eps: normalization constant
  33. is_bias: use bias
  34. is_scale: use scale
  35. drop_rate: dropout rate,
  36. is_eps_leanable: if eps is learnable
  37. Returns:
  38. torch.Tensor: Normalized features
  39. Shape:
  40. - Input: :math:`(B, \text{num_features}, H, W)`
  41. - Output: :math:`(B, \text{num_features}, H, W)`
  42. """ # noqa: D205
  43. def __init__(
  44. self,
  45. num_features: int,
  46. eps: float = 1e-6,
  47. is_bias: bool = True,
  48. is_scale: bool = True,
  49. is_eps_leanable: bool = False,
  50. ) -> None:
  51. super().__init__()
  52. self.num_features = num_features
  53. self.init_eps = eps
  54. self.is_eps_leanable = is_eps_leanable
  55. self.is_bias = is_bias
  56. self.is_scale = is_scale
  57. self.weight = Parameter(torch.ones(1, num_features, 1, 1), requires_grad=True)
  58. self.bias = Parameter(zeros(1, num_features, 1, 1), requires_grad=True)
  59. if is_eps_leanable:
  60. self.eps = Parameter(tensor(1), requires_grad=True)
  61. else:
  62. self.register_buffer("eps", tensor([eps]))
  63. self.reset_parameters()
  64. def reset_parameters(self) -> None:
  65. nn.init.ones_(self.weight)
  66. nn.init.zeros_(self.bias)
  67. if self.is_eps_leanable:
  68. nn.init.constant_(self.eps, self.init_eps)
  69. def extra_repr(self) -> str:
  70. return "num_features={num_features}, eps={init_eps}".format(**self.__dict__)
  71. def forward(self, x: Tensor) -> Tensor:
  72. # Compute the mean norm of activations per channel.
  73. nu2 = x.pow(2).mean(dim=[2, 3], keepdim=True)
  74. # Perform FRN.
  75. x = x * torch.rsqrt(nu2 + self.eps.abs())
  76. # Scale and Bias
  77. if self.is_scale:
  78. x = self.weight * x
  79. if self.is_bias:
  80. x = x + self.bias
  81. return x
  82. class TLU(Module):
  83. r"""TLU layer from 'Filter Response Normalization Layer: Eliminating Batch Dependence in the Training of Deep
  84. Neural Networks, see :cite:`FRN2019` for more details. :math:`{\tau}` is learnable per channel.
  85. .. math::
  86. y = \max(x, {\tau})
  87. Args:
  88. num_features: number of channels
  89. Returns:
  90. torch.Tensor
  91. Shape:
  92. - Input: :math:`(B, \text{num_features}, H, W)`
  93. - Output: :math:`(B, \text{num_features}, H, W)`
  94. """ # noqa:D205
  95. def __init__(self, num_features: int) -> None:
  96. """max(y, tau) = max(y - tau, 0) + tau = ReLU(y - tau) + tau."""
  97. super().__init__()
  98. self.num_features = num_features
  99. self.tau = Parameter(-torch.ones(1, num_features, 1, 1), requires_grad=True)
  100. self.reset_parameters()
  101. def reset_parameters(self) -> None:
  102. # nn.init.zeros_(self.tau)
  103. nn.init.constant_(self.tau, -1)
  104. def extra_repr(self) -> str:
  105. return "num_features={num_features}".format(**self.__dict__)
  106. def forward(self, x: Tensor) -> Tensor:
  107. return torch.max(x, self.tau)
  108. class HyNet(Module):
  109. r"""Module, which computes HyNet descriptors of given grayscale patches of 32x32.
  110. This is based on the original code from paper
  111. "HyNet: Learning Local Descriptor with Hybrid Similarity Measure and Triplet Loss".
  112. See :cite:`hynet2020` for more details.
  113. Args:
  114. pretrained: Download and set pretrained weights to the model.
  115. is_bias: use bias in TLU layers
  116. is_bias_FRN: use bias in FRN layers
  117. dim_desc: descriptor dimensionality,
  118. drop_rate: dropout rate,
  119. eps_l2_norm: to avoid div by zero
  120. Returns:
  121. HyNet descriptor of the patches.
  122. Shape:
  123. - Input: :math:`(B, 1, 32, 32)`
  124. - Output: :math:`(B, 128)`
  125. Examples:
  126. >>> input = torch.rand(16, 1, 32, 32)
  127. >>> hynet = HyNet()
  128. >>> descs = hynet(input) # 16x128
  129. """
  130. patch_size = 32
  131. def __init__(
  132. self,
  133. pretrained: bool = False,
  134. is_bias: bool = True,
  135. is_bias_FRN: bool = True,
  136. dim_desc: int = 128,
  137. drop_rate: float = 0.3,
  138. eps_l2_norm: float = 1e-10,
  139. ) -> None:
  140. super().__init__()
  141. self.eps_l2_norm = eps_l2_norm
  142. self.dim_desc = dim_desc
  143. self.drop_rate = drop_rate
  144. self.layer1 = nn.Sequential(
  145. FilterResponseNorm2d(1, is_bias=is_bias_FRN),
  146. TLU(1),
  147. nn.Conv2d(1, 32, kernel_size=3, padding=1, bias=is_bias),
  148. FilterResponseNorm2d(32, is_bias=is_bias_FRN),
  149. TLU(32),
  150. )
  151. self.layer2 = nn.Sequential(
  152. nn.Conv2d(32, 32, kernel_size=3, padding=1, bias=is_bias),
  153. FilterResponseNorm2d(32, is_bias=is_bias_FRN),
  154. TLU(32),
  155. )
  156. self.layer3 = nn.Sequential(
  157. nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1, bias=is_bias),
  158. FilterResponseNorm2d(64, is_bias=is_bias_FRN),
  159. TLU(64),
  160. )
  161. self.layer4 = nn.Sequential(
  162. nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=is_bias),
  163. FilterResponseNorm2d(64, is_bias=is_bias_FRN),
  164. TLU(64),
  165. )
  166. self.layer5 = nn.Sequential(
  167. nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=is_bias),
  168. FilterResponseNorm2d(128, is_bias=is_bias_FRN),
  169. TLU(128),
  170. )
  171. self.layer6 = nn.Sequential(
  172. nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=is_bias),
  173. FilterResponseNorm2d(128, is_bias=is_bias_FRN),
  174. TLU(128),
  175. )
  176. self.layer7 = nn.Sequential(
  177. nn.Dropout(self.drop_rate),
  178. nn.Conv2d(128, self.dim_desc, kernel_size=8, bias=False),
  179. nn.BatchNorm2d(self.dim_desc, affine=False),
  180. )
  181. self.desc_norm = nn.LocalResponseNorm(2 * self.dim_desc, 2.0 * self.dim_desc, 0.5, 0.0)
  182. # use torch.hub to load pretrained model
  183. if pretrained:
  184. pretrained_dict = torch.hub.load_state_dict_from_url(urls["liberty"], map_location=torch.device("cpu"))
  185. self.load_state_dict(pretrained_dict, strict=True)
  186. self.eval()
  187. def forward(self, x: Tensor) -> Tensor:
  188. x = self.layer1(x)
  189. x = self.layer2(x)
  190. x = self.layer3(x)
  191. x = self.layer4(x)
  192. x = self.layer5(x)
  193. x = self.layer6(x)
  194. x = self.layer7(x)
  195. x = self.desc_norm(x + self.eps_l2_norm)
  196. x = x.view(x.size(0), -1)
  197. return x