| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345 |
- # LICENSE HEADER MANAGED BY add-license-header
- #
- # Copyright 2018 Kornia Team
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- import warnings
- from typing import Any, Dict, Optional, Tuple
- import torch
- import torch.nn.functional as F
- from kornia.core import Module, Tensor, concatenate
- from kornia.core.check import KORNIA_CHECK_SHAPE
- from kornia.feature.sold2.structures import DetectorCfg, LineMatcherCfg
- from kornia.geometry.conversions import normalize_pixel_coordinates
- from kornia.utils import dataclass_to_dict, dict_to_dataclass
- from .backbones import SOLD2Net
- from .sold2_detector import LineSegmentDetectionModule, line_map_to_segments, prob_to_junctions
- urls: Dict[str, str] = {}
- urls["wireframe"] = "http://cmp.felk.cvut.cz/~mishkdmy/models/sold2_wireframe.pth"
- class SOLD2(Module):
- r"""Module, which detects and describe line segments in an image.
- This is based on the original code from the paper "SOLD²: Self-supervised
- Occlusion-aware Line Detector and Descriptor". See :cite:`SOLD22021` for more details.
- Args:
- config: Dict specifying parameters. None will load the default parameters,
- which are tuned for images in the range 400~800 px.
- pretrained: If True, download and set pretrained weights to the model.
- Returns:
- The raw junction and line heatmaps, the semi-dense descriptor map,
- as well as the list of detected line segments (ij coordinates convention).
- Example:
- >>> images = torch.rand(2, 1, 64, 64)
- >>> sold2 = SOLD2()
- >>> outputs = sold2(images)
- >>> line_seg1 = outputs["line_segments"][0]
- >>> line_seg2 = outputs["line_segments"][1]
- >>> desc1 = outputs["dense_desc"][0]
- >>> desc2 = outputs["dense_desc"][1]
- >>> matches = sold2.match(line_seg1, line_seg2, desc1[None], desc2[None])
- """
- def __init__(self, pretrained: bool = True, config: Optional[DetectorCfg] = None) -> None:
- if isinstance(config, dict):
- warnings.warn(
- "Usage of config as a plain dictionary is deprecated in favor of"
- " `kornia.features.sold2.structures.DetectorCfg`. The support of plain dictionaries"
- "as config will be removed in kornia v0.8.0 (December 2024).",
- category=DeprecationWarning,
- stacklevel=2,
- )
- config = dict_to_dataclass(config, DetectorCfg)
- super().__init__()
- # Initialize some parameters
- self.config = config if config is not None else DetectorCfg()
- self.config.use_descriptor = True # Only difference to SOLD2_detector DetectorCfg
- self.grid_size = self.config.grid_size
- self.junc_detect_thresh = self.config.detection_thresh
- self.max_num_junctions = self.config.max_num_junctions
- # Load the pre-trained model
- self.model = SOLD2Net(dataclass_to_dict(self.config))
- if pretrained:
- pretrained_dict = torch.hub.load_state_dict_from_url(urls["wireframe"], map_location=torch.device("cpu"))
- state_dict = self.adapt_state_dict(pretrained_dict["model_state_dict"])
- self.model.load_state_dict(state_dict)
- self.eval()
- # Initialize the line detector
- self.line_detector = LineSegmentDetectionModule(self.config.line_detector_cfg)
- # Initialize the line matcher
- self.line_matcher = WunschLineMatcher(self.config.line_matcher_cfg)
- def forward(self, img: Tensor) -> Dict[str, Any]:
- """Run forward.
- Args:
- img: batched images with shape :math:`(B, 1, H, W)`.
- Returns:
- line_segments: list of N line segments in each of the B images :math:`List[(N, 2, 2)]`.
- junction_heatmap: raw junction heatmap of shape :math:`(B, H, W)`.
- line_heatmap: raw line heatmap of shape :math:`(B, H, W)`.
- dense_desc: the semi-dense descriptor map of shape :math:`(B, 128, H/4, W/4)`.
- """
- KORNIA_CHECK_SHAPE(img, ["B", "1", "H", "W"])
- outputs = {}
- # Forward pass of the CNN backbone
- net_outputs = self.model(img)
- outputs["junction_heatmap"] = net_outputs["junctions"]
- outputs["line_heatmap"] = net_outputs["heatmap"]
- outputs["dense_desc"] = net_outputs["descriptors"]
- # Loop through all images
- lines = []
- for junc_prob, heatmap in zip(net_outputs["junctions"], net_outputs["heatmap"]):
- # Get the junctions
- junctions = prob_to_junctions(junc_prob, self.grid_size, self.junc_detect_thresh, self.max_num_junctions)
- # Run the line detector
- line_map, junctions, _ = self.line_detector.detect(junctions, heatmap)
- lines.append(line_map_to_segments(junctions, line_map))
- outputs["line_segments"] = lines
- return outputs
- def match(self, line_seg1: Tensor, line_seg2: Tensor, desc1: Tensor, desc2: Tensor) -> Tensor:
- """Find the best matches between two sets of line segments and their corresponding descriptors.
- Args:
- line_seg1: list of line segments in image 1, with shape [num_lines, 2, 2].
- line_seg2: list of line segments in image 2, with shape [num_lines, 2, 2].
- desc1: semi-dense descriptor map of image 1, with shape [1, 128, H/4, W/4].
- desc2: semi-dense descriptor map of image 2, with shape [1, 128, H/4, W/4].
- Returns:
- A np.array of size [num_lines1] indicating the index in line_seg2 of the matched line,
- for each line in line_seg1. -1 means that the line is not matched.
- """
- return self.line_matcher(line_seg1, line_seg2, desc1, desc2)
- def adapt_state_dict(self, state_dict: Dict[str, Any]) -> Dict[str, Any]:
- del state_dict["w_junc"]
- del state_dict["w_heatmap"]
- del state_dict["w_desc"]
- state_dict["heatmap_decoder.conv_block_lst.2.0.weight"] = state_dict["heatmap_decoder.conv_block_lst.2.weight"]
- state_dict["heatmap_decoder.conv_block_lst.2.0.bias"] = state_dict["heatmap_decoder.conv_block_lst.2.bias"]
- del state_dict["heatmap_decoder.conv_block_lst.2.weight"]
- del state_dict["heatmap_decoder.conv_block_lst.2.bias"]
- return state_dict
- class WunschLineMatcher(Module):
- """Class matching two sets of line segments with the Needleman-Wunsch algorithm.
- TODO: move it later in kornia.feature.matching
- """
- def __init__(self, config: Optional[LineMatcherCfg] = None) -> None:
- super().__init__()
- # Initialize the parameters
- if config is None:
- config = LineMatcherCfg()
- self.config = config
- self.cross_check = self.config.cross_check
- self.num_samples = self.config.num_samples
- self.min_dist_pts = self.config.min_dist_pts
- self.top_k_candidates = self.config.top_k_candidates
- self.grid_size = self.config.grid_size
- self.line_score = self.config.line_score
- def forward(self, line_seg1: Tensor, line_seg2: Tensor, desc1: Tensor, desc2: Tensor) -> Tensor:
- """Find the best matches between two sets of line segments and their corresponding descriptors."""
- KORNIA_CHECK_SHAPE(line_seg1, ["N", "2", "2"])
- KORNIA_CHECK_SHAPE(line_seg2, ["N", "2", "2"])
- KORNIA_CHECK_SHAPE(desc1, ["B", "D", "H", "H"])
- KORNIA_CHECK_SHAPE(desc2, ["B", "D", "H", "H"])
- device = desc1.device
- img_size1 = (desc1.shape[2] * self.grid_size, desc1.shape[3] * self.grid_size)
- img_size2 = (desc2.shape[2] * self.grid_size, desc2.shape[3] * self.grid_size)
- # Default case when an image has no lines
- if len(line_seg1) == 0:
- return torch.empty(0, dtype=torch.int, device=device)
- if len(line_seg2) == 0:
- return -torch.ones(len(line_seg1), dtype=torch.int, device=device)
- # Sample points regularly along each line
- line_points1, valid_points1 = self.sample_line_points(line_seg1)
- line_points2, valid_points2 = self.sample_line_points(line_seg2)
- line_points1 = line_points1.reshape(-1, 2)
- line_points2 = line_points2.reshape(-1, 2)
- # Extract the descriptors for each point
- grid1 = keypoints_to_grid(line_points1, img_size1)
- grid2 = keypoints_to_grid(line_points2, img_size2)
- desc1 = F.normalize(F.grid_sample(desc1, grid1, align_corners=False)[0, :, :, 0], dim=0)
- desc2 = F.normalize(F.grid_sample(desc2, grid2, align_corners=False)[0, :, :, 0], dim=0)
- # Precompute the distance between line points for every pair of lines
- # Assign a score of -1 for invalid points
- scores = desc1.t() @ desc2
- scores[~valid_points1.flatten()] = -1
- scores[:, ~valid_points2.flatten()] = -1
- scores = scores.reshape(len(line_seg1), self.num_samples, len(line_seg2), self.num_samples)
- scores = scores.permute(0, 2, 1, 3)
- # scores.shape = (n_lines1, n_lines2, num_samples, num_samples)
- # Pre-filter the line candidates and find the best match for each line
- matches = self.filter_and_match_lines(scores)
- # [Optionally] filter matches with mutual nearest neighbor filtering
- if self.cross_check:
- matches2 = self.filter_and_match_lines(scores.permute(1, 0, 3, 2))
- mutual = matches2[matches] == torch.arange(len(line_seg1), device=device)
- matches[~mutual] = -1
- return matches
- def sample_line_points(self, line_seg: Tensor) -> Tuple[Tensor, Tensor]:
- """Regularly sample points along each line segments, with a minimal distance between each point.
- Pad the remaining points.
- Args:
- line_seg: an Nx2x2 Tensor.
- Returns:
- line_points: an N x num_samples x 2 Tensor.
- valid_points: a boolean N x num_samples Tensor.
- """
- _N, _, _ = line_seg.shape
- M = self.num_samples
- dev = line_seg.device
- lengths = torch.norm(line_seg[:, 0] - line_seg[:, 1], dim=1)
- num_pts = torch.clamp((lengths / self.min_dist_pts).floor().int(), min=2, max=M) # (N,)
- orig = line_seg[:, 0].unsqueeze(1)
- dirs = (line_seg[:, 1] - line_seg[:, 0]).unsqueeze(1)
- idx = torch.arange(M, device=dev).unsqueeze(0)
- denom = (num_pts - 1).unsqueeze(1)
- alpha = idx / denom
- pts = orig + dirs * alpha.unsqueeze(-1)
- valid = idx < num_pts.unsqueeze(1)
- pts = pts.masked_fill(~valid.unsqueeze(-1), 0.0)
- return pts, valid
- def filter_and_match_lines(self, scores: Tensor) -> Tensor:
- """Use scores to keep the top k best lines.
- Compute the Needleman- Wunsch algorithm on each candidate pairs, and keep the highest score.
- Args:
- scores: a (N, M, n, n) Tensor containing the pairwise scores
- of the elements to match.
- Returns:
- matches: a (N) Tensor containing the indices of the best match
- """
- KORNIA_CHECK_SHAPE(scores, ["M", "N", "n", "n"])
- # Pre-filter the pairs and keep the top k best candidate lines
- line_scores1 = scores.max(3)[0]
- valid_scores1 = line_scores1 != -1
- line_scores1 = (line_scores1 * valid_scores1).sum(2) / valid_scores1.sum(2)
- line_scores2 = scores.max(2)[0]
- valid_scores2 = line_scores2 != -1
- line_scores2 = (line_scores2 * valid_scores2).sum(2) / valid_scores2.sum(2)
- line_scores = (line_scores1 + line_scores2) / 2
- topk_lines = torch.argsort(line_scores, dim=1)[:, -self.top_k_candidates :]
- # topk_lines.shape = (n_lines1, top_k_candidates)
- top_scores = torch.take_along_dim(scores, topk_lines[:, :, None, None], dim=1)
- # Consider the reversed line segments as well
- top_scores = concatenate([top_scores, torch.flip(top_scores, dims=[-1])], 1)
- # Compute the line distance matrix with Needleman-Wunsch algo and
- # retrieve the closest line neighbor
- n_lines1, top2k, n, m = top_scores.shape
- top_scores = top_scores.reshape((n_lines1 * top2k, n, m))
- nw_scores = self.needleman_wunsch(top_scores)
- nw_scores = nw_scores.reshape(n_lines1, top2k)
- matches = torch.remainder(torch.argmax(nw_scores, dim=1), top2k // 2)
- matches = topk_lines[torch.arange(n_lines1), matches]
- return matches
- def needleman_wunsch(self, scores: Tensor) -> Tensor:
- """Batched implementation of the Needleman-Wunsch algorithm.
- The cost of the InDel operation is set to 0 by subtracting the gap
- penalty to the scores.
- Args:
- scores: a (B, N, M) Tensor containing the pairwise scores
- of the elements to match.
- """
- KORNIA_CHECK_SHAPE(scores, ["B", "N", "M"])
- # Recalibrate the scores to get a gap score of 0
- gap = 0.1
- B, N, M = scores.shape
- dp = torch.zeros(B, N + 1, M + 1, device=scores.device)
- S = scores - gap
- for k in range(2, N + M + 1):
- i_min = max(1, k - M)
- i_max = min(N, k - 1)
- i = torch.arange(i_min, i_max + 1, device=scores.device)
- j = k - i
- up = dp[:, i - 1, j]
- left = dp[:, i, j - 1]
- diag = dp[:, i - 1, j - 1] + S[:, i - 1, j - 1]
- dp[:, i, j] = torch.max(torch.max(up, left), diag)
- return dp[:, -1, -1]
- def keypoints_to_grid(keypoints: Tensor, img_size: Tuple[int, int]) -> Tensor:
- """Convert a list of keypoints into a grid in [-1, 1]² that can be used in torch.nn.functional.interpolate.
- Args:
- keypoints: a tensor [N, 2] of N keypoints (ij coordinates convention).
- img_size: the original image size (H, W)
- """
- KORNIA_CHECK_SHAPE(keypoints, ["N", "2"])
- n_points = len(keypoints)
- grid_points = normalize_pixel_coordinates(keypoints[:, [1, 0]], img_size[0], img_size[1])
- grid_points = grid_points.view(-1, n_points, 1, 2)
- return grid_points
- def batched_linspace(start: Tensor, end: Tensor, step: int, dim: int) -> Tensor:
- """Batch version of torch.normalize (similar to the numpy one)."""
- intervals = ((end - start) / (step - 1)).unsqueeze(dim)
- broadcast_size = [1] * len(intervals.shape)
- broadcast_size[dim] = step
- samples = torch.arange(step, dtype=torch.float, device=start.device).reshape(broadcast_size)
- samples = start.unsqueeze(dim) + samples * intervals
- return samples
|