| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553 |
- # LICENSE HEADER MANAGED BY add-license-header
- #
- # Copyright 2018 Kornia Team
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- from typing import Any, ClassVar, Dict, List, Optional, Tuple
- import torch
- from kornia.core import Module, Tensor, concatenate
- from kornia.core.check import KORNIA_CHECK_DM_DESC, KORNIA_CHECK_SHAPE
- from kornia.feature.laf import get_laf_center
- from kornia.feature.steerers import DiscreteSteerer
- from kornia.utils.helpers import is_mps_tensor_safe
- from .adalam import get_adalam_default_config, match_adalam
- def _cdist(d1: Tensor, d2: Tensor) -> Tensor:
- r"""Manual `torch.cdist` for M1."""
- if (not is_mps_tensor_safe(d1)) and (not is_mps_tensor_safe(d2)):
- return torch.cdist(d1, d2)
- d1_sq = (d1**2).sum(dim=1, keepdim=True)
- d2_sq = (d2**2).sum(dim=1, keepdim=True)
- dm = d1_sq.repeat(1, d2.size(0)) + d2_sq.repeat(1, d1.size(0)).t() - 2.0 * d1 @ d2.t()
- dm = dm.clamp(min=0.0).sqrt()
- return dm
- def _get_default_fginn_params() -> Dict[str, Any]:
- config = {"th": 0.85, "mutual": False, "spatial_th": 10.0}
- return config
- def _get_lazy_distance_matrix(desc1: Tensor, desc2: Tensor, dm_: Optional[Tensor] = None) -> Tensor:
- """Check validity of provided distance matrix, or calculates L2-distance matrix if dm is not provided.
- Args:
- desc1: Batch of descriptors of a shape :math:`(B1, D)`.
- desc2: Batch of descriptors of a shape :math:`(B2, D)`.
- dm_: Tensor containing the distances from each descriptor in desc1
- to each descriptor in desc2, shape of :math:`(B1, B2)`.
- """
- if dm_ is None:
- dm = _cdist(desc1, desc2)
- else:
- KORNIA_CHECK_DM_DESC(desc1, desc2, dm_)
- dm = dm_
- return dm
- def _no_match(dm: Tensor) -> Tuple[Tensor, Tensor]:
- """Output empty tensors.
- Returns:
- - Descriptor distance of matching descriptors, shape of :math:`(0, 1)`.
- - Long tensor indexes of matching descriptors in desc1 and desc2, shape of :math:`(0, 2)`.
- """
- dists = torch.empty(0, 1, device=dm.device, dtype=dm.dtype)
- idxs = torch.empty(0, 2, device=dm.device, dtype=torch.long)
- return dists, idxs
- def match_nn(desc1: Tensor, desc2: Tensor, dm: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
- r"""Find nearest neighbors in desc2 for each vector in desc1.
- If the distance matrix dm is not provided, :py:func:`torch.cdist` is used.
- Args:
- desc1: Batch of descriptors of a shape :math:`(B1, D)`.
- desc2: Batch of descriptors of a shape :math:`(B2, D)`.
- dm: Tensor containing the distances from each descriptor in desc1
- to each descriptor in desc2, shape of :math:`(B1, B2)`.
- Returns:
- - Descriptor distance of matching descriptors, shape of :math:`(B1, 1)`.
- - Long tensor indexes of matching descriptors in desc1 and desc2, shape of :math:`(B1, 2)`.
- """
- KORNIA_CHECK_SHAPE(desc1, ["B", "DIM"])
- KORNIA_CHECK_SHAPE(desc2, ["B", "DIM"])
- if (len(desc1) == 0) or (len(desc2) == 0):
- return _no_match(desc1)
- distance_matrix = _get_lazy_distance_matrix(desc1, desc2, dm)
- match_dists, idxs_in_2 = torch.min(distance_matrix, dim=1)
- idxs_in1 = torch.arange(0, idxs_in_2.size(0), device=idxs_in_2.device)
- matches_idxs = concatenate([idxs_in1.view(-1, 1), idxs_in_2.view(-1, 1)], 1)
- return match_dists.view(-1, 1), matches_idxs.view(-1, 2)
- def match_mnn(desc1: Tensor, desc2: Tensor, dm: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
- """Find mutual nearest neighbors in desc2 for each vector in desc1.
- If the distance matrix dm is not provided, :py:func:`torch.cdist` is used.
- Args:
- desc1: Batch of descriptors of a shape :math:`(B1, D)`.
- desc2: Batch of descriptors of a shape :math:`(B2, D)`.
- dm: Tensor containing the distances from each descriptor in desc1
- to each descriptor in desc2, shape of :math:`(B1, B2)`.
- Return:
- - Descriptor distance of matching descriptors, shape of. :math:`(B3, 1)`.
- - Long tensor indexes of matching descriptors in desc1 and desc2, shape of :math:`(B3, 2)`,
- where 0 <= B3 <= min(B1, B2)
- """
- KORNIA_CHECK_SHAPE(desc1, ["B", "DIM"])
- KORNIA_CHECK_SHAPE(desc2, ["B", "DIM"])
- if (len(desc1) == 0) or (len(desc2) == 0):
- return _no_match(desc1)
- distance_matrix = _get_lazy_distance_matrix(desc1, desc2, dm)
- ms = min(distance_matrix.size(0), distance_matrix.size(1))
- match_dists, idxs_in_2 = torch.min(distance_matrix, dim=1)
- match_dists2, idxs_in_1 = torch.min(distance_matrix, dim=0)
- minsize_idxs = torch.arange(ms, device=distance_matrix.device)
- if distance_matrix.size(0) <= distance_matrix.size(1):
- mutual_nns = minsize_idxs == idxs_in_1[idxs_in_2][:ms]
- matches_idxs = concatenate([minsize_idxs.view(-1, 1), idxs_in_2.view(-1, 1)], 1)[mutual_nns]
- match_dists = match_dists[mutual_nns]
- else:
- mutual_nns = minsize_idxs == idxs_in_2[idxs_in_1][:ms]
- matches_idxs = concatenate([idxs_in_1.view(-1, 1), minsize_idxs.view(-1, 1)], 1)[mutual_nns]
- match_dists = match_dists2[mutual_nns]
- return match_dists.view(-1, 1), matches_idxs.view(-1, 2)
- def match_snn(desc1: Tensor, desc2: Tensor, th: float = 0.8, dm: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
- """Find nearest neighbors in desc2 for each vector in desc1.
- The method satisfies first to second nearest neighbor distance <= th.
- If the distance matrix dm is not provided, :py:func:`torch.cdist` is used.
- Args:
- desc1: Batch of descriptors of a shape :math:`(B1, D)`.
- desc2: Batch of descriptors of a shape :math:`(B2, D)`.
- th: distance ratio threshold.
- dm: Tensor containing the distances from each descriptor in desc1
- to each descriptor in desc2, shape of :math:`(B1, B2)`.
- Return:
- - Descriptor distance of matching descriptors, shape of :math:`(B3, 1)`.
- - Long tensor indexes of matching descriptors in desc1 and desc2. Shape: :math:`(B3, 2)`,
- where 0 <= B3 <= B1.
- """
- KORNIA_CHECK_SHAPE(desc1, ["B", "DIM"])
- KORNIA_CHECK_SHAPE(desc2, ["B", "DIM"])
- if desc2.shape[0] < 2: # We cannot perform snn check, so output empty matches
- return _no_match(desc1)
- distance_matrix = _get_lazy_distance_matrix(desc1, desc2, dm)
- vals, idxs_in_2 = torch.topk(distance_matrix, 2, dim=1, largest=False)
- ratio = vals[:, 0] / vals[:, 1]
- mask = ratio <= th
- match_dists = ratio[mask]
- if len(match_dists) == 0:
- return _no_match(distance_matrix)
- idxs_in1 = torch.arange(0, idxs_in_2.size(0), device=distance_matrix.device)[mask]
- idxs_in_2 = idxs_in_2[:, 0][mask]
- matches_idxs = concatenate([idxs_in1.view(-1, 1), idxs_in_2.view(-1, 1)], 1)
- return match_dists.view(-1, 1), matches_idxs.view(-1, 2)
- def match_smnn(desc1: Tensor, desc2: Tensor, th: float = 0.95, dm: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
- """Find mutual nearest neighbors in desc2 for each vector in desc1.
- the method satisfies first to second nearest neighbor distance <= th.
- If the distance matrix dm is not provided, :py:func:`torch.cdist` is used.
- Args:
- desc1: Batch of descriptors of a shape :math:`(B1, D)`.
- desc2: Batch of descriptors of a shape :math:`(B2, D)`.
- th: distance ratio threshold.
- dm: Tensor containing the distances from each descriptor in desc1
- to each descriptor in desc2, shape of :math:`(B1, B2)`.
- Return:
- - Descriptor distance of matching descriptors, shape of. :math:`(B3, 1)`.
- - Long tensor indexes of matching descriptors in desc1 and desc2,
- shape of :math:`(B3, 2)` where 0 <= B3 <= B1.
- """
- KORNIA_CHECK_SHAPE(desc1, ["B", "DIM"])
- KORNIA_CHECK_SHAPE(desc2, ["B", "DIM"])
- if (desc1.shape[0] < 2) or (desc2.shape[0] < 2):
- return _no_match(desc1)
- distance_matrix = _get_lazy_distance_matrix(desc1, desc2, dm)
- dists1, idx1 = match_snn(desc1, desc2, th, distance_matrix)
- dists2, idx2 = match_snn(desc2, desc1, th, distance_matrix.t())
- if len(dists2) > 0 and len(dists1) > 0:
- idx2 = idx2.flip(1)
- if not is_mps_tensor_safe(idx1):
- idxs_dm = torch.cdist(idx1.float(), idx2.float(), p=1.0)
- else:
- idxs1_rep = idx1.to(desc1).repeat_interleave(idx2.size(0), dim=0)
- idxs_dm = (idx2.to(desc2).repeat(idx1.size(0), 1) - idxs1_rep).abs().sum(dim=1)
- idxs_dm = idxs_dm.reshape(idx1.size(0), idx2.size(0))
- mutual_idxs1 = idxs_dm.min(dim=1)[0] < 1e-8
- mutual_idxs2 = idxs_dm.min(dim=0)[0] < 1e-8
- good_idxs1 = idx1[mutual_idxs1.view(-1)]
- good_idxs2 = idx2[mutual_idxs2.view(-1)]
- dists1_good = dists1[mutual_idxs1.view(-1)]
- dists2_good = dists2[mutual_idxs2.view(-1)]
- _, idx_upl1 = torch.sort(good_idxs1[:, 0])
- _, idx_upl2 = torch.sort(good_idxs2[:, 0])
- good_idxs1 = good_idxs1[idx_upl1]
- match_dists = torch.max(dists1_good[idx_upl1], dists2_good[idx_upl2])
- matches_idxs = good_idxs1
- match_dists, matches_idxs = match_dists.view(-1, 1), matches_idxs.view(-1, 2)
- else:
- match_dists, matches_idxs = _no_match(distance_matrix)
- return match_dists, matches_idxs
- def match_fginn(
- desc1: Tensor,
- desc2: Tensor,
- lafs1: Tensor,
- lafs2: Tensor,
- th: float = 0.8,
- spatial_th: float = 10.0,
- mutual: bool = False,
- dm: Optional[Tensor] = None,
- ) -> Tuple[Tensor, Tensor]:
- """Find nearest neighbors in desc2 for each vector in desc1.
- The method satisfies first to second nearest neighbor distance <= th,
- and assures 2nd nearest neighbor is geometrically inconsistent with the 1st one
- (see :cite:`MODS2015` for more details)
- If the distance matrix dm is not provided, :py:func:`torch.cdist` is used.
- Args:
- desc1: Batch of descriptors of a shape :math:`(B1, D)`.
- desc2: Batch of descriptors of a shape :math:`(B2, D)`.
- lafs1: LAFs of a shape :math:`(1, B1, 2, 3)`.
- lafs2: LAFs of a shape :math:`(1, B2, 2, 3)`.
- th: distance ratio threshold.
- spatial_th: minimal distance in pixels to 2nd nearest neighbor.
- mutual: also perform mutual nearest neighbor check
- dm: Tensor containing the distances from each descriptor in desc1
- to each descriptor in desc2, shape of :math:`(B1, B2)`.
- Return:
- - Descriptor distance of matching descriptors, shape of :math:`(B3, 1)`.
- - Long tensor indexes of matching descriptors in desc1 and desc2. Shape: :math:`(B3, 2)`,
- where 0 <= B3 <= B1.
- """
- KORNIA_CHECK_SHAPE(desc1, ["B", "DIM"])
- KORNIA_CHECK_SHAPE(desc2, ["B", "DIM"])
- BIG_NUMBER = 1000000.0
- distance_matrix = _get_lazy_distance_matrix(desc1, desc2, dm)
- dtype = distance_matrix.dtype
- if desc2.shape[0] < 2: # We cannot perform snn check, so output empty matches
- return _no_match(distance_matrix)
- num_candidates = max(2, min(10, desc2.shape[0]))
- vals_cand, idxs_in_2 = torch.topk(distance_matrix, num_candidates, dim=1, largest=False)
- vals = vals_cand[:, 0]
- xy2 = get_laf_center(lafs2).view(-1, 2)
- candidates_xy = xy2[idxs_in_2]
- kdist = torch.norm(candidates_xy - candidates_xy[0:1], p=2, dim=2)
- fginn_vals = vals_cand[:, 1:] + (kdist[:, 1:] < spatial_th).to(dtype) * BIG_NUMBER
- fginn_vals_best, _fginn_idxs_best = fginn_vals.min(dim=1)
- # orig_idxs = idxs_in_2.gather(1, fginn_idxs_best.unsqueeze(1))[0]
- # if you need to know fginn indexes - uncomment
- vals_2nd = fginn_vals_best
- idxs_in_2 = idxs_in_2[:, 0]
- ratio = vals / vals_2nd
- mask = ratio <= th
- match_dists = ratio[mask]
- if len(match_dists) == 0:
- return _no_match(distance_matrix)
- idxs_in1 = torch.arange(0, idxs_in_2.size(0), device=distance_matrix.device)[mask]
- idxs_in_2 = idxs_in_2[mask]
- matches_idxs = concatenate([idxs_in1.view(-1, 1), idxs_in_2.view(-1, 1)], 1)
- match_dists, matches_idxs = match_dists.view(-1, 1), matches_idxs.view(-1, 2)
- if not mutual: # returning 1-way matches
- return match_dists, matches_idxs
- _, idxs_in_1_mut = torch.min(distance_matrix, dim=0)
- good_mask = matches_idxs[:, 0] == idxs_in_1_mut[matches_idxs[:, 1]]
- return match_dists[good_mask], matches_idxs[good_mask]
- class DescriptorMatcher(Module):
- """Module version of matching functions.
- See :func:`~kornia.feature.match_nn`, :func:`~kornia.feature.match_snn`,
- :func:`~kornia.feature.match_mnn` or :func:`~kornia.feature.match_smnn` for more details.
- Args:
- match_mode: type of matching, can be `nn`, `snn`, `mnn`, `smnn`.
- th: threshold on distance ratio, or other quality measure.
- """
- def __init__(self, match_mode: str = "snn", th: float = 0.8) -> None:
- super().__init__()
- _match_mode: str = match_mode.lower()
- self.known_modes = ["nn", "mnn", "snn", "smnn"]
- if _match_mode not in self.known_modes:
- raise NotImplementedError(f"{match_mode} is not supported. Try one of {self.known_modes}")
- self.match_mode = _match_mode
- self.th = th
- def forward(self, desc1: Tensor, desc2: Tensor) -> Tuple[Tensor, Tensor]:
- """Run forward.
- Args:
- desc1: Batch of descriptors of a shape :math:`(B1, D)`.
- desc2: Batch of descriptors of a shape :math:`(B2, D)`.
- Returns:
- - Descriptor distance of matching descriptors, shape of :math:`(B3, 1)`.
- - Long tensor indexes of matching descriptors in desc1 and desc2,
- shape of :math:`(B3, 2)` where :math:`0 <= B3 <= B1`.
- """
- if self.match_mode == "nn":
- out = match_nn(desc1, desc2)
- elif self.match_mode == "mnn":
- out = match_mnn(desc1, desc2)
- elif self.match_mode == "snn":
- out = match_snn(desc1, desc2, self.th)
- elif self.match_mode == "smnn":
- out = match_smnn(desc1, desc2, self.th)
- else:
- raise NotImplementedError
- return out
- class DescriptorMatcherWithSteerer(Module):
- """Matching that is invariant under rotations, using Steerers.
- Args:
- steerer: An instance of :func:`kornia.feature.steerers.DiscreteSteerer`.
- steerer_order: order of discretisation of rotation angles, e.g. 4 leads to quarter rotations.
- steer_mode: can be `global`, `local`.
- `global` means that the we output matches from the global rotation with most matches.
- `local` means that we output matches from a distance matrix
- where the distance between each descriptor pair is the minimal over rotations.
- match_mode: type of matching, can be `nn`, `snn`, `mnn`, `smnn`.
- WARNING: using steer_mode `global` with match_mode `nn` will lead to bad results
- since `nn` doesn't generate different amount of matches depending on goodness of fit.
- th: threshold on distance ratio, or other quality measure.
- Example:
- >>> import kornia as K
- >>> import kornia.feature as KF
- >>> device = K.utils.get_cuda_or_mps_device_if_available()
- >>> img1 = torch.randn([1, 3, 768, 768], device=device)
- >>> img2 = torch.randn([1, 3, 768, 768], device=device)
- >>> dedode = KF.DeDoDe.from_pretrained(detector_weights="L-C4-v2", descriptor_weights="B-SO2").to(device)
- >>> steerer_order = 8 # discretisation order of rotation angles
- >>> steerer = KF.steerers.DiscreteSteerer.create_dedode_default(
- ... generator_type="SO2", steerer_order=steerer_order
- ... )
- >>> steerer = steerer.to(device)
- >>> matcher = KF.matching.DescriptorMatcherWithSteerer(
- ... steerer=steerer, steerer_order=steerer_order, steer_mode="global", match_mode="smnn", th=0.98
- ... )
- >>> with torch.inference_mode():
- ... kps1, scores1, descs1 = dedode(img1, n=20_000)
- ... kps2, scores2, descs2 = dedode(img2, n=20_000)
- ... kps1, kps2, descs1, descs2 = kps1[0], kps2[0], descs1[0], descs2[0]
- ... dists, idxs, num_rot = matcher(
- ... descs1, descs2, normalize=True, subset_size=1000,
- ... )
- >>> # print(f"{idxs.shape[0]} tentative matches with steered DeDoDe")
- >>> # print(f"at rotation of {num_rot * 360 / steerer_order} degrees")
- """
- def __init__(
- self,
- steerer: DiscreteSteerer,
- steerer_order: int,
- steer_mode: str = "global",
- match_mode: str = "snn",
- th: float = 0.8,
- ) -> None:
- super().__init__()
- self.steerer = steerer
- self.steerer_order = steerer_order
- _steer_mode: str = steer_mode.lower()
- self.known_steer_modes = ["global", "local"]
- if _steer_mode not in self.known_steer_modes:
- raise NotImplementedError(f"{steer_mode} is not supported. Try one of {self.known_steer_modes}")
- self.steer_mode = _steer_mode
- _match_mode: str = match_mode.lower()
- self.known_modes = ["nn", "mnn", "snn", "smnn"]
- if _match_mode not in self.known_modes:
- raise NotImplementedError(f"{match_mode} is not supported. Try one of {self.known_modes}")
- self.match_mode = _match_mode
- self.th = th
- def matching_function(self, d1: Tensor, d2: Tensor, dm: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
- if self.match_mode == "nn":
- return match_nn(d1, d2, dm=dm)
- elif self.match_mode == "mnn":
- return match_mnn(d1, d2, dm=dm)
- elif self.match_mode == "snn":
- return match_snn(d1, d2, self.th, dm=dm)
- elif self.match_mode == "smnn":
- return match_smnn(d1, d2, self.th, dm=dm)
- else:
- raise NotImplementedError
- def forward(
- self,
- desc1: Tensor,
- desc2: Tensor,
- normalize: bool = False,
- subset_size: Optional[int] = None,
- ) -> Tuple[Tensor, Tensor, Optional[int]]:
- """Run forward.
- Args:
- desc1: Batch of descriptors of a shape :math:`(B1, D)`.
- desc2: Batch of descriptors of a shape :math:`(B2, D)`.
- normalize: bool to decide whether to normalize descriptors to unit norm.
- subset_size: If set, the subset size to use for determining optimal
- number of rotations. Smaller subset size leads to faster but less
- accurate matching. Only used when `self.steer_mode` is `"global"`.
- Returns:
- - Descriptor distance of matching descriptors, shape of :math:`(B3, 1)`.
- - Long tensor indexes of matching descriptors in desc1 and desc2,
- shape of :math:`(B3, 2)` where :math:`0 <= B3 <= B1`.
- - Number of global rotations from desc1 to desc2, in terms of `self.steerer_order`
- (will be `None` if `self.steer_mode` is `local`).
- """
- rot1to2 = None
- if normalize:
- desc1 = torch.nn.functional.normalize(desc1, dim=-1)
- desc2 = torch.nn.functional.normalize(desc2, dim=-1)
- if self.steer_mode == "global":
- if subset_size is not None:
- subsample1 = torch.randperm(desc1.shape[0])[:subset_size]
- subsample2 = torch.randperm(desc2.shape[0])[:subset_size]
- _, _, rot1to2 = self(
- desc1[subsample1],
- desc2[subsample2],
- normalize=normalize,
- )
- desc1 = self.steerer.steer_descriptions(
- desc1,
- steerer_power=rot1to2,
- normalize=normalize,
- )
- dist, idx = self.matching_function(desc1, desc2, None)
- return dist, idx, rot1to2
- dist, idx = self.matching_function(desc1, desc2, None)
- rot1to2 = 0
- for r in range(1, self.steerer_order):
- desc1 = self.steerer.steer_descriptions(desc1, normalize=normalize)
- dist_new, idx_new = self.matching_function(desc1, desc2, None)
- if idx_new.shape[0] > idx.shape[0]:
- dist, idx, rot1to2 = dist_new, idx_new, r
- elif self.steer_mode == "local":
- dm = _cdist(desc1, desc2)
- for _ in range(1, self.steerer_order):
- desc1 = self.steerer.steer_descriptions(desc1, normalize=normalize)
- dm_new = _cdist(desc1, desc2)
- dm = torch.minimum(dm, dm_new)
- dist, idx = self.matching_function(desc1, desc2, dm)
- else:
- raise NotImplementedError
- return dist, idx, rot1to2
- class GeometryAwareDescriptorMatcher(Module):
- """Module version of matching functions.
- See :func:`~kornia.feature.match_nn`, :func:`~kornia.feature.match_snn`,
- :func:`~kornia.feature.match_mnn` or :func:`~kornia.feature.match_smnn` for more details.
- Args:
- match_mode: type of matching, can be `fginn`.
- th: threshold on distance ratio, or other quality measure.
- """
- known_modes: ClassVar[List[str]] = ["fginn", "adalam"]
- def __init__(self, match_mode: str = "fginn", params: Optional[Dict[str, Tensor]] = None) -> None:
- super().__init__()
- _match_mode: str = match_mode.lower()
- if _match_mode not in self.known_modes:
- raise NotImplementedError(f"{match_mode} is not supported. Try one of {self.known_modes}")
- self.match_mode = _match_mode
- self.params = params or {}
- def forward(self, desc1: Tensor, desc2: Tensor, lafs1: Tensor, lafs2: Tensor) -> Tuple[Tensor, Tensor]:
- """Run forward.
- Args:
- desc1: Batch of descriptors of a shape :math:`(B1, D)`.
- desc2: Batch of descriptors of a shape :math:`(B2, D)`.
- lafs1: LAFs of a shape :math:`(1, B1, 2, 3)`.
- lafs2: LAFs of a shape :math:`(1, B2, 2, 3)`.
- Returns:
- - Descriptor distance of matching descriptors, shape of :math:`(B3, 1)`.
- - Long tensor indexes of matching descriptors in desc1 and desc2,
- shape of :math:`(B3, 2)` where :math:`0 <= B3 <= B1`.
- """
- if self.match_mode == "fginn":
- params = _get_default_fginn_params()
- params.update(self.params)
- out = match_fginn(desc1, desc2, lafs1, lafs2, params["th"], params["spatial_th"], params["mutual"])
- elif self.match_mode == "adalam":
- _params = get_adalam_default_config()
- _params.update(self.params) # type: ignore[typeddict-item]
- out = match_adalam(desc1, desc2, lafs1, lafs2, config=_params)
- else:
- raise NotImplementedError
- return out
|