| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164 |
- from __future__ import annotations
- import sys
- from pathlib import Path
- import numpy as np
- import torch
- _REPOSITORY_ROOT_DIRECTORY_PATH = Path(__file__).resolve().parent.parent
- _LIGHTGLUE_REPOSITORY_PACKAGE_PARENT_DIRECTORY_PATH = (
- _REPOSITORY_ROOT_DIRECTORY_PATH / "python" / "LightGlue"
- )
- _DEFAULT_SUPERPOINT_MAX_NUM_KEYPOINTS_INTEGER = 2048
- def _ensure_lightglue_python_package_is_importable() -> None:
- path_string = str(_LIGHTGLUE_REPOSITORY_PACKAGE_PARENT_DIRECTORY_PATH)
- if path_string not in sys.path:
- sys.path.insert(0, path_string)
- def _resolve_existing_image_paths(
- template_image_path_or_bgr_numpy_array,
- large_image_path_or_bgr_numpy_array,
- ):
- return template_image_path_or_bgr_numpy_array, large_image_path_or_bgr_numpy_array
- def _choose_torch_device(device: torch.device | None) -> torch.device:
- if device is not None:
- return device
- return torch.device("cuda" if torch.cuda.is_available() else "cpu")
- def _effective_superpoint_keypoint_cap(
- max_num_keypoints: int | None,
- ) -> int:
- if max_num_keypoints is None:
- return _DEFAULT_SUPERPOINT_MAX_NUM_KEYPOINTS_INTEGER
- return int(max_num_keypoints)
- def _get_superpoint_extractor(device: torch.device, max_num_keypoints: int):
- _ensure_lightglue_python_package_is_importable()
- from lightglue import SuperPoint
- extractor = SuperPoint(max_num_keypoints=max_num_keypoints).eval().to(device)
- return extractor
- def extract_superpoint_features_pair_rbd_on_device(
- extractor,
- image0: torch.Tensor,
- image1: torch.Tensor,
- device: torch.device,
- ):
- _ensure_lightglue_python_package_is_importable()
- from lightglue.utils import batch_to_device, rbd
- feats0 = extractor.extract(image0)
- feats1 = extractor.extract(image1)
- device_string = str(device)
- feats0 = batch_to_device(rbd(feats0), device_string)
- feats1 = batch_to_device(rbd(feats1), device_string)
- return feats0, feats1
- def _numpy_bgr_uint8_array_to_rgb_torch_chw_float_zero_to_one(
- bgr_uint8_numpy_array: np.ndarray,
- ) -> torch.Tensor:
- _ensure_lightglue_python_package_is_importable()
- from lightglue.utils import numpy_image_to_torch
- rgb_uint8_numpy_array = bgr_uint8_numpy_array[..., ::-1]
- return numpy_image_to_torch(rgb_uint8_numpy_array)
- def _load_images_to_device(
- template_image_source,
- large_image_source,
- device: torch.device,
- ):
- _ensure_lightglue_python_package_is_importable()
- from lightglue.utils import numpy_image_to_torch, read_image
- if isinstance(template_image_source, np.ndarray):
- image0 = _numpy_bgr_uint8_array_to_rgb_torch_chw_float_zero_to_one(
- template_image_source
- )
- else:
- image0 = numpy_image_to_torch(read_image(Path(template_image_source)))
- if isinstance(large_image_source, np.ndarray):
- image1 = _numpy_bgr_uint8_array_to_rgb_torch_chw_float_zero_to_one(
- large_image_source
- )
- else:
- image1 = numpy_image_to_torch(read_image(Path(large_image_source)))
- return image0.to(device), image1.to(device)
- def _template_xy_large_xy_scores_from_descriptor_nearest_neighbor(
- feats0: dict,
- feats1: dict,
- ):
- """模板侧每个 SuperPoint 点在截图侧取描述子余弦相似度最大的一个点;无阈值剔除。"""
- keypoints_template_image_xy_torch = feats0["keypoints"].float()
- keypoints_large_image_xy_torch = feats1["keypoints"].float()
- descriptors_template_torch = feats0["descriptors"].float()
- descriptors_large_torch = feats1["descriptors"].float()
- descriptors_template_normalized_torch = torch.nn.functional.normalize(
- descriptors_template_torch,
- p=2,
- dim=-1,
- )
- descriptors_large_normalized_torch = torch.nn.functional.normalize(
- descriptors_large_torch,
- p=2,
- dim=-1,
- )
- cosine_similarity_matrix = (
- descriptors_template_normalized_torch @ descriptors_large_normalized_torch.T
- )
- template_point_count_integer = int(cosine_similarity_matrix.shape[0])
- if template_point_count_integer == 0:
- empty_xy_numpy_array = np.zeros((0, 2), dtype=np.float64)
- empty_scores_numpy_array = np.zeros((0,), dtype=np.float64)
- return empty_xy_numpy_array, empty_xy_numpy_array, empty_scores_numpy_array
- best_large_point_index_per_template_row = cosine_similarity_matrix.argmax(dim=1)
- row_index_torch = torch.arange(
- template_point_count_integer,
- device=cosine_similarity_matrix.device,
- dtype=torch.long,
- )
- confidence_score_per_template_point_torch = cosine_similarity_matrix[
- row_index_torch,
- best_large_point_index_per_template_row,
- ]
- template_xy_numpy_array = (
- keypoints_template_image_xy_torch.detach().cpu().numpy().astype(np.float64)
- )
- large_xy_matched_numpy_array = (
- keypoints_large_image_xy_torch[
- best_large_point_index_per_template_row
- ]
- .detach()
- .cpu()
- .numpy()
- .astype(np.float64)
- )
- confidence_scores_numpy_array = (
- confidence_score_per_template_point_torch.detach().cpu().numpy().astype(
- np.float64
- )
- )
- return (
- template_xy_numpy_array,
- large_xy_matched_numpy_array,
- confidence_scores_numpy_array,
- )
|