| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255 |
- # Benchmark script for LightGlue on real images
- import argparse
- import time
- from collections import defaultdict
- from pathlib import Path
- import matplotlib.pyplot as plt
- import numpy as np
- import torch
- import torch._dynamo
- from lightglue import LightGlue, SuperPoint
- from lightglue.utils import load_image
- torch.set_grad_enabled(False)
- def measure(matcher, data, device="cuda", r=100):
- timings = np.zeros((r, 1))
- if device.type == "cuda":
- starter = torch.cuda.Event(enable_timing=True)
- ender = torch.cuda.Event(enable_timing=True)
- # warmup
- for _ in range(10):
- _ = matcher(data)
- # measurements
- with torch.no_grad():
- for rep in range(r):
- if device.type == "cuda":
- starter.record()
- _ = matcher(data)
- ender.record()
- # sync gpu
- torch.cuda.synchronize()
- curr_time = starter.elapsed_time(ender)
- else:
- start = time.perf_counter()
- _ = matcher(data)
- curr_time = (time.perf_counter() - start) * 1e3
- timings[rep] = curr_time
- mean_syn = np.sum(timings) / r
- std_syn = np.std(timings)
- return {"mean": mean_syn, "std": std_syn}
- def print_as_table(d, title, cnames):
- print()
- header = f"{title:30} " + " ".join([f"{x:>7}" for x in cnames])
- print(header)
- print("-" * len(header))
- for k, l in d.items():
- print(f"{k:30}", " ".join([f"{x:>7.1f}" for x in l]))
- if __name__ == "__main__":
- parser = argparse.ArgumentParser(description="Benchmark script for LightGlue")
- parser.add_argument(
- "--device",
- choices=["auto", "cuda", "cpu", "mps"],
- default="auto",
- help="device to benchmark on",
- )
- parser.add_argument("--compile", action="store_true", help="Compile LightGlue runs")
- parser.add_argument(
- "--no_flash", action="store_true", help="disable FlashAttention"
- )
- parser.add_argument(
- "--no_prune_thresholds",
- action="store_true",
- help="disable pruning thresholds (i.e. always do pruning)",
- )
- parser.add_argument(
- "--add_superglue",
- action="store_true",
- help="add SuperGlue to the benchmark (requires hloc)",
- )
- parser.add_argument(
- "--measure", default="time", choices=["time", "log-time", "throughput"]
- )
- parser.add_argument(
- "--repeat", "--r", type=int, default=100, help="repetitions of measurements"
- )
- parser.add_argument(
- "--num_keypoints",
- nargs="+",
- type=int,
- default=[256, 512, 1024, 2048, 4096],
- help="number of keypoints (list separated by spaces)",
- )
- parser.add_argument(
- "--matmul_precision", default="highest", choices=["highest", "high", "medium"]
- )
- parser.add_argument(
- "--save", default=None, type=str, help="path where figure should be saved"
- )
- args = parser.parse_intermixed_args()
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- if args.device != "auto":
- device = torch.device(args.device)
- print("Running benchmark on device:", device)
- images = Path("assets")
- inputs = {
- "easy": (
- load_image(images / "DSC_0411.JPG"),
- load_image(images / "DSC_0410.JPG"),
- ),
- "difficult": (
- load_image(images / "sacre_coeur1.jpg"),
- load_image(images / "sacre_coeur2.jpg"),
- ),
- }
- configs = {
- "LightGlue-full": {
- "depth_confidence": -1,
- "width_confidence": -1,
- },
- # 'LG-prune': {
- # 'width_confidence': -1,
- # },
- # 'LG-depth': {
- # 'depth_confidence': -1,
- # },
- "LightGlue-adaptive": {},
- }
- if args.compile:
- configs = {**configs, **{k + "-compile": v for k, v in configs.items()}}
- sg_configs = {
- # 'SuperGlue': {},
- "SuperGlue-fast": {"sinkhorn_iterations": 5}
- }
- torch.set_float32_matmul_precision(args.matmul_precision)
- results = {k: defaultdict(list) for k, v in inputs.items()}
- extractor = SuperPoint(max_num_keypoints=None, detection_threshold=-1)
- extractor = extractor.eval().to(device)
- figsize = (len(inputs) * 4.5, 4.5)
- fig, axes = plt.subplots(1, len(inputs), sharey=True, figsize=figsize)
- axes = axes if len(inputs) > 1 else [axes]
- fig.canvas.manager.set_window_title(f"LightGlue benchmark ({device.type})")
- for title, ax in zip(inputs.keys(), axes):
- ax.set_xscale("log", base=2)
- bases = [2**x for x in range(7, 16)]
- ax.set_xticks(bases, bases)
- ax.grid(which="major")
- if args.measure == "log-time":
- ax.set_yscale("log")
- yticks = [10**x for x in range(6)]
- ax.set_yticks(yticks, yticks)
- mpos = [10**x * i for x in range(6) for i in range(2, 10)]
- mlabel = [
- 10**x * i if i in [2, 5] else None
- for x in range(6)
- for i in range(2, 10)
- ]
- ax.set_yticks(mpos, mlabel, minor=True)
- ax.grid(which="minor", linewidth=0.2)
- ax.set_title(title)
- ax.set_xlabel("# keypoints")
- if args.measure == "throughput":
- ax.set_ylabel("Throughput [pairs/s]")
- else:
- ax.set_ylabel("Latency [ms]")
- for name, conf in configs.items():
- print("Run benchmark for:", name)
- torch.cuda.empty_cache()
- matcher = LightGlue(features="superpoint", flash=not args.no_flash, **conf)
- if args.no_prune_thresholds:
- matcher.pruning_keypoint_thresholds = {
- k: -1 for k in matcher.pruning_keypoint_thresholds
- }
- matcher = matcher.eval().to(device)
- if name.endswith("compile"):
- import torch._dynamo
- torch._dynamo.reset() # avoid buffer overflow
- matcher.compile()
- for pair_name, ax in zip(inputs.keys(), axes):
- image0, image1 = [x.to(device) for x in inputs[pair_name]]
- runtimes = []
- for num_kpts in args.num_keypoints:
- extractor.conf.max_num_keypoints = num_kpts
- feats0 = extractor.extract(image0)
- feats1 = extractor.extract(image1)
- runtime = measure(
- matcher,
- {"image0": feats0, "image1": feats1},
- device=device,
- r=args.repeat,
- )["mean"]
- results[pair_name][name].append(
- 1000 / runtime if args.measure == "throughput" else runtime
- )
- ax.plot(
- args.num_keypoints, results[pair_name][name], label=name, marker="o"
- )
- del matcher, feats0, feats1
- if args.add_superglue:
- from hloc.matchers.superglue import SuperGlue
- for name, conf in sg_configs.items():
- print("Run benchmark for:", name)
- matcher = SuperGlue(conf)
- matcher = matcher.eval().to(device)
- for pair_name, ax in zip(inputs.keys(), axes):
- image0, image1 = [x.to(device) for x in inputs[pair_name]]
- runtimes = []
- for num_kpts in args.num_keypoints:
- extractor.conf.max_num_keypoints = num_kpts
- feats0 = extractor.extract(image0)
- feats1 = extractor.extract(image1)
- data = {
- "image0": image0[None],
- "image1": image1[None],
- **{k + "0": v for k, v in feats0.items()},
- **{k + "1": v for k, v in feats1.items()},
- }
- data["scores0"] = data["keypoint_scores0"]
- data["scores1"] = data["keypoint_scores1"]
- data["descriptors0"] = (
- data["descriptors0"].transpose(-1, -2).contiguous()
- )
- data["descriptors1"] = (
- data["descriptors1"].transpose(-1, -2).contiguous()
- )
- runtime = measure(matcher, data, device=device, r=args.repeat)[
- "mean"
- ]
- results[pair_name][name].append(
- 1000 / runtime if args.measure == "throughput" else runtime
- )
- ax.plot(
- args.num_keypoints, results[pair_name][name], label=name, marker="o"
- )
- del matcher, data, image0, image1, feats0, feats1
- for name, runtimes in results.items():
- print_as_table(runtimes, name, args.num_keypoints)
- axes[0].legend()
- fig.tight_layout()
- if args.save:
- plt.savefig(args.save, dpi=fig.dpi)
- plt.show()
|