sold2.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. import warnings
  18. from typing import Any, Dict, Optional, Tuple
  19. import torch
  20. import torch.nn.functional as F
  21. from kornia.core import Module, Tensor, concatenate
  22. from kornia.core.check import KORNIA_CHECK_SHAPE
  23. from kornia.feature.sold2.structures import DetectorCfg, LineMatcherCfg
  24. from kornia.geometry.conversions import normalize_pixel_coordinates
  25. from kornia.utils import dataclass_to_dict, dict_to_dataclass
  26. from .backbones import SOLD2Net
  27. from .sold2_detector import LineSegmentDetectionModule, line_map_to_segments, prob_to_junctions
  28. urls: Dict[str, str] = {}
  29. urls["wireframe"] = "http://cmp.felk.cvut.cz/~mishkdmy/models/sold2_wireframe.pth"
  30. class SOLD2(Module):
  31. r"""Module, which detects and describe line segments in an image.
  32. This is based on the original code from the paper "SOLD²: Self-supervised
  33. Occlusion-aware Line Detector and Descriptor". See :cite:`SOLD22021` for more details.
  34. Args:
  35. config: Dict specifying parameters. None will load the default parameters,
  36. which are tuned for images in the range 400~800 px.
  37. pretrained: If True, download and set pretrained weights to the model.
  38. Returns:
  39. The raw junction and line heatmaps, the semi-dense descriptor map,
  40. as well as the list of detected line segments (ij coordinates convention).
  41. Example:
  42. >>> images = torch.rand(2, 1, 64, 64)
  43. >>> sold2 = SOLD2()
  44. >>> outputs = sold2(images)
  45. >>> line_seg1 = outputs["line_segments"][0]
  46. >>> line_seg2 = outputs["line_segments"][1]
  47. >>> desc1 = outputs["dense_desc"][0]
  48. >>> desc2 = outputs["dense_desc"][1]
  49. >>> matches = sold2.match(line_seg1, line_seg2, desc1[None], desc2[None])
  50. """
  51. def __init__(self, pretrained: bool = True, config: Optional[DetectorCfg] = None) -> None:
  52. if isinstance(config, dict):
  53. warnings.warn(
  54. "Usage of config as a plain dictionary is deprecated in favor of"
  55. " `kornia.features.sold2.structures.DetectorCfg`. The support of plain dictionaries"
  56. "as config will be removed in kornia v0.8.0 (December 2024).",
  57. category=DeprecationWarning,
  58. stacklevel=2,
  59. )
  60. config = dict_to_dataclass(config, DetectorCfg)
  61. super().__init__()
  62. # Initialize some parameters
  63. self.config = config if config is not None else DetectorCfg()
  64. self.config.use_descriptor = True # Only difference to SOLD2_detector DetectorCfg
  65. self.grid_size = self.config.grid_size
  66. self.junc_detect_thresh = self.config.detection_thresh
  67. self.max_num_junctions = self.config.max_num_junctions
  68. # Load the pre-trained model
  69. self.model = SOLD2Net(dataclass_to_dict(self.config))
  70. if pretrained:
  71. pretrained_dict = torch.hub.load_state_dict_from_url(urls["wireframe"], map_location=torch.device("cpu"))
  72. state_dict = self.adapt_state_dict(pretrained_dict["model_state_dict"])
  73. self.model.load_state_dict(state_dict)
  74. self.eval()
  75. # Initialize the line detector
  76. self.line_detector = LineSegmentDetectionModule(self.config.line_detector_cfg)
  77. # Initialize the line matcher
  78. self.line_matcher = WunschLineMatcher(self.config.line_matcher_cfg)
  79. def forward(self, img: Tensor) -> Dict[str, Any]:
  80. """Run forward.
  81. Args:
  82. img: batched images with shape :math:`(B, 1, H, W)`.
  83. Returns:
  84. line_segments: list of N line segments in each of the B images :math:`List[(N, 2, 2)]`.
  85. junction_heatmap: raw junction heatmap of shape :math:`(B, H, W)`.
  86. line_heatmap: raw line heatmap of shape :math:`(B, H, W)`.
  87. dense_desc: the semi-dense descriptor map of shape :math:`(B, 128, H/4, W/4)`.
  88. """
  89. KORNIA_CHECK_SHAPE(img, ["B", "1", "H", "W"])
  90. outputs = {}
  91. # Forward pass of the CNN backbone
  92. net_outputs = self.model(img)
  93. outputs["junction_heatmap"] = net_outputs["junctions"]
  94. outputs["line_heatmap"] = net_outputs["heatmap"]
  95. outputs["dense_desc"] = net_outputs["descriptors"]
  96. # Loop through all images
  97. lines = []
  98. for junc_prob, heatmap in zip(net_outputs["junctions"], net_outputs["heatmap"]):
  99. # Get the junctions
  100. junctions = prob_to_junctions(junc_prob, self.grid_size, self.junc_detect_thresh, self.max_num_junctions)
  101. # Run the line detector
  102. line_map, junctions, _ = self.line_detector.detect(junctions, heatmap)
  103. lines.append(line_map_to_segments(junctions, line_map))
  104. outputs["line_segments"] = lines
  105. return outputs
  106. def match(self, line_seg1: Tensor, line_seg2: Tensor, desc1: Tensor, desc2: Tensor) -> Tensor:
  107. """Find the best matches between two sets of line segments and their corresponding descriptors.
  108. Args:
  109. line_seg1: list of line segments in image 1, with shape [num_lines, 2, 2].
  110. line_seg2: list of line segments in image 2, with shape [num_lines, 2, 2].
  111. desc1: semi-dense descriptor map of image 1, with shape [1, 128, H/4, W/4].
  112. desc2: semi-dense descriptor map of image 2, with shape [1, 128, H/4, W/4].
  113. Returns:
  114. A np.array of size [num_lines1] indicating the index in line_seg2 of the matched line,
  115. for each line in line_seg1. -1 means that the line is not matched.
  116. """
  117. return self.line_matcher(line_seg1, line_seg2, desc1, desc2)
  118. def adapt_state_dict(self, state_dict: Dict[str, Any]) -> Dict[str, Any]:
  119. del state_dict["w_junc"]
  120. del state_dict["w_heatmap"]
  121. del state_dict["w_desc"]
  122. state_dict["heatmap_decoder.conv_block_lst.2.0.weight"] = state_dict["heatmap_decoder.conv_block_lst.2.weight"]
  123. state_dict["heatmap_decoder.conv_block_lst.2.0.bias"] = state_dict["heatmap_decoder.conv_block_lst.2.bias"]
  124. del state_dict["heatmap_decoder.conv_block_lst.2.weight"]
  125. del state_dict["heatmap_decoder.conv_block_lst.2.bias"]
  126. return state_dict
  127. class WunschLineMatcher(Module):
  128. """Class matching two sets of line segments with the Needleman-Wunsch algorithm.
  129. TODO: move it later in kornia.feature.matching
  130. """
  131. def __init__(self, config: Optional[LineMatcherCfg] = None) -> None:
  132. super().__init__()
  133. # Initialize the parameters
  134. if config is None:
  135. config = LineMatcherCfg()
  136. self.config = config
  137. self.cross_check = self.config.cross_check
  138. self.num_samples = self.config.num_samples
  139. self.min_dist_pts = self.config.min_dist_pts
  140. self.top_k_candidates = self.config.top_k_candidates
  141. self.grid_size = self.config.grid_size
  142. self.line_score = self.config.line_score
  143. def forward(self, line_seg1: Tensor, line_seg2: Tensor, desc1: Tensor, desc2: Tensor) -> Tensor:
  144. """Find the best matches between two sets of line segments and their corresponding descriptors."""
  145. KORNIA_CHECK_SHAPE(line_seg1, ["N", "2", "2"])
  146. KORNIA_CHECK_SHAPE(line_seg2, ["N", "2", "2"])
  147. KORNIA_CHECK_SHAPE(desc1, ["B", "D", "H", "H"])
  148. KORNIA_CHECK_SHAPE(desc2, ["B", "D", "H", "H"])
  149. device = desc1.device
  150. img_size1 = (desc1.shape[2] * self.grid_size, desc1.shape[3] * self.grid_size)
  151. img_size2 = (desc2.shape[2] * self.grid_size, desc2.shape[3] * self.grid_size)
  152. # Default case when an image has no lines
  153. if len(line_seg1) == 0:
  154. return torch.empty(0, dtype=torch.int, device=device)
  155. if len(line_seg2) == 0:
  156. return -torch.ones(len(line_seg1), dtype=torch.int, device=device)
  157. # Sample points regularly along each line
  158. line_points1, valid_points1 = self.sample_line_points(line_seg1)
  159. line_points2, valid_points2 = self.sample_line_points(line_seg2)
  160. line_points1 = line_points1.reshape(-1, 2)
  161. line_points2 = line_points2.reshape(-1, 2)
  162. # Extract the descriptors for each point
  163. grid1 = keypoints_to_grid(line_points1, img_size1)
  164. grid2 = keypoints_to_grid(line_points2, img_size2)
  165. desc1 = F.normalize(F.grid_sample(desc1, grid1, align_corners=False)[0, :, :, 0], dim=0)
  166. desc2 = F.normalize(F.grid_sample(desc2, grid2, align_corners=False)[0, :, :, 0], dim=0)
  167. # Precompute the distance between line points for every pair of lines
  168. # Assign a score of -1 for invalid points
  169. scores = desc1.t() @ desc2
  170. scores[~valid_points1.flatten()] = -1
  171. scores[:, ~valid_points2.flatten()] = -1
  172. scores = scores.reshape(len(line_seg1), self.num_samples, len(line_seg2), self.num_samples)
  173. scores = scores.permute(0, 2, 1, 3)
  174. # scores.shape = (n_lines1, n_lines2, num_samples, num_samples)
  175. # Pre-filter the line candidates and find the best match for each line
  176. matches = self.filter_and_match_lines(scores)
  177. # [Optionally] filter matches with mutual nearest neighbor filtering
  178. if self.cross_check:
  179. matches2 = self.filter_and_match_lines(scores.permute(1, 0, 3, 2))
  180. mutual = matches2[matches] == torch.arange(len(line_seg1), device=device)
  181. matches[~mutual] = -1
  182. return matches
  183. def sample_line_points(self, line_seg: Tensor) -> Tuple[Tensor, Tensor]:
  184. """Regularly sample points along each line segments, with a minimal distance between each point.
  185. Pad the remaining points.
  186. Args:
  187. line_seg: an Nx2x2 Tensor.
  188. Returns:
  189. line_points: an N x num_samples x 2 Tensor.
  190. valid_points: a boolean N x num_samples Tensor.
  191. """
  192. _N, _, _ = line_seg.shape
  193. M = self.num_samples
  194. dev = line_seg.device
  195. lengths = torch.norm(line_seg[:, 0] - line_seg[:, 1], dim=1)
  196. num_pts = torch.clamp((lengths / self.min_dist_pts).floor().int(), min=2, max=M) # (N,)
  197. orig = line_seg[:, 0].unsqueeze(1)
  198. dirs = (line_seg[:, 1] - line_seg[:, 0]).unsqueeze(1)
  199. idx = torch.arange(M, device=dev).unsqueeze(0)
  200. denom = (num_pts - 1).unsqueeze(1)
  201. alpha = idx / denom
  202. pts = orig + dirs * alpha.unsqueeze(-1)
  203. valid = idx < num_pts.unsqueeze(1)
  204. pts = pts.masked_fill(~valid.unsqueeze(-1), 0.0)
  205. return pts, valid
  206. def filter_and_match_lines(self, scores: Tensor) -> Tensor:
  207. """Use scores to keep the top k best lines.
  208. Compute the Needleman- Wunsch algorithm on each candidate pairs, and keep the highest score.
  209. Args:
  210. scores: a (N, M, n, n) Tensor containing the pairwise scores
  211. of the elements to match.
  212. Returns:
  213. matches: a (N) Tensor containing the indices of the best match
  214. """
  215. KORNIA_CHECK_SHAPE(scores, ["M", "N", "n", "n"])
  216. # Pre-filter the pairs and keep the top k best candidate lines
  217. line_scores1 = scores.max(3)[0]
  218. valid_scores1 = line_scores1 != -1
  219. line_scores1 = (line_scores1 * valid_scores1).sum(2) / valid_scores1.sum(2)
  220. line_scores2 = scores.max(2)[0]
  221. valid_scores2 = line_scores2 != -1
  222. line_scores2 = (line_scores2 * valid_scores2).sum(2) / valid_scores2.sum(2)
  223. line_scores = (line_scores1 + line_scores2) / 2
  224. topk_lines = torch.argsort(line_scores, dim=1)[:, -self.top_k_candidates :]
  225. # topk_lines.shape = (n_lines1, top_k_candidates)
  226. top_scores = torch.take_along_dim(scores, topk_lines[:, :, None, None], dim=1)
  227. # Consider the reversed line segments as well
  228. top_scores = concatenate([top_scores, torch.flip(top_scores, dims=[-1])], 1)
  229. # Compute the line distance matrix with Needleman-Wunsch algo and
  230. # retrieve the closest line neighbor
  231. n_lines1, top2k, n, m = top_scores.shape
  232. top_scores = top_scores.reshape((n_lines1 * top2k, n, m))
  233. nw_scores = self.needleman_wunsch(top_scores)
  234. nw_scores = nw_scores.reshape(n_lines1, top2k)
  235. matches = torch.remainder(torch.argmax(nw_scores, dim=1), top2k // 2)
  236. matches = topk_lines[torch.arange(n_lines1), matches]
  237. return matches
  238. def needleman_wunsch(self, scores: Tensor) -> Tensor:
  239. """Batched implementation of the Needleman-Wunsch algorithm.
  240. The cost of the InDel operation is set to 0 by subtracting the gap
  241. penalty to the scores.
  242. Args:
  243. scores: a (B, N, M) Tensor containing the pairwise scores
  244. of the elements to match.
  245. """
  246. KORNIA_CHECK_SHAPE(scores, ["B", "N", "M"])
  247. # Recalibrate the scores to get a gap score of 0
  248. gap = 0.1
  249. B, N, M = scores.shape
  250. dp = torch.zeros(B, N + 1, M + 1, device=scores.device)
  251. S = scores - gap
  252. for k in range(2, N + M + 1):
  253. i_min = max(1, k - M)
  254. i_max = min(N, k - 1)
  255. i = torch.arange(i_min, i_max + 1, device=scores.device)
  256. j = k - i
  257. up = dp[:, i - 1, j]
  258. left = dp[:, i, j - 1]
  259. diag = dp[:, i - 1, j - 1] + S[:, i - 1, j - 1]
  260. dp[:, i, j] = torch.max(torch.max(up, left), diag)
  261. return dp[:, -1, -1]
  262. def keypoints_to_grid(keypoints: Tensor, img_size: Tuple[int, int]) -> Tensor:
  263. """Convert a list of keypoints into a grid in [-1, 1]² that can be used in torch.nn.functional.interpolate.
  264. Args:
  265. keypoints: a tensor [N, 2] of N keypoints (ij coordinates convention).
  266. img_size: the original image size (H, W)
  267. """
  268. KORNIA_CHECK_SHAPE(keypoints, ["N", "2"])
  269. n_points = len(keypoints)
  270. grid_points = normalize_pixel_coordinates(keypoints[:, [1, 0]], img_size[0], img_size[1])
  271. grid_points = grid_points.view(-1, n_points, 1, 2)
  272. return grid_points
  273. def batched_linspace(start: Tensor, end: Tensor, step: int, dim: int) -> Tensor:
  274. """Batch version of torch.normalize (similar to the numpy one)."""
  275. intervals = ((end - start) / (step - 1)).unsqueeze(dim)
  276. broadcast_size = [1] * len(intervals.shape)
  277. broadcast_size[dim] = step
  278. samples = torch.arange(step, dtype=torch.float, device=start.device).reshape(broadcast_size)
  279. samples = start.unsqueeze(dim) + samples * intervals
  280. return samples