modular_efficientloftr.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. from typing import TYPE_CHECKING
  2. from ...processing_utils import ImagesKwargs
  3. from ...utils import TensorType, is_torch_available
  4. from ...utils.import_utils import requires
  5. from ..superglue.image_processing_pil_superglue import SuperGlueImageProcessorPil
  6. from ..superglue.image_processing_superglue import SuperGlueImageProcessor
  7. if is_torch_available():
  8. import torch
  9. if TYPE_CHECKING:
  10. from .modeling_efficientloftr import EfficientLoFTRKeypointMatchingOutput
  11. class EfficientLoFTRImageProcessorKwargs(ImagesKwargs, total=False):
  12. r"""
  13. do_grayscale (`bool`, *optional*, defaults to `self.do_grayscale`):
  14. Whether to convert the image to grayscale. Can be overridden by `do_grayscale` in the `preprocess` method.
  15. """
  16. do_grayscale: bool
  17. class EfficientLoFTRImageProcessor(SuperGlueImageProcessor):
  18. def post_process_keypoint_matching(
  19. self,
  20. outputs: "EfficientLoFTRKeypointMatchingOutput",
  21. target_sizes: TensorType | list[tuple],
  22. threshold: float = 0.0,
  23. ) -> list[dict[str, torch.Tensor]]:
  24. """
  25. Converts the raw output of [`EfficientLoFTRKeypointMatchingOutput`] into lists of keypoints, scores and descriptors
  26. with coordinates absolute to the original image sizes.
  27. Args:
  28. outputs ([`EfficientLoFTRKeypointMatchingOutput`]):
  29. Raw outputs of the model.
  30. target_sizes (`torch.Tensor` or `List[Tuple[Tuple[int, int]]]`, *optional*):
  31. Tensor of shape `(batch_size, 2, 2)` or list of tuples of tuples (`Tuple[int, int]`) containing the
  32. target size `(height, width)` of each image in the batch. This must be the original image size (before
  33. any processing).
  34. threshold (`float`, *optional*, defaults to 0.0):
  35. Threshold to filter out the matches with low scores.
  36. Returns:
  37. `List[Dict]`: A list of dictionaries, each dictionary containing the keypoints in the first and second image
  38. of the pair, the matching scores and the matching indices.
  39. """
  40. if outputs.matches.shape[0] != len(target_sizes):
  41. raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the mask")
  42. if not all(len(target_size) == 2 for target_size in target_sizes):
  43. raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch")
  44. if isinstance(target_sizes, list):
  45. image_pair_sizes = torch.tensor(target_sizes, device=outputs.matches.device)
  46. else:
  47. if target_sizes.shape[1] != 2 or target_sizes.shape[2] != 2:
  48. raise ValueError(
  49. "Each element of target_sizes must contain the size (h, w) of each image of the batch"
  50. )
  51. image_pair_sizes = target_sizes
  52. keypoints = outputs.keypoints.clone()
  53. keypoints = keypoints * image_pair_sizes.flip(-1).reshape(-1, 2, 1, 2)
  54. keypoints = keypoints.to(torch.int32)
  55. results = []
  56. for keypoints_pair, matches, scores in zip(keypoints, outputs.matches, outputs.matching_scores):
  57. # Filter out matches with low scores
  58. valid_matches = torch.logical_and(scores > threshold, matches > -1)
  59. matched_keypoints0 = keypoints_pair[0][valid_matches[0]]
  60. matched_keypoints1 = keypoints_pair[1][valid_matches[1]]
  61. matching_scores = scores[0][valid_matches[0]]
  62. results.append(
  63. {
  64. "keypoints0": matched_keypoints0,
  65. "keypoints1": matched_keypoints1,
  66. "matching_scores": matching_scores,
  67. }
  68. )
  69. return results
  70. class EfficientLoFTRImageProcessorPil(SuperGlueImageProcessorPil):
  71. @requires(backends=("torch",))
  72. def post_process_keypoint_matching(
  73. self,
  74. outputs: "EfficientLoFTRKeypointMatchingOutput",
  75. target_sizes: TensorType | list[tuple],
  76. threshold: float = 0.0,
  77. ) -> list[dict[str, "torch.Tensor"]]:
  78. """
  79. Converts the raw output of [`EfficientLoFTRKeypointMatchingOutput`] into lists of keypoints, scores and descriptors
  80. with coordinates absolute to the original image sizes.
  81. Args:
  82. outputs ([`EfficientLoFTRKeypointMatchingOutput`]):
  83. Raw outputs of the model.
  84. target_sizes (`torch.Tensor` or `List[Tuple[Tuple[int, int]]]`, *optional*):
  85. Tensor of shape `(batch_size, 2, 2)` or list of tuples of tuples (`Tuple[int, int]`) containing the
  86. target size `(height, width)` of each image in the batch. This must be the original image size (before
  87. any processing).
  88. threshold (`float`, *optional*, defaults to 0.0):
  89. Threshold to filter out the matches with low scores.
  90. Returns:
  91. `List[Dict]`: A list of dictionaries, each dictionary containing the keypoints in the first and second image
  92. of the pair, the matching scores and the matching indices.
  93. """
  94. import torch
  95. if outputs.matches.shape[0] != len(target_sizes):
  96. raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the mask")
  97. if not all(len(target_size) == 2 for target_size in target_sizes):
  98. raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch")
  99. if isinstance(target_sizes, list):
  100. image_pair_sizes = torch.tensor(target_sizes, device=outputs.matches.device)
  101. else:
  102. if target_sizes.shape[1] != 2 or target_sizes.shape[2] != 2:
  103. raise ValueError(
  104. "Each element of target_sizes must contain the size (h, w) of each image of the batch"
  105. )
  106. image_pair_sizes = target_sizes
  107. keypoints = outputs.keypoints.clone()
  108. keypoints = keypoints * image_pair_sizes.flip(-1).reshape(-1, 2, 1, 2)
  109. keypoints = keypoints.to(torch.int32)
  110. results = []
  111. for keypoints_pair, matches, scores in zip(keypoints, outputs.matches, outputs.matching_scores):
  112. # Filter out matches with low scores
  113. valid_matches = torch.logical_and(scores > threshold, matches > -1)
  114. matched_keypoints0 = keypoints_pair[0][valid_matches[0]]
  115. matched_keypoints1 = keypoints_pair[1][valid_matches[1]]
  116. matching_scores = scores[0][valid_matches[0]]
  117. results.append(
  118. {
  119. "keypoints0": matched_keypoints0,
  120. "keypoints1": matched_keypoints1,
  121. "matching_scores": matching_scores,
  122. }
  123. )
  124. return results
  125. __all__ = ["EfficientLoFTRImageProcessor", "EfficientLoFTRImageProcessorPil"]