loftr.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from __future__ import annotations
  18. from typing import Any, Optional
  19. import torch
  20. from kornia.core import Module, Tensor
  21. from kornia.geometry import resize
  22. from .backbone import build_backbone
  23. from .loftr_module import FinePreprocess, LocalFeatureTransformer
  24. from .utils.coarse_matching import CoarseMatching
  25. from .utils.fine_matching import FineMatching
  26. from .utils.position_encoding import PositionEncodingSine
  27. urls: dict[str, str] = {}
  28. urls["outdoor"] = "http://cmp.felk.cvut.cz/~mishkdmy/models/loftr_outdoor.ckpt"
  29. urls["indoor_new"] = "http://cmp.felk.cvut.cz/~mishkdmy/models/loftr_indoor_ds_new.ckpt"
  30. urls["indoor"] = "http://cmp.felk.cvut.cz/~mishkdmy/models/loftr_indoor.ckpt"
  31. # Comments: the config below is the one corresponding to the pretrained models
  32. # Some do not change there anything, unless you want to retrain it.
  33. default_cfg = {
  34. "backbone_type": "ResNetFPN",
  35. "resolution": (8, 2),
  36. "fine_window_size": 5,
  37. "fine_concat_coarse_feat": True,
  38. "resnetfpn": {"initial_dim": 128, "block_dims": [128, 196, 256]},
  39. "coarse": {
  40. "d_model": 256,
  41. "d_ffn": 256,
  42. "nhead": 8,
  43. "layer_names": ["self", "cross", "self", "cross", "self", "cross", "self", "cross"],
  44. "attention": "linear",
  45. "temp_bug_fix": False,
  46. },
  47. "match_coarse": {
  48. "thr": 0.2,
  49. "border_rm": 2,
  50. "match_type": "dual_softmax",
  51. "dsmax_temperature": 0.1,
  52. "skh_iters": 3,
  53. "skh_init_bin_score": 1.0,
  54. "skh_prefilter": True,
  55. "train_coarse_percent": 0.4,
  56. "train_pad_num_gt_min": 200,
  57. },
  58. "fine": {"d_model": 128, "d_ffn": 128, "nhead": 8, "layer_names": ["self", "cross"], "attention": "linear"},
  59. }
  60. class LoFTR(Module):
  61. r"""Module, which finds correspondences between two images.
  62. This is based on the original code from paper "LoFTR: Detector-Free Local
  63. Feature Matching with Transformers". See :cite:`LoFTR2021` for more details.
  64. If the distance matrix dm is not provided, :py:func:`torch.cdist` is used.
  65. Args:
  66. config: Dict with initialization parameters. Do not pass it, unless you know what you are doing`.
  67. pretrained: Download and set pretrained weights to the model. Options: 'outdoor', 'indoor'.
  68. 'outdoor' is trained on the MegaDepth dataset and 'indoor'
  69. on the ScanNet.
  70. Returns:
  71. Dictionary with image correspondences and confidence scores.
  72. Example:
  73. >>> img1 = torch.rand(1, 1, 320, 200)
  74. >>> img2 = torch.rand(1, 1, 128, 128)
  75. >>> input = {"image0": img1, "image1": img2}
  76. >>> loftr = LoFTR('outdoor')
  77. >>> out = loftr(input)
  78. """
  79. def __init__(self, pretrained: Optional[str] = "outdoor", config: dict[str, Any] = default_cfg) -> None:
  80. super().__init__()
  81. # Misc
  82. self.config = config
  83. if pretrained == "indoor_new":
  84. self.config["coarse"]["temp_bug_fix"] = True
  85. # Modules
  86. self.backbone = build_backbone(config)
  87. self.pos_encoding = PositionEncodingSine(
  88. config["coarse"]["d_model"], temp_bug_fix=config["coarse"]["temp_bug_fix"]
  89. )
  90. self.loftr_coarse = LocalFeatureTransformer(config["coarse"])
  91. self.coarse_matching = CoarseMatching(config["match_coarse"])
  92. self.fine_preprocess = FinePreprocess(config)
  93. self.loftr_fine = LocalFeatureTransformer(config["fine"])
  94. self.fine_matching = FineMatching()
  95. self.pretrained = pretrained
  96. if pretrained is not None:
  97. if pretrained not in urls.keys():
  98. raise ValueError(f"pretrained should be None or one of {urls.keys()}")
  99. pretrained_dict = torch.hub.load_state_dict_from_url(urls[pretrained], map_location=torch.device("cpu"))
  100. self.load_state_dict(pretrained_dict["state_dict"])
  101. self.eval()
  102. def forward(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
  103. """Run forward.
  104. Args:
  105. data: dictionary containing the input data in the following format:
  106. Keyword Args:
  107. image0: left image with shape :math:`(N, 1, H1, W1)`.
  108. image1: right image with shape :math:`(N, 1, H2, W2)`.
  109. mask0 (optional): left image mask. '0' indicates a padded position :math:`(N, H1, W1)`.
  110. mask1 (optional): right image mask. '0' indicates a padded position :math:`(N, H2, W2)`.
  111. Returns:
  112. - ``keypoints0``, matching keypoints from image0 :math:`(NC, 2)`.
  113. - ``keypoints1``, matching keypoints from image1 :math:`(NC, 2)`.
  114. - ``confidence``, confidence score [0, 1] :math:`(NC)`.
  115. - ``batch_indexes``, batch indexes for the keypoints and lafs :math:`(NC)`.
  116. """
  117. # 1. Local Feature CNN
  118. _data: dict[str, Tensor | int | torch.Size] = {
  119. "bs": data["image0"].size(0),
  120. "hw0_i": data["image0"].shape[2:],
  121. "hw1_i": data["image1"].shape[2:],
  122. }
  123. if _data["hw0_i"] == _data["hw1_i"]: # faster & better BN convergence
  124. feats_c, feats_f = self.backbone(torch.cat([data["image0"], data["image1"]], dim=0))
  125. (feat_c0, feat_c1), (feat_f0, feat_f1) = feats_c.split(_data["bs"]), feats_f.split(_data["bs"])
  126. else: # handle different input shapes
  127. (feat_c0, feat_f0), (feat_c1, feat_f1) = self.backbone(data["image0"]), self.backbone(data["image1"])
  128. _data.update(
  129. {
  130. "hw0_c": feat_c0.shape[2:],
  131. "hw1_c": feat_c1.shape[2:],
  132. "hw0_f": feat_f0.shape[2:],
  133. "hw1_f": feat_f1.shape[2:],
  134. }
  135. )
  136. # 2. coarse-level loftr module
  137. # add featmap with positional encoding, then flatten it to sequence [N, HW, C]
  138. # feat_c0 = rearrange(self.pos_encoding(feat_c0), 'n c h w -> n (h w) c')
  139. # feat_c1 = rearrange(self.pos_encoding(feat_c1), 'n c h w -> n (h w) c')
  140. feat_c0 = self.pos_encoding(feat_c0).permute(0, 2, 3, 1)
  141. n, _h, _w, c = feat_c0.shape
  142. feat_c0 = feat_c0.reshape(n, -1, c)
  143. feat_c1 = self.pos_encoding(feat_c1).permute(0, 2, 3, 1)
  144. n1, _h1, _w1, c1 = feat_c1.shape
  145. feat_c1 = feat_c1.reshape(n1, -1, c1)
  146. mask_c0 = mask_c1 = None # mask is useful in training
  147. if "mask0" in data:
  148. mask_c0 = resize(data["mask0"], _data["hw0_c"], interpolation="nearest").flatten(-2)
  149. if "mask1" in data:
  150. mask_c1 = resize(data["mask1"], _data["hw1_c"], interpolation="nearest").flatten(-2)
  151. feat_c0, feat_c1 = self.loftr_coarse(feat_c0, feat_c1, mask_c0, mask_c1)
  152. # 3. match coarse-level
  153. self.coarse_matching(feat_c0, feat_c1, _data, mask_c0=mask_c0, mask_c1=mask_c1)
  154. # 4. fine-level refinement
  155. feat_f0_unfold, feat_f1_unfold = self.fine_preprocess(feat_f0, feat_f1, feat_c0, feat_c1, _data)
  156. if feat_f0_unfold.size(0) != 0: # at least one coarse level predicted
  157. feat_f0_unfold, feat_f1_unfold = self.loftr_fine(feat_f0_unfold, feat_f1_unfold)
  158. # 5. match fine-level
  159. self.fine_matching(feat_f0_unfold, feat_f1_unfold, _data)
  160. rename_keys: dict[str, str] = {
  161. "mkpts0_f": "keypoints0",
  162. "mkpts1_f": "keypoints1",
  163. "mconf": "confidence",
  164. "b_ids": "batch_indexes",
  165. }
  166. out: dict[str, Tensor] = {}
  167. for k, v in rename_keys.items():
  168. _d = _data[k]
  169. if isinstance(_d, Tensor):
  170. out[v] = _d
  171. else:
  172. raise TypeError(f"Expected Tensor for item `{k}`. Gotcha {type(_d)}")
  173. return out
  174. def load_state_dict(self, state_dict: dict[str, Any], *args: Any, **kwargs: Any) -> Any: # type: ignore[override]
  175. for k in list(state_dict.keys()):
  176. if k.startswith("matcher."):
  177. state_dict[k.replace("matcher.", "", 1)] = state_dict.pop(k)
  178. return super().load_state_dict(state_dict, *args, **kwargs)