sosnet.py 3.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from typing import Dict
  18. import torch
  19. from torch import nn
  20. from kornia.core.check import KORNIA_CHECK_SHAPE
  21. urls: Dict[str, str] = {}
  22. urls["lib"] = "https://github.com/yuruntian/SOSNet/raw/master/sosnet-weights/sosnet_32x32_liberty.pth"
  23. urls["hp_a"] = "https://github.com/yuruntian/SOSNet/raw/master/sosnet-weights/sosnet_32x32_hpatches_a.pth"
  24. class SOSNet(nn.Module):
  25. r"""128-dimensional SOSNet model definition for 32x32 patches.
  26. This is based on the original code from paper
  27. "SOSNet:Second Order Similarity Regularization for Local Descriptor Learning".
  28. Args:
  29. pretrained: Download and set pretrained weights to the model.
  30. Shape:
  31. - Input: :math:`(B, 1, 32, 32)`
  32. - Output: :math:`(B, 128)`
  33. Examples:
  34. >>> input = torch.rand(8, 1, 32, 32)
  35. >>> sosnet = SOSNet()
  36. >>> descs = sosnet(input) # 8x128
  37. """
  38. patch_size = 32
  39. def __init__(self, pretrained: bool = False) -> None:
  40. super().__init__()
  41. self.layers = nn.Sequential(
  42. nn.InstanceNorm2d(1, affine=False),
  43. nn.Conv2d(1, 32, kernel_size=3, padding=1, bias=False),
  44. nn.BatchNorm2d(32, affine=False),
  45. nn.ReLU(),
  46. nn.Conv2d(32, 32, kernel_size=3, padding=1, bias=False),
  47. nn.BatchNorm2d(32, affine=False),
  48. nn.ReLU(),
  49. nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1, bias=False),
  50. nn.BatchNorm2d(64, affine=False),
  51. nn.ReLU(),
  52. nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=False),
  53. nn.BatchNorm2d(64, affine=False),
  54. nn.ReLU(),
  55. nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=False),
  56. nn.BatchNorm2d(128, affine=False),
  57. nn.ReLU(),
  58. nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False),
  59. nn.BatchNorm2d(128, affine=False),
  60. nn.ReLU(),
  61. nn.Dropout(0.1),
  62. nn.Conv2d(128, 128, kernel_size=8, bias=False),
  63. nn.BatchNorm2d(128, affine=False),
  64. )
  65. self.desc_norm = nn.Sequential(nn.LocalResponseNorm(256, alpha=256.0, beta=0.5, k=0.0))
  66. # load pretrained model
  67. if pretrained:
  68. pretrained_dict = torch.hub.load_state_dict_from_url(urls["lib"], map_location=torch.device("cpu"))
  69. self.load_state_dict(pretrained_dict, strict=True)
  70. self.eval()
  71. def forward(self, input: torch.Tensor, eps: float = 1e-10) -> torch.Tensor:
  72. KORNIA_CHECK_SHAPE(input, ["B", "1", "32", "32"])
  73. descr = self.desc_norm(self.layers(input) + eps)
  74. descr = descr.view(descr.size(0), -1)
  75. return descr