lightglue.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from __future__ import annotations
  18. from typing import ClassVar
  19. import torch
  20. from kornia.core import Device, Tensor
  21. from kornia.core.check import KORNIA_CHECK, KORNIA_CHECK_SAME_DEVICES, KORNIA_CHECK_SHAPE
  22. from .utils import download_onnx_from_url, normalize_keypoints
  23. try:
  24. import numpy as np
  25. import onnxruntime as ort
  26. except ImportError:
  27. np = None # type: ignore
  28. ort = None
  29. __all__ = ["OnnxLightGlue"]
  30. class OnnxLightGlue:
  31. r"""Wrapper for loading LightGlue-ONNX models and running inference via ONNXRuntime.
  32. LightGlue :cite:`LightGlue2023` performs fast descriptor-based deep keypoint matching.
  33. This module requires `onnxruntime` to be installed.
  34. If you have trained your own LightGlue model, see https://github.com/fabio-sim/LightGlue-ONNX
  35. for how to export the model to ONNX and optimize it.
  36. Args:
  37. weights: Pretrained weights, or a path to your own exported ONNX model. Available pretrained weights
  38. are ``'disk'``, ``'superpoint'``, ``'disk_fp16'``, and ``'superpoint_fp16'``. `Note that FP16 requires CUDA.`
  39. Defaults to ``'disk_fp16'`` if ``device`` is CUDA, and ``'disk'`` if CPU.
  40. device: Device to run inference on.
  41. """
  42. MODEL_URLS: ClassVar[dict[str, str]] = {
  43. "disk": "https://github.com/fabio-sim/LightGlue-ONNX/releases/download/v1.0.0/disk_lightglue_fused.onnx",
  44. "superpoint": "https://github.com/fabio-sim/LightGlue-ONNX/releases/download/v1.0.0/superpoint_lightglue_fused.onnx",
  45. "disk_fp16": "https://github.com/fabio-sim/LightGlue-ONNX/releases/download/v1.0.0/disk_lightglue_fused_fp16.onnx",
  46. "superpoint_fp16": "https://github.com/fabio-sim/LightGlue-ONNX/releases/download/v1.0.0/superpoint_lightglue_fused_fp16.onnx",
  47. }
  48. required_data_keys: ClassVar[list[str]] = ["image0", "image1"]
  49. def __init__(self, weights: str | None = None, device: Device = "cpu") -> None:
  50. KORNIA_CHECK(ort is not None, "onnxruntime is not installed.")
  51. KORNIA_CHECK(np is not None, "numpy is not installed.")
  52. device = torch.device(device) # type: ignore
  53. self.device = device
  54. if device.type == "cpu":
  55. providers = ["CPUExecutionProvider"]
  56. elif device.type == "cuda":
  57. providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
  58. else:
  59. raise ValueError(f"Unsupported device {device}")
  60. if weights is None:
  61. weights = "disk_fp16" if device.type == "cuda" else "disk"
  62. if weights in self.MODEL_URLS:
  63. if "fp16" in weights:
  64. KORNIA_CHECK(device.type == "cuda", "FP16 requires CUDA.")
  65. url = self.MODEL_URLS[weights]
  66. if device.type == "cpu":
  67. url = url.replace(".onnx", "_cpu.onnx")
  68. weights = download_onnx_from_url(url)
  69. self.session = ort.InferenceSession(weights, providers=providers)
  70. def __call__(self, data: dict[str, dict[str, Tensor]]) -> dict[str, Tensor]:
  71. return self.forward(data)
  72. def forward(self, data: dict[str, dict[str, Tensor]]) -> dict[str, Tensor]:
  73. r"""Match keypoints and descriptors between two images.
  74. The output contains the matches (the indices of the matching keypoint pairs between the first and second image)
  75. and the corresponding confidence scores.
  76. Only a batch size of 1 is supported.
  77. Args:
  78. data: Dictionary containing both images and the keypoints and descriptors thereof.
  79. Returns:
  80. Dictionary containing the matches and scores.
  81. ``data`` (``dict``):
  82. ``image0`` (``dict``):
  83. ``keypoints`` (`float32`): :math:`(1, M, 2)`
  84. ``descriptors`` (`float32`): :math:`(1, M, D)`
  85. ``image``: :math:`(1, C, H, W)` or ``image_size``: :math:`(1, 2)`
  86. ``image1`` (``dict``):
  87. ``keypoints`` (`float32`): :math:`(1, N, 2)`
  88. ``descriptors`` (`float32`): :math:`(1, N, D)`
  89. ``image``: :math:`(1, C, H, W)` or ``image_size``: :math:`(1, 2)`
  90. ``output`` (``dict``):
  91. ``matches`` (`int64`): :math:`(S, 2)`
  92. ``scores`` (`float32`): :math:`(S)`
  93. """
  94. # Input validation.
  95. for key in self.required_data_keys:
  96. KORNIA_CHECK(key in data, f"Missing key {key} in data")
  97. data0, data1 = data["image0"], data["image1"]
  98. kpts0_, kpts1_ = data0["keypoints"].contiguous(), data1["keypoints"].contiguous()
  99. desc0, desc1 = data0["descriptors"].contiguous(), data1["descriptors"].contiguous()
  100. KORNIA_CHECK_SAME_DEVICES([kpts0_, desc0, kpts1_, desc1], "Wrong device")
  101. KORNIA_CHECK(kpts0_.device.type == self.device.type, "Wrong device")
  102. KORNIA_CHECK(torch.float32 == kpts0_.dtype == kpts1_.dtype == desc0.dtype == desc1.dtype, "Wrong dtype")
  103. KORNIA_CHECK_SHAPE(kpts0_, ["1", "M", "2"])
  104. KORNIA_CHECK_SHAPE(kpts1_, ["1", "N", "2"])
  105. KORNIA_CHECK_SHAPE(desc0, ["1", "M", "D"])
  106. KORNIA_CHECK_SHAPE(desc1, ["1", "N", "D"])
  107. KORNIA_CHECK(kpts0_.shape[1] == desc0.shape[1], "Number of keypoints does not match number of descriptors")
  108. KORNIA_CHECK(kpts1_.shape[1] == desc1.shape[1], "Number of keypoints does not match number of descriptors")
  109. KORNIA_CHECK(desc0.shape[2] == desc1.shape[2], "Descriptors' dimensions do not match")
  110. # Normalize keypoints.
  111. size0, size1 = data0.get("image_size"), data1.get("image_size")
  112. size0 = size0 if size0 is not None else data0["image"].shape[-2:][::-1] # type: ignore
  113. size1 = size1 if size1 is not None else data1["image"].shape[-2:][::-1] # type: ignore
  114. kpts0 = normalize_keypoints(kpts0_, size=size0) # type: ignore
  115. kpts1 = normalize_keypoints(kpts1_, size=size1) # type: ignore
  116. KORNIA_CHECK(torch.all(kpts0 >= -1).item() and torch.all(kpts0 <= 1).item(), "") # type: ignore
  117. KORNIA_CHECK(torch.all(kpts1 >= -1).item() and torch.all(kpts1 <= 1).item(), "") # type: ignore
  118. # Inference.
  119. lightglue_inputs = {"kpts0": kpts0, "kpts1": kpts1, "desc0": desc0, "desc1": desc1}
  120. lightglue_outputs = ["matches0", "mscores0"]
  121. binding = self.session.io_binding()
  122. for name, tensor in lightglue_inputs.items():
  123. binding.bind_input(
  124. name,
  125. device_type=self.device.type,
  126. device_id=0,
  127. element_type=np.float32,
  128. shape=tuple(tensor.shape),
  129. buffer_ptr=tensor.data_ptr(),
  130. )
  131. for name in lightglue_outputs:
  132. binding.bind_output(name, device_type=self.device.type, device_id=0)
  133. self.session.run_with_iobinding(binding)
  134. matches, mscores = binding.get_outputs()
  135. # TODO: The following is an unnecessary copy. Replace with a better solution when torch supports
  136. # constructing a tensor from a data pointer, or when ORT supports converting to torch tensor.
  137. # https://github.com/microsoft/onnxruntime/issues/15963
  138. outputs = {
  139. "matches": torch.from_dlpack(matches.numpy()).to(self.device),
  140. "scores": torch.from_dlpack(mscores.numpy()).to(self.device),
  141. }
  142. return outputs