benchmark.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. # Benchmark script for LightGlue on real images
  2. import argparse
  3. import time
  4. from collections import defaultdict
  5. from pathlib import Path
  6. import matplotlib.pyplot as plt
  7. import numpy as np
  8. import torch
  9. import torch._dynamo
  10. from lightglue import LightGlue, SuperPoint
  11. from lightglue.utils import load_image
  12. torch.set_grad_enabled(False)
  13. def measure(matcher, data, device="cuda", r=100):
  14. timings = np.zeros((r, 1))
  15. if device.type == "cuda":
  16. starter = torch.cuda.Event(enable_timing=True)
  17. ender = torch.cuda.Event(enable_timing=True)
  18. # warmup
  19. for _ in range(10):
  20. _ = matcher(data)
  21. # measurements
  22. with torch.no_grad():
  23. for rep in range(r):
  24. if device.type == "cuda":
  25. starter.record()
  26. _ = matcher(data)
  27. ender.record()
  28. # sync gpu
  29. torch.cuda.synchronize()
  30. curr_time = starter.elapsed_time(ender)
  31. else:
  32. start = time.perf_counter()
  33. _ = matcher(data)
  34. curr_time = (time.perf_counter() - start) * 1e3
  35. timings[rep] = curr_time
  36. mean_syn = np.sum(timings) / r
  37. std_syn = np.std(timings)
  38. return {"mean": mean_syn, "std": std_syn}
  39. def print_as_table(d, title, cnames):
  40. print()
  41. header = f"{title:30} " + " ".join([f"{x:>7}" for x in cnames])
  42. print(header)
  43. print("-" * len(header))
  44. for k, l in d.items():
  45. print(f"{k:30}", " ".join([f"{x:>7.1f}" for x in l]))
  46. if __name__ == "__main__":
  47. parser = argparse.ArgumentParser(description="Benchmark script for LightGlue")
  48. parser.add_argument(
  49. "--device",
  50. choices=["auto", "cuda", "cpu", "mps"],
  51. default="auto",
  52. help="device to benchmark on",
  53. )
  54. parser.add_argument("--compile", action="store_true", help="Compile LightGlue runs")
  55. parser.add_argument(
  56. "--no_flash", action="store_true", help="disable FlashAttention"
  57. )
  58. parser.add_argument(
  59. "--no_prune_thresholds",
  60. action="store_true",
  61. help="disable pruning thresholds (i.e. always do pruning)",
  62. )
  63. parser.add_argument(
  64. "--add_superglue",
  65. action="store_true",
  66. help="add SuperGlue to the benchmark (requires hloc)",
  67. )
  68. parser.add_argument(
  69. "--measure", default="time", choices=["time", "log-time", "throughput"]
  70. )
  71. parser.add_argument(
  72. "--repeat", "--r", type=int, default=100, help="repetitions of measurements"
  73. )
  74. parser.add_argument(
  75. "--num_keypoints",
  76. nargs="+",
  77. type=int,
  78. default=[256, 512, 1024, 2048, 4096],
  79. help="number of keypoints (list separated by spaces)",
  80. )
  81. parser.add_argument(
  82. "--matmul_precision", default="highest", choices=["highest", "high", "medium"]
  83. )
  84. parser.add_argument(
  85. "--save", default=None, type=str, help="path where figure should be saved"
  86. )
  87. args = parser.parse_intermixed_args()
  88. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  89. if args.device != "auto":
  90. device = torch.device(args.device)
  91. print("Running benchmark on device:", device)
  92. images = Path("assets")
  93. inputs = {
  94. "easy": (
  95. load_image(images / "DSC_0411.JPG"),
  96. load_image(images / "DSC_0410.JPG"),
  97. ),
  98. "difficult": (
  99. load_image(images / "sacre_coeur1.jpg"),
  100. load_image(images / "sacre_coeur2.jpg"),
  101. ),
  102. }
  103. configs = {
  104. "LightGlue-full": {
  105. "depth_confidence": -1,
  106. "width_confidence": -1,
  107. },
  108. # 'LG-prune': {
  109. # 'width_confidence': -1,
  110. # },
  111. # 'LG-depth': {
  112. # 'depth_confidence': -1,
  113. # },
  114. "LightGlue-adaptive": {},
  115. }
  116. if args.compile:
  117. configs = {**configs, **{k + "-compile": v for k, v in configs.items()}}
  118. sg_configs = {
  119. # 'SuperGlue': {},
  120. "SuperGlue-fast": {"sinkhorn_iterations": 5}
  121. }
  122. torch.set_float32_matmul_precision(args.matmul_precision)
  123. results = {k: defaultdict(list) for k, v in inputs.items()}
  124. extractor = SuperPoint(max_num_keypoints=None, detection_threshold=-1)
  125. extractor = extractor.eval().to(device)
  126. figsize = (len(inputs) * 4.5, 4.5)
  127. fig, axes = plt.subplots(1, len(inputs), sharey=True, figsize=figsize)
  128. axes = axes if len(inputs) > 1 else [axes]
  129. fig.canvas.manager.set_window_title(f"LightGlue benchmark ({device.type})")
  130. for title, ax in zip(inputs.keys(), axes):
  131. ax.set_xscale("log", base=2)
  132. bases = [2**x for x in range(7, 16)]
  133. ax.set_xticks(bases, bases)
  134. ax.grid(which="major")
  135. if args.measure == "log-time":
  136. ax.set_yscale("log")
  137. yticks = [10**x for x in range(6)]
  138. ax.set_yticks(yticks, yticks)
  139. mpos = [10**x * i for x in range(6) for i in range(2, 10)]
  140. mlabel = [
  141. 10**x * i if i in [2, 5] else None
  142. for x in range(6)
  143. for i in range(2, 10)
  144. ]
  145. ax.set_yticks(mpos, mlabel, minor=True)
  146. ax.grid(which="minor", linewidth=0.2)
  147. ax.set_title(title)
  148. ax.set_xlabel("# keypoints")
  149. if args.measure == "throughput":
  150. ax.set_ylabel("Throughput [pairs/s]")
  151. else:
  152. ax.set_ylabel("Latency [ms]")
  153. for name, conf in configs.items():
  154. print("Run benchmark for:", name)
  155. torch.cuda.empty_cache()
  156. matcher = LightGlue(features="superpoint", flash=not args.no_flash, **conf)
  157. if args.no_prune_thresholds:
  158. matcher.pruning_keypoint_thresholds = {
  159. k: -1 for k in matcher.pruning_keypoint_thresholds
  160. }
  161. matcher = matcher.eval().to(device)
  162. if name.endswith("compile"):
  163. import torch._dynamo
  164. torch._dynamo.reset() # avoid buffer overflow
  165. matcher.compile()
  166. for pair_name, ax in zip(inputs.keys(), axes):
  167. image0, image1 = [x.to(device) for x in inputs[pair_name]]
  168. runtimes = []
  169. for num_kpts in args.num_keypoints:
  170. extractor.conf.max_num_keypoints = num_kpts
  171. feats0 = extractor.extract(image0)
  172. feats1 = extractor.extract(image1)
  173. runtime = measure(
  174. matcher,
  175. {"image0": feats0, "image1": feats1},
  176. device=device,
  177. r=args.repeat,
  178. )["mean"]
  179. results[pair_name][name].append(
  180. 1000 / runtime if args.measure == "throughput" else runtime
  181. )
  182. ax.plot(
  183. args.num_keypoints, results[pair_name][name], label=name, marker="o"
  184. )
  185. del matcher, feats0, feats1
  186. if args.add_superglue:
  187. from hloc.matchers.superglue import SuperGlue
  188. for name, conf in sg_configs.items():
  189. print("Run benchmark for:", name)
  190. matcher = SuperGlue(conf)
  191. matcher = matcher.eval().to(device)
  192. for pair_name, ax in zip(inputs.keys(), axes):
  193. image0, image1 = [x.to(device) for x in inputs[pair_name]]
  194. runtimes = []
  195. for num_kpts in args.num_keypoints:
  196. extractor.conf.max_num_keypoints = num_kpts
  197. feats0 = extractor.extract(image0)
  198. feats1 = extractor.extract(image1)
  199. data = {
  200. "image0": image0[None],
  201. "image1": image1[None],
  202. **{k + "0": v for k, v in feats0.items()},
  203. **{k + "1": v for k, v in feats1.items()},
  204. }
  205. data["scores0"] = data["keypoint_scores0"]
  206. data["scores1"] = data["keypoint_scores1"]
  207. data["descriptors0"] = (
  208. data["descriptors0"].transpose(-1, -2).contiguous()
  209. )
  210. data["descriptors1"] = (
  211. data["descriptors1"].transpose(-1, -2).contiguous()
  212. )
  213. runtime = measure(matcher, data, device=device, r=args.repeat)[
  214. "mean"
  215. ]
  216. results[pair_name][name].append(
  217. 1000 / runtime if args.measure == "throughput" else runtime
  218. )
  219. ax.plot(
  220. args.num_keypoints, results[pair_name][name], label=name, marker="o"
  221. )
  222. del matcher, data, image0, image1, feats0, feats1
  223. for name, runtimes in results.items():
  224. print_as_table(runtimes, name, args.num_keypoints)
  225. axes[0].legend()
  226. fig.tight_layout()
  227. if args.save:
  228. plt.savefig(args.save, dpi=fig.dpi)
  229. plt.show()