lightglue-img-match.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. from __future__ import annotations
  2. import sys
  3. from pathlib import Path
  4. import numpy as np
  5. import torch
  6. _REPOSITORY_ROOT_DIRECTORY_PATH = Path(__file__).resolve().parent.parent
  7. _LIGHTGLUE_REPOSITORY_PACKAGE_PARENT_DIRECTORY_PATH = (
  8. _REPOSITORY_ROOT_DIRECTORY_PATH / "python" / "LightGlue"
  9. )
  10. _DEFAULT_SUPERPOINT_MAX_NUM_KEYPOINTS_INTEGER = 2048
  11. def _ensure_lightglue_python_package_is_importable() -> None:
  12. path_string = str(_LIGHTGLUE_REPOSITORY_PACKAGE_PARENT_DIRECTORY_PATH)
  13. if path_string not in sys.path:
  14. sys.path.insert(0, path_string)
  15. def _resolve_existing_image_paths(
  16. template_image_path_or_bgr_numpy_array,
  17. large_image_path_or_bgr_numpy_array,
  18. ):
  19. return template_image_path_or_bgr_numpy_array, large_image_path_or_bgr_numpy_array
  20. def _choose_torch_device(device: torch.device | None) -> torch.device:
  21. if device is not None:
  22. return device
  23. return torch.device("cuda" if torch.cuda.is_available() else "cpu")
  24. def _effective_superpoint_keypoint_cap(
  25. max_num_keypoints: int | None,
  26. ) -> int:
  27. if max_num_keypoints is None:
  28. return _DEFAULT_SUPERPOINT_MAX_NUM_KEYPOINTS_INTEGER
  29. return int(max_num_keypoints)
  30. def _get_superpoint_extractor(device: torch.device, max_num_keypoints: int):
  31. _ensure_lightglue_python_package_is_importable()
  32. from lightglue import SuperPoint
  33. extractor = SuperPoint(max_num_keypoints=max_num_keypoints).eval().to(device)
  34. return extractor
  35. def extract_superpoint_features_pair_rbd_on_device(
  36. extractor,
  37. image0: torch.Tensor,
  38. image1: torch.Tensor,
  39. device: torch.device,
  40. ):
  41. _ensure_lightglue_python_package_is_importable()
  42. from lightglue.utils import batch_to_device, rbd
  43. feats0 = extractor.extract(image0)
  44. feats1 = extractor.extract(image1)
  45. device_string = str(device)
  46. feats0 = batch_to_device(rbd(feats0), device_string)
  47. feats1 = batch_to_device(rbd(feats1), device_string)
  48. return feats0, feats1
  49. def _numpy_bgr_uint8_array_to_rgb_torch_chw_float_zero_to_one(
  50. bgr_uint8_numpy_array: np.ndarray,
  51. ) -> torch.Tensor:
  52. _ensure_lightglue_python_package_is_importable()
  53. from lightglue.utils import numpy_image_to_torch
  54. rgb_uint8_numpy_array = bgr_uint8_numpy_array[..., ::-1]
  55. return numpy_image_to_torch(rgb_uint8_numpy_array)
  56. def _load_images_to_device(
  57. template_image_source,
  58. large_image_source,
  59. device: torch.device,
  60. ):
  61. _ensure_lightglue_python_package_is_importable()
  62. from lightglue.utils import numpy_image_to_torch, read_image
  63. if isinstance(template_image_source, np.ndarray):
  64. image0 = _numpy_bgr_uint8_array_to_rgb_torch_chw_float_zero_to_one(
  65. template_image_source
  66. )
  67. else:
  68. image0 = numpy_image_to_torch(read_image(Path(template_image_source)))
  69. if isinstance(large_image_source, np.ndarray):
  70. image1 = _numpy_bgr_uint8_array_to_rgb_torch_chw_float_zero_to_one(
  71. large_image_source
  72. )
  73. else:
  74. image1 = numpy_image_to_torch(read_image(Path(large_image_source)))
  75. return image0.to(device), image1.to(device)
  76. def _template_xy_large_xy_scores_from_descriptor_nearest_neighbor(
  77. feats0: dict,
  78. feats1: dict,
  79. ):
  80. """模板侧每个 SuperPoint 点在截图侧取描述子余弦相似度最大的一个点;无阈值剔除。"""
  81. keypoints_template_image_xy_torch = feats0["keypoints"].float()
  82. keypoints_large_image_xy_torch = feats1["keypoints"].float()
  83. descriptors_template_torch = feats0["descriptors"].float()
  84. descriptors_large_torch = feats1["descriptors"].float()
  85. descriptors_template_normalized_torch = torch.nn.functional.normalize(
  86. descriptors_template_torch,
  87. p=2,
  88. dim=-1,
  89. )
  90. descriptors_large_normalized_torch = torch.nn.functional.normalize(
  91. descriptors_large_torch,
  92. p=2,
  93. dim=-1,
  94. )
  95. cosine_similarity_matrix = (
  96. descriptors_template_normalized_torch @ descriptors_large_normalized_torch.T
  97. )
  98. template_point_count_integer = int(cosine_similarity_matrix.shape[0])
  99. if template_point_count_integer == 0:
  100. empty_xy_numpy_array = np.zeros((0, 2), dtype=np.float64)
  101. empty_scores_numpy_array = np.zeros((0,), dtype=np.float64)
  102. return empty_xy_numpy_array, empty_xy_numpy_array, empty_scores_numpy_array
  103. best_large_point_index_per_template_row = cosine_similarity_matrix.argmax(dim=1)
  104. row_index_torch = torch.arange(
  105. template_point_count_integer,
  106. device=cosine_similarity_matrix.device,
  107. dtype=torch.long,
  108. )
  109. confidence_score_per_template_point_torch = cosine_similarity_matrix[
  110. row_index_torch,
  111. best_large_point_index_per_template_row,
  112. ]
  113. template_xy_numpy_array = (
  114. keypoints_template_image_xy_torch.detach().cpu().numpy().astype(np.float64)
  115. )
  116. large_xy_matched_numpy_array = (
  117. keypoints_large_image_xy_torch[
  118. best_large_point_index_per_template_row
  119. ]
  120. .detach()
  121. .cpu()
  122. .numpy()
  123. .astype(np.float64)
  124. )
  125. confidence_scores_numpy_array = (
  126. confidence_score_per_template_point_torch.detach().cpu().numpy().astype(
  127. np.float64
  128. )
  129. )
  130. return (
  131. template_xy_numpy_array,
  132. large_xy_matched_numpy_array,
  133. confidence_scores_numpy_array,
  134. )