| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266 |
- # LICENSE HEADER MANAGED BY add-license-header
- #
- # Copyright 2018 Kornia Team
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- from typing import Dict, Optional
- import torch
- import torch.nn.functional as F
- from torch import nn
- from kornia.constants import pi
- from kornia.core.check import KORNIA_CHECK_LAF, KORNIA_CHECK_SHAPE
- from kornia.filters import SpatialGradient, get_gaussian_discrete_kernel1d, get_gaussian_kernel2d
- from kornia.geometry import rad2deg
- from .laf import extract_patches_from_pyramid, get_laf_orientation, set_laf_orientation
- urls: Dict[str, str] = {}
- urls["orinet"] = "https://github.com/ducha-aiki/affnet/raw/master/pretrained/OriNet.pth"
- class PassLAF(nn.Module):
- """Dummy module to use instead of local feature orientation or affine shape estimator."""
- def forward(self, laf: torch.Tensor, img: torch.Tensor) -> torch.Tensor:
- """Run forward.
- Args:
- laf: :math:`(B, N, 2, 3)`
- img: :math:`(B, 1, H, W)`
- Returns:
- LAF, unchanged :math:`(B, N, 2, 3)`
- """
- return laf
- class PatchDominantGradientOrientation(nn.Module):
- """Module, which estimates the dominant gradient orientation of the given patches, in radians.
- Zero angle points towards right.
- Args:
- patch_size: size of the (square) input patch.
- num_angular_bins: number of histogram bins.
- eps: for safe division, and arctan.
- """
- def __init__(self, patch_size: int = 32, num_angular_bins: int = 36, eps: float = 1e-8) -> None:
- super().__init__()
- self.patch_size = patch_size
- self.num_ang_bins = num_angular_bins
- self.gradient = SpatialGradient("sobel", 1)
- self.eps = eps
- self.angular_smooth = nn.Conv1d(1, 1, kernel_size=5, padding=2, bias=False, padding_mode="circular")
- with torch.no_grad():
- self.angular_smooth.weight[:] = get_gaussian_discrete_kernel1d(5, 1.6)
- sigma: float = float(self.patch_size) / 6.0
- self.weighting = get_gaussian_kernel2d((self.patch_size, self.patch_size), (sigma, sigma), True)
- def __repr__(self) -> str:
- return (
- f"{self.__class__.__name__}(patch_size={self.patch_size}, num_ang_bins={self.num_ang_bins}, eps={self.eps})"
- )
- def forward(self, patch: torch.Tensor) -> torch.Tensor:
- """Run forward.
- Args:
- patch: :math:`(B, 1, H, W)`
- Returns:
- angle in radians: :math:`(B)`
- """
- KORNIA_CHECK_SHAPE(patch, ["B", "1", "H", "W"])
- _, CH, W, H = patch.size()
- if (W != self.patch_size) or (H != self.patch_size) or (CH != 1):
- raise TypeError(
- f"input shape should be must be [Bx1x{self.patch_size}x{self.patch_size}]. Got {patch.size()}"
- )
- self.weighting = self.weighting.to(patch.dtype).to(patch.device)
- self.angular_smooth = self.angular_smooth.to(patch.dtype).to(patch.device)
- grads: torch.Tensor = self.gradient(patch)
- # unpack the edges
- gx: torch.Tensor = grads[:, :, 0]
- gy: torch.Tensor = grads[:, :, 1]
- mag: torch.Tensor = torch.sqrt(gx * gx + gy * gy + self.eps) * self.weighting
- ori: torch.Tensor = torch.atan2(gy, gx + self.eps) + 2.0 * pi
- o_big = float(self.num_ang_bins) * (ori + 1.0 * pi) / (2.0 * pi)
- bo0_big = torch.floor(o_big)
- wo1_big = o_big - bo0_big
- bo0_big = bo0_big % self.num_ang_bins
- bo1_big = (bo0_big + 1) % self.num_ang_bins
- wo0_big = (1.0 - wo1_big) * mag
- wo1_big = wo1_big * mag
- ang_bins_list = []
- for i in range(0, self.num_ang_bins):
- ang_bins_i = F.adaptive_avg_pool2d(
- (bo0_big == i).to(patch.dtype) * wo0_big + (bo1_big == i).to(patch.dtype) * wo1_big, (1, 1)
- )
- ang_bins_list.append(ang_bins_i)
- ang_bins = torch.cat(ang_bins_list, 1).view(-1, 1, self.num_ang_bins)
- ang_bins = self.angular_smooth(ang_bins).view(-1, self.num_ang_bins)
- values, indices = ang_bins.max(1)
- indices_left = (self.num_ang_bins + indices - 1) % self.num_ang_bins
- indices_right = (indices + 1) % self.num_ang_bins
- left = torch.gather(ang_bins, 1, indices_left.reshape(-1, 1)).reshape(-1)
- center = values
- right = torch.gather(ang_bins, 1, indices_right.reshape(-1, 1)).reshape(-1)
- c_subpix = 0.5 * (left - right) / (left + right - 2.0 * center)
- angle = -((2.0 * pi * (indices.to(patch.dtype) + c_subpix) / float(self.num_ang_bins)) - pi)
- return angle
- class OriNet(nn.Module):
- """Network, which estimates the canonical orientation of the given 32x32 patches, in radians.
- Zero angle points towards right. This is based on the original code from paper
- "Repeatability Is Not Enough: Learning Discriminative Affine Regions via Discriminability"".
- See :cite:`AffNet2018` for more details.
- Args:
- pretrained: Download and set pretrained weights to the model.
- eps: to avoid division by zero in atan2.
- Returns:
- Angle in radians.
- Shape:
- - Input: (B, 1, 32, 32)
- - Output: (B)
- Examples:
- >>> input = torch.rand(16, 1, 32, 32)
- >>> orinet = OriNet()
- >>> angle = orinet(input) # 16
- """
- def __init__(self, pretrained: bool = False, eps: float = 1e-8) -> None:
- super().__init__()
- self.features = nn.Sequential(
- nn.Conv2d(1, 16, kernel_size=3, padding=1, bias=False),
- nn.BatchNorm2d(16, affine=False),
- nn.ReLU(),
- nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias=False),
- nn.BatchNorm2d(16, affine=False),
- nn.ReLU(),
- nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1, bias=False),
- nn.BatchNorm2d(32, affine=False),
- nn.ReLU(),
- nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1, bias=False),
- nn.BatchNorm2d(32, affine=False),
- nn.ReLU(),
- nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1, bias=False),
- nn.BatchNorm2d(64, affine=False),
- nn.ReLU(),
- nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False),
- nn.BatchNorm2d(64, affine=False),
- nn.ReLU(),
- nn.Dropout(0.25),
- nn.Conv2d(64, 2, kernel_size=8, stride=1, padding=1, bias=True),
- nn.Tanh(),
- nn.AdaptiveAvgPool2d(1),
- )
- self.eps = eps
- # use torch.hub to load pretrained model
- if pretrained:
- pretrained_dict = torch.hub.load_state_dict_from_url(urls["orinet"], map_location=torch.device("cpu"))
- self.load_state_dict(pretrained_dict["state_dict"], strict=False)
- self.eval()
- @staticmethod
- def _normalize_input(x: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
- """Utility function that normalizes the input by batch."""
- sp, mp = torch.std_mean(x, dim=(-3, -2, -1), keepdim=True)
- # WARNING: we need to .detach() input, otherwise the gradients produced by
- # the patches extractor with F.grid_sample are very noisy, making the detector
- # training totally unstable.
- return (x - mp.detach()) / (sp.detach() + eps)
- def forward(self, patch: torch.Tensor) -> torch.Tensor:
- """Run forward.
- Args:
- patch: :math:`(B, 1, H, W)`
- Returns:
- angle in radians: :math:`(B)`
- """
- xy = self.features(self._normalize_input(patch)).view(-1, 2)
- angle = torch.atan2(xy[:, 0] + 1e-8, xy[:, 1] + self.eps)
- return angle
- class LAFOrienter(nn.Module):
- """Module, which extracts patches using input images and local affine frames (LAFs).
- Then runs :class:`~kornia.feature.PatchDominantGradientOrientation` or
- :class:`~kornia.feature.OriNet` on patches and then rotates the LAFs by the estimated angles
- Args:
- patch_size:
- num_angular_bins:
- angle_detector: Patch orientation estimator, e.g. :class:`~kornia.feature.PatchDominantGradientOrientation`
- or OriNet.
- """ # pylint: disable
- def __init__(
- self, patch_size: int = 32, num_angular_bins: int = 36, angle_detector: Optional[nn.Module] = None
- ) -> None:
- super().__init__()
- self.patch_size = patch_size
- self.num_ang_bins = num_angular_bins
- self.angle_detector: nn.Module
- if angle_detector is None:
- self.angle_detector = PatchDominantGradientOrientation(self.patch_size, self.num_ang_bins)
- else:
- self.angle_detector = angle_detector
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}(patch_size={self.patch_size}, angle_detector={self.angle_detector})"
- def forward(self, laf: torch.Tensor, img: torch.Tensor) -> torch.Tensor:
- """Run forward.
- Args:
- laf: :math:`(B, N, 2, 3)`
- img: :math:`(B, 1, H, W)`
- Returns:
- LAF_out: :math:`(B, N, 2, 3)`
- """
- KORNIA_CHECK_LAF(laf)
- KORNIA_CHECK_SHAPE(img, ["B", "C", "H", "W"])
- if laf.size(0) != img.size(0):
- raise ValueError(f"Batch size of laf and img should be the same. Got {img.size(0)}, {laf.size(0)}")
- B, N = laf.shape[:2]
- patches: torch.Tensor = extract_patches_from_pyramid(img, laf, self.patch_size).view(
- -1, 1, self.patch_size, self.patch_size
- )
- angles_radians: torch.Tensor = self.angle_detector(patches).view(B, N)
- prev_angle = get_laf_orientation(laf).view_as(angles_radians)
- laf_out: torch.Tensor = set_laf_orientation(laf, rad2deg(angles_radians) + prev_angle)
- return laf_out
|