# LICENSE HEADER MANAGED BY add-license-header # # Copyright 2018 Kornia Team # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import warnings from typing import Any, Dict, Optional, Tuple import torch import torch.nn.functional as F from kornia.core import Module, Tensor, concatenate from kornia.core.check import KORNIA_CHECK_SHAPE from kornia.feature.sold2.structures import DetectorCfg, LineMatcherCfg from kornia.geometry.conversions import normalize_pixel_coordinates from kornia.utils import dataclass_to_dict, dict_to_dataclass from .backbones import SOLD2Net from .sold2_detector import LineSegmentDetectionModule, line_map_to_segments, prob_to_junctions urls: Dict[str, str] = {} urls["wireframe"] = "http://cmp.felk.cvut.cz/~mishkdmy/models/sold2_wireframe.pth" class SOLD2(Module): r"""Module, which detects and describe line segments in an image. This is based on the original code from the paper "SOLD²: Self-supervised Occlusion-aware Line Detector and Descriptor". See :cite:`SOLD22021` for more details. Args: config: Dict specifying parameters. None will load the default parameters, which are tuned for images in the range 400~800 px. pretrained: If True, download and set pretrained weights to the model. Returns: The raw junction and line heatmaps, the semi-dense descriptor map, as well as the list of detected line segments (ij coordinates convention). Example: >>> images = torch.rand(2, 1, 64, 64) >>> sold2 = SOLD2() >>> outputs = sold2(images) >>> line_seg1 = outputs["line_segments"][0] >>> line_seg2 = outputs["line_segments"][1] >>> desc1 = outputs["dense_desc"][0] >>> desc2 = outputs["dense_desc"][1] >>> matches = sold2.match(line_seg1, line_seg2, desc1[None], desc2[None]) """ def __init__(self, pretrained: bool = True, config: Optional[DetectorCfg] = None) -> None: if isinstance(config, dict): warnings.warn( "Usage of config as a plain dictionary is deprecated in favor of" " `kornia.features.sold2.structures.DetectorCfg`. The support of plain dictionaries" "as config will be removed in kornia v0.8.0 (December 2024).", category=DeprecationWarning, stacklevel=2, ) config = dict_to_dataclass(config, DetectorCfg) super().__init__() # Initialize some parameters self.config = config if config is not None else DetectorCfg() self.config.use_descriptor = True # Only difference to SOLD2_detector DetectorCfg self.grid_size = self.config.grid_size self.junc_detect_thresh = self.config.detection_thresh self.max_num_junctions = self.config.max_num_junctions # Load the pre-trained model self.model = SOLD2Net(dataclass_to_dict(self.config)) if pretrained: pretrained_dict = torch.hub.load_state_dict_from_url(urls["wireframe"], map_location=torch.device("cpu")) state_dict = self.adapt_state_dict(pretrained_dict["model_state_dict"]) self.model.load_state_dict(state_dict) self.eval() # Initialize the line detector self.line_detector = LineSegmentDetectionModule(self.config.line_detector_cfg) # Initialize the line matcher self.line_matcher = WunschLineMatcher(self.config.line_matcher_cfg) def forward(self, img: Tensor) -> Dict[str, Any]: """Run forward. Args: img: batched images with shape :math:`(B, 1, H, W)`. Returns: line_segments: list of N line segments in each of the B images :math:`List[(N, 2, 2)]`. junction_heatmap: raw junction heatmap of shape :math:`(B, H, W)`. line_heatmap: raw line heatmap of shape :math:`(B, H, W)`. dense_desc: the semi-dense descriptor map of shape :math:`(B, 128, H/4, W/4)`. """ KORNIA_CHECK_SHAPE(img, ["B", "1", "H", "W"]) outputs = {} # Forward pass of the CNN backbone net_outputs = self.model(img) outputs["junction_heatmap"] = net_outputs["junctions"] outputs["line_heatmap"] = net_outputs["heatmap"] outputs["dense_desc"] = net_outputs["descriptors"] # Loop through all images lines = [] for junc_prob, heatmap in zip(net_outputs["junctions"], net_outputs["heatmap"]): # Get the junctions junctions = prob_to_junctions(junc_prob, self.grid_size, self.junc_detect_thresh, self.max_num_junctions) # Run the line detector line_map, junctions, _ = self.line_detector.detect(junctions, heatmap) lines.append(line_map_to_segments(junctions, line_map)) outputs["line_segments"] = lines return outputs def match(self, line_seg1: Tensor, line_seg2: Tensor, desc1: Tensor, desc2: Tensor) -> Tensor: """Find the best matches between two sets of line segments and their corresponding descriptors. Args: line_seg1: list of line segments in image 1, with shape [num_lines, 2, 2]. line_seg2: list of line segments in image 2, with shape [num_lines, 2, 2]. desc1: semi-dense descriptor map of image 1, with shape [1, 128, H/4, W/4]. desc2: semi-dense descriptor map of image 2, with shape [1, 128, H/4, W/4]. Returns: A np.array of size [num_lines1] indicating the index in line_seg2 of the matched line, for each line in line_seg1. -1 means that the line is not matched. """ return self.line_matcher(line_seg1, line_seg2, desc1, desc2) def adapt_state_dict(self, state_dict: Dict[str, Any]) -> Dict[str, Any]: del state_dict["w_junc"] del state_dict["w_heatmap"] del state_dict["w_desc"] state_dict["heatmap_decoder.conv_block_lst.2.0.weight"] = state_dict["heatmap_decoder.conv_block_lst.2.weight"] state_dict["heatmap_decoder.conv_block_lst.2.0.bias"] = state_dict["heatmap_decoder.conv_block_lst.2.bias"] del state_dict["heatmap_decoder.conv_block_lst.2.weight"] del state_dict["heatmap_decoder.conv_block_lst.2.bias"] return state_dict class WunschLineMatcher(Module): """Class matching two sets of line segments with the Needleman-Wunsch algorithm. TODO: move it later in kornia.feature.matching """ def __init__(self, config: Optional[LineMatcherCfg] = None) -> None: super().__init__() # Initialize the parameters if config is None: config = LineMatcherCfg() self.config = config self.cross_check = self.config.cross_check self.num_samples = self.config.num_samples self.min_dist_pts = self.config.min_dist_pts self.top_k_candidates = self.config.top_k_candidates self.grid_size = self.config.grid_size self.line_score = self.config.line_score def forward(self, line_seg1: Tensor, line_seg2: Tensor, desc1: Tensor, desc2: Tensor) -> Tensor: """Find the best matches between two sets of line segments and their corresponding descriptors.""" KORNIA_CHECK_SHAPE(line_seg1, ["N", "2", "2"]) KORNIA_CHECK_SHAPE(line_seg2, ["N", "2", "2"]) KORNIA_CHECK_SHAPE(desc1, ["B", "D", "H", "H"]) KORNIA_CHECK_SHAPE(desc2, ["B", "D", "H", "H"]) device = desc1.device img_size1 = (desc1.shape[2] * self.grid_size, desc1.shape[3] * self.grid_size) img_size2 = (desc2.shape[2] * self.grid_size, desc2.shape[3] * self.grid_size) # Default case when an image has no lines if len(line_seg1) == 0: return torch.empty(0, dtype=torch.int, device=device) if len(line_seg2) == 0: return -torch.ones(len(line_seg1), dtype=torch.int, device=device) # Sample points regularly along each line line_points1, valid_points1 = self.sample_line_points(line_seg1) line_points2, valid_points2 = self.sample_line_points(line_seg2) line_points1 = line_points1.reshape(-1, 2) line_points2 = line_points2.reshape(-1, 2) # Extract the descriptors for each point grid1 = keypoints_to_grid(line_points1, img_size1) grid2 = keypoints_to_grid(line_points2, img_size2) desc1 = F.normalize(F.grid_sample(desc1, grid1, align_corners=False)[0, :, :, 0], dim=0) desc2 = F.normalize(F.grid_sample(desc2, grid2, align_corners=False)[0, :, :, 0], dim=0) # Precompute the distance between line points for every pair of lines # Assign a score of -1 for invalid points scores = desc1.t() @ desc2 scores[~valid_points1.flatten()] = -1 scores[:, ~valid_points2.flatten()] = -1 scores = scores.reshape(len(line_seg1), self.num_samples, len(line_seg2), self.num_samples) scores = scores.permute(0, 2, 1, 3) # scores.shape = (n_lines1, n_lines2, num_samples, num_samples) # Pre-filter the line candidates and find the best match for each line matches = self.filter_and_match_lines(scores) # [Optionally] filter matches with mutual nearest neighbor filtering if self.cross_check: matches2 = self.filter_and_match_lines(scores.permute(1, 0, 3, 2)) mutual = matches2[matches] == torch.arange(len(line_seg1), device=device) matches[~mutual] = -1 return matches def sample_line_points(self, line_seg: Tensor) -> Tuple[Tensor, Tensor]: """Regularly sample points along each line segments, with a minimal distance between each point. Pad the remaining points. Args: line_seg: an Nx2x2 Tensor. Returns: line_points: an N x num_samples x 2 Tensor. valid_points: a boolean N x num_samples Tensor. """ _N, _, _ = line_seg.shape M = self.num_samples dev = line_seg.device lengths = torch.norm(line_seg[:, 0] - line_seg[:, 1], dim=1) num_pts = torch.clamp((lengths / self.min_dist_pts).floor().int(), min=2, max=M) # (N,) orig = line_seg[:, 0].unsqueeze(1) dirs = (line_seg[:, 1] - line_seg[:, 0]).unsqueeze(1) idx = torch.arange(M, device=dev).unsqueeze(0) denom = (num_pts - 1).unsqueeze(1) alpha = idx / denom pts = orig + dirs * alpha.unsqueeze(-1) valid = idx < num_pts.unsqueeze(1) pts = pts.masked_fill(~valid.unsqueeze(-1), 0.0) return pts, valid def filter_and_match_lines(self, scores: Tensor) -> Tensor: """Use scores to keep the top k best lines. Compute the Needleman- Wunsch algorithm on each candidate pairs, and keep the highest score. Args: scores: a (N, M, n, n) Tensor containing the pairwise scores of the elements to match. Returns: matches: a (N) Tensor containing the indices of the best match """ KORNIA_CHECK_SHAPE(scores, ["M", "N", "n", "n"]) # Pre-filter the pairs and keep the top k best candidate lines line_scores1 = scores.max(3)[0] valid_scores1 = line_scores1 != -1 line_scores1 = (line_scores1 * valid_scores1).sum(2) / valid_scores1.sum(2) line_scores2 = scores.max(2)[0] valid_scores2 = line_scores2 != -1 line_scores2 = (line_scores2 * valid_scores2).sum(2) / valid_scores2.sum(2) line_scores = (line_scores1 + line_scores2) / 2 topk_lines = torch.argsort(line_scores, dim=1)[:, -self.top_k_candidates :] # topk_lines.shape = (n_lines1, top_k_candidates) top_scores = torch.take_along_dim(scores, topk_lines[:, :, None, None], dim=1) # Consider the reversed line segments as well top_scores = concatenate([top_scores, torch.flip(top_scores, dims=[-1])], 1) # Compute the line distance matrix with Needleman-Wunsch algo and # retrieve the closest line neighbor n_lines1, top2k, n, m = top_scores.shape top_scores = top_scores.reshape((n_lines1 * top2k, n, m)) nw_scores = self.needleman_wunsch(top_scores) nw_scores = nw_scores.reshape(n_lines1, top2k) matches = torch.remainder(torch.argmax(nw_scores, dim=1), top2k // 2) matches = topk_lines[torch.arange(n_lines1), matches] return matches def needleman_wunsch(self, scores: Tensor) -> Tensor: """Batched implementation of the Needleman-Wunsch algorithm. The cost of the InDel operation is set to 0 by subtracting the gap penalty to the scores. Args: scores: a (B, N, M) Tensor containing the pairwise scores of the elements to match. """ KORNIA_CHECK_SHAPE(scores, ["B", "N", "M"]) # Recalibrate the scores to get a gap score of 0 gap = 0.1 B, N, M = scores.shape dp = torch.zeros(B, N + 1, M + 1, device=scores.device) S = scores - gap for k in range(2, N + M + 1): i_min = max(1, k - M) i_max = min(N, k - 1) i = torch.arange(i_min, i_max + 1, device=scores.device) j = k - i up = dp[:, i - 1, j] left = dp[:, i, j - 1] diag = dp[:, i - 1, j - 1] + S[:, i - 1, j - 1] dp[:, i, j] = torch.max(torch.max(up, left), diag) return dp[:, -1, -1] def keypoints_to_grid(keypoints: Tensor, img_size: Tuple[int, int]) -> Tensor: """Convert a list of keypoints into a grid in [-1, 1]² that can be used in torch.nn.functional.interpolate. Args: keypoints: a tensor [N, 2] of N keypoints (ij coordinates convention). img_size: the original image size (H, W) """ KORNIA_CHECK_SHAPE(keypoints, ["N", "2"]) n_points = len(keypoints) grid_points = normalize_pixel_coordinates(keypoints[:, [1, 0]], img_size[0], img_size[1]) grid_points = grid_points.view(-1, n_points, 1, 2) return grid_points def batched_linspace(start: Tensor, end: Tensor, step: int, dim: int) -> Tensor: """Batch version of torch.normalize (similar to the numpy one).""" intervals = ((end - start) / (step - 1)).unsqueeze(dim) broadcast_size = [1] * len(intervals.shape) broadcast_size[dim] = step samples = torch.arange(step, dtype=torch.float, device=start.device).reshape(broadcast_size) samples = start.unsqueeze(dim) + samples * intervals return samples