| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245 |
- # LICENSE HEADER MANAGED BY add-license-header
- #
- # Copyright 2018 Kornia Team
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- from typing import Dict
- import torch
- from torch import nn
- from kornia.core import Module, Parameter, Tensor, tensor, zeros
- urls: Dict[str, str] = {}
- urls["liberty"] = "https://github.com/ducha-aiki/Key.Net-Pytorch/raw/main/model/HyNet/weights/HyNet_LIB.pth" # pylint: disable
- urls["notredame"] = "https://github.com/ducha-aiki/Key.Net-Pytorch/raw/main/model/HyNet/weights/HyNet_ND.pth" # pylint: disable
- urls["yosemite"] = "https://github.com/ducha-aiki/Key.Net-Pytorch/raw/main/model/HyNet/weights/HyNet_YOS.pth" # pylint: disable
- class FilterResponseNorm2d(Module):
- r"""Feature Response Normalization layer from 'Filter Response Normalization Layer: Eliminating Batch Dependence
- in the Training of Deep Neural Networks', see :cite:`FRN2019` for more details.
- .. math::
- y = \gamma \times \frac{x}{\sqrt{\mathrm{E}[x^2]} + |\epsilon|} + \beta
- Args:
- num_features: number of channels
- eps: normalization constant
- is_bias: use bias
- is_scale: use scale
- drop_rate: dropout rate,
- is_eps_leanable: if eps is learnable
- Returns:
- torch.Tensor: Normalized features
- Shape:
- - Input: :math:`(B, \text{num_features}, H, W)`
- - Output: :math:`(B, \text{num_features}, H, W)`
- """ # noqa: D205
- def __init__(
- self,
- num_features: int,
- eps: float = 1e-6,
- is_bias: bool = True,
- is_scale: bool = True,
- is_eps_leanable: bool = False,
- ) -> None:
- super().__init__()
- self.num_features = num_features
- self.init_eps = eps
- self.is_eps_leanable = is_eps_leanable
- self.is_bias = is_bias
- self.is_scale = is_scale
- self.weight = Parameter(torch.ones(1, num_features, 1, 1), requires_grad=True)
- self.bias = Parameter(zeros(1, num_features, 1, 1), requires_grad=True)
- if is_eps_leanable:
- self.eps = Parameter(tensor(1), requires_grad=True)
- else:
- self.register_buffer("eps", tensor([eps]))
- self.reset_parameters()
- def reset_parameters(self) -> None:
- nn.init.ones_(self.weight)
- nn.init.zeros_(self.bias)
- if self.is_eps_leanable:
- nn.init.constant_(self.eps, self.init_eps)
- def extra_repr(self) -> str:
- return "num_features={num_features}, eps={init_eps}".format(**self.__dict__)
- def forward(self, x: Tensor) -> Tensor:
- # Compute the mean norm of activations per channel.
- nu2 = x.pow(2).mean(dim=[2, 3], keepdim=True)
- # Perform FRN.
- x = x * torch.rsqrt(nu2 + self.eps.abs())
- # Scale and Bias
- if self.is_scale:
- x = self.weight * x
- if self.is_bias:
- x = x + self.bias
- return x
- class TLU(Module):
- r"""TLU layer from 'Filter Response Normalization Layer: Eliminating Batch Dependence in the Training of Deep
- Neural Networks, see :cite:`FRN2019` for more details. :math:`{\tau}` is learnable per channel.
- .. math::
- y = \max(x, {\tau})
- Args:
- num_features: number of channels
- Returns:
- torch.Tensor
- Shape:
- - Input: :math:`(B, \text{num_features}, H, W)`
- - Output: :math:`(B, \text{num_features}, H, W)`
- """ # noqa:D205
- def __init__(self, num_features: int) -> None:
- """max(y, tau) = max(y - tau, 0) + tau = ReLU(y - tau) + tau."""
- super().__init__()
- self.num_features = num_features
- self.tau = Parameter(-torch.ones(1, num_features, 1, 1), requires_grad=True)
- self.reset_parameters()
- def reset_parameters(self) -> None:
- # nn.init.zeros_(self.tau)
- nn.init.constant_(self.tau, -1)
- def extra_repr(self) -> str:
- return "num_features={num_features}".format(**self.__dict__)
- def forward(self, x: Tensor) -> Tensor:
- return torch.max(x, self.tau)
- class HyNet(Module):
- r"""Module, which computes HyNet descriptors of given grayscale patches of 32x32.
- This is based on the original code from paper
- "HyNet: Learning Local Descriptor with Hybrid Similarity Measure and Triplet Loss".
- See :cite:`hynet2020` for more details.
- Args:
- pretrained: Download and set pretrained weights to the model.
- is_bias: use bias in TLU layers
- is_bias_FRN: use bias in FRN layers
- dim_desc: descriptor dimensionality,
- drop_rate: dropout rate,
- eps_l2_norm: to avoid div by zero
- Returns:
- HyNet descriptor of the patches.
- Shape:
- - Input: :math:`(B, 1, 32, 32)`
- - Output: :math:`(B, 128)`
- Examples:
- >>> input = torch.rand(16, 1, 32, 32)
- >>> hynet = HyNet()
- >>> descs = hynet(input) # 16x128
- """
- patch_size = 32
- def __init__(
- self,
- pretrained: bool = False,
- is_bias: bool = True,
- is_bias_FRN: bool = True,
- dim_desc: int = 128,
- drop_rate: float = 0.3,
- eps_l2_norm: float = 1e-10,
- ) -> None:
- super().__init__()
- self.eps_l2_norm = eps_l2_norm
- self.dim_desc = dim_desc
- self.drop_rate = drop_rate
- self.layer1 = nn.Sequential(
- FilterResponseNorm2d(1, is_bias=is_bias_FRN),
- TLU(1),
- nn.Conv2d(1, 32, kernel_size=3, padding=1, bias=is_bias),
- FilterResponseNorm2d(32, is_bias=is_bias_FRN),
- TLU(32),
- )
- self.layer2 = nn.Sequential(
- nn.Conv2d(32, 32, kernel_size=3, padding=1, bias=is_bias),
- FilterResponseNorm2d(32, is_bias=is_bias_FRN),
- TLU(32),
- )
- self.layer3 = nn.Sequential(
- nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1, bias=is_bias),
- FilterResponseNorm2d(64, is_bias=is_bias_FRN),
- TLU(64),
- )
- self.layer4 = nn.Sequential(
- nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=is_bias),
- FilterResponseNorm2d(64, is_bias=is_bias_FRN),
- TLU(64),
- )
- self.layer5 = nn.Sequential(
- nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=is_bias),
- FilterResponseNorm2d(128, is_bias=is_bias_FRN),
- TLU(128),
- )
- self.layer6 = nn.Sequential(
- nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=is_bias),
- FilterResponseNorm2d(128, is_bias=is_bias_FRN),
- TLU(128),
- )
- self.layer7 = nn.Sequential(
- nn.Dropout(self.drop_rate),
- nn.Conv2d(128, self.dim_desc, kernel_size=8, bias=False),
- nn.BatchNorm2d(self.dim_desc, affine=False),
- )
- self.desc_norm = nn.LocalResponseNorm(2 * self.dim_desc, 2.0 * self.dim_desc, 0.5, 0.0)
- # use torch.hub to load pretrained model
- if pretrained:
- pretrained_dict = torch.hub.load_state_dict_from_url(urls["liberty"], map_location=torch.device("cpu"))
- self.load_state_dict(pretrained_dict, strict=True)
- self.eval()
- def forward(self, x: Tensor) -> Tensor:
- x = self.layer1(x)
- x = self.layer2(x)
- x = self.layer3(x)
- x = self.layer4(x)
- x = self.layer5(x)
- x = self.layer6(x)
- x = self.layer7(x)
- x = self.desc_norm(x + self.eps_l2_norm)
- x = x.view(x.size(0), -1)
- return x
|