demo_loftr.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. front_matter = """
  2. ------------------------------------------------------------------------
  3. Online demo for [LoFTR](https://zju3dv.github.io/loftr/).
  4. This demo is heavily inspired by [SuperGlue](https://github.com/magicleap/SuperGluePretrainedNetwork/).
  5. We thank the authors for their execellent work.
  6. ------------------------------------------------------------------------
  7. """
  8. import os
  9. import argparse
  10. from pathlib import Path
  11. import cv2
  12. import torch
  13. import numpy as np
  14. import matplotlib.cm as cm
  15. os.sys.path.append("../") # Add the project directory
  16. from src.loftr import LoFTR, default_cfg
  17. from src.config.default import get_cfg_defaults
  18. try:
  19. from demo.utils import (AverageTimer, VideoStreamer,
  20. make_matching_plot_fast, make_matching_plot, frame2tensor)
  21. except:
  22. raise ImportError("This demo requires utils.py from SuperGlue, please use run_demo.sh to start this script.")
  23. torch.set_grad_enabled(False)
  24. if __name__ == '__main__':
  25. parser = argparse.ArgumentParser(
  26. description='LoFTR online demo',
  27. formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  28. parser.add_argument('--weight', type=str, help="Path to the checkpoint.")
  29. parser.add_argument(
  30. '--input', type=str, default='0',
  31. help='ID of a USB webcam, URL of an IP camera, '
  32. 'or path to an image directory or movie file')
  33. parser.add_argument(
  34. '--output_dir', type=str, default=None,
  35. help='Directory where to write output frames (If None, no output)')
  36. parser.add_argument(
  37. '--image_glob', type=str, nargs='+', default=['*.png', '*.jpg', '*.jpeg'],
  38. help='Glob if a directory of images is specified')
  39. parser.add_argument(
  40. '--skip', type=int, default=1,
  41. help='Images to skip if input is a movie or directory')
  42. parser.add_argument(
  43. '--max_length', type=int, default=1000000,
  44. help='Maximum length if input is a movie or directory')
  45. parser.add_argument(
  46. '--resize', type=int, nargs='+', default=[640, 480],
  47. help='Resize the input image before running inference. If two numbers, '
  48. 'resize to the exact dimensions, if one number, resize the max '
  49. 'dimension, if -1, do not resize')
  50. parser.add_argument(
  51. '--no_display', action='store_true',
  52. help='Do not display images to screen. Useful if running remotely')
  53. parser.add_argument(
  54. '--save_video', action='store_true',
  55. help='Save output (with match visualizations) to a video.')
  56. parser.add_argument(
  57. '--save_input', action='store_true',
  58. help='Save the input images to a video (for gathering repeatable input source).')
  59. parser.add_argument(
  60. '--skip_frames', type=int, default=1,
  61. help="Skip frames from webcam input.")
  62. parser.add_argument(
  63. '--top_k', type=int, default=2000, help="The max vis_range (please refer to the code).")
  64. parser.add_argument(
  65. '--bottom_k', type=int, default=0, help="The min vis_range (please refer to the code).")
  66. opt = parser.parse_args()
  67. print(front_matter)
  68. parser.print_help()
  69. if len(opt.resize) == 2 and opt.resize[1] == -1:
  70. opt.resize = opt.resize[0:1]
  71. if len(opt.resize) == 2:
  72. print('Will resize to {}x{} (WxH)'.format(
  73. opt.resize[0], opt.resize[1]))
  74. elif len(opt.resize) == 1 and opt.resize[0] > 0:
  75. print('Will resize max dimension to {}'.format(opt.resize[0]))
  76. elif len(opt.resize) == 1:
  77. print('Will not resize images')
  78. else:
  79. raise ValueError('Cannot specify more than two integers for --resize')
  80. if torch.cuda.is_available():
  81. device = 'cuda'
  82. else:
  83. raise RuntimeError("GPU is required to run this demo.")
  84. # Initialize LoFTR
  85. matcher = LoFTR(config=default_cfg)
  86. matcher.load_state_dict(torch.load(opt.weight)['state_dict'])
  87. matcher = matcher.eval().to(device=device)
  88. # Configure I/O
  89. if opt.save_video:
  90. print('Writing video to loftr-matches.mp4...')
  91. writer = cv2.VideoWriter('loftr-matches.mp4', cv2.VideoWriter_fourcc(*'mp4v'), 15, (640*2 + 10, 480))
  92. if opt.save_input:
  93. print('Writing video to demo-input.mp4...')
  94. input_writer = cv2.VideoWriter('demo-input.mp4', cv2.VideoWriter_fourcc(*'mp4v'), 15, (640, 480))
  95. vs = VideoStreamer(opt.input, opt.resize, opt.skip,
  96. opt.image_glob, opt.max_length)
  97. frame, ret = vs.next_frame()
  98. assert ret, 'Error when reading the first frame (try different --input?)'
  99. frame_id = 0
  100. last_image_id = 0
  101. frame_tensor = frame2tensor(frame, device)
  102. last_data = {'image0': frame_tensor}
  103. last_frame = frame
  104. if opt.output_dir is not None:
  105. print('==> Will write outputs to {}'.format(opt.output_dir))
  106. Path(opt.output_dir).mkdir(exist_ok=True)
  107. # Create a window to display the demo.
  108. if not opt.no_display:
  109. window_name = 'LoFTR Matches'
  110. cv2.namedWindow(window_name, cv2.WINDOW_NORMAL)
  111. cv2.resizeWindow(window_name, (640*2, 480))
  112. else:
  113. print('Skipping visualization, will not show a GUI.')
  114. # Print the keyboard help menu.
  115. print('==> Keyboard control:\n'
  116. '\tn: select the current frame as the reference image (left)\n'
  117. '\td/f: move the range of the matches (ranked by confidence) to visualize\n'
  118. '\tc/v: increase/decrease the length of the visualization range (i.e., total number of matches) to show\n'
  119. '\tq: quit')
  120. timer = AverageTimer()
  121. vis_range = [opt.bottom_k, opt.top_k]
  122. while True:
  123. frame_id += 1
  124. frame, ret = vs.next_frame()
  125. if frame_id % opt.skip_frames != 0:
  126. # print("Skipping frame.")
  127. continue
  128. if opt.save_input:
  129. inp = np.stack([frame]*3, -1)
  130. inp_rgb = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
  131. input_writer.write(inp_rgb)
  132. if not ret:
  133. print('Finished demo_loftr.py')
  134. break
  135. timer.update('data')
  136. stem0, stem1 = last_image_id, vs.i - 1
  137. frame_tensor = frame2tensor(frame, device)
  138. last_data = {**last_data, 'image1': frame_tensor}
  139. matcher(last_data)
  140. total_n_matches = len(last_data['mkpts0_f'])
  141. mkpts0 = last_data['mkpts0_f'].cpu().numpy()[vis_range[0]:vis_range[1]]
  142. mkpts1 = last_data['mkpts1_f'].cpu().numpy()[vis_range[0]:vis_range[1]]
  143. mconf = last_data['mconf'].cpu().numpy()[vis_range[0]:vis_range[1]]
  144. # Normalize confidence.
  145. if len(mconf) > 0:
  146. conf_vis_min = 0.
  147. conf_min = mconf.min()
  148. conf_max = mconf.max()
  149. mconf = (mconf - conf_vis_min) / (conf_max - conf_vis_min + 1e-5)
  150. timer.update('forward')
  151. alpha = 0
  152. color = cm.jet(mconf, alpha=alpha)
  153. text = [
  154. f'LoFTR',
  155. '# Matches (showing/total): {}/{}'.format(len(mkpts0), total_n_matches),
  156. ]
  157. small_text = [
  158. f'Showing matches from {vis_range[0]}:{vis_range[1]}',
  159. f'Confidence Range: {conf_min:.2f}:{conf_max:.2f}',
  160. 'Image Pair: {:06}:{:06}'.format(stem0, stem1),
  161. ]
  162. out = make_matching_plot_fast(
  163. last_frame, frame, mkpts0, mkpts1, mkpts0, mkpts1, color, text,
  164. path=None, show_keypoints=False, small_text=small_text)
  165. # Save high quality png, optionally with dynamic alpha support (unreleased yet).
  166. # save_path = 'demo_vid/{:06}'.format(frame_id)
  167. # make_matching_plot(
  168. # last_frame, frame, mkpts0, mkpts1, mkpts0, mkpts1, color, text,
  169. # path=save_path, show_keypoints=opt.show_keypoints, small_text=small_text)
  170. if not opt.no_display:
  171. if opt.save_video:
  172. writer.write(out)
  173. cv2.imshow('LoFTR Matches', out)
  174. key = chr(cv2.waitKey(1) & 0xFF)
  175. if key == 'q':
  176. if opt.save_video:
  177. writer.release()
  178. if opt.save_input:
  179. input_writer.release()
  180. vs.cleanup()
  181. print('Exiting...')
  182. break
  183. elif key == 'n':
  184. last_data['image0'] = frame_tensor
  185. last_frame = frame
  186. last_image_id = (vs.i - 1)
  187. frame_id_left = frame_id
  188. elif key in ['d', 'f']:
  189. if key == 'd':
  190. if vis_range[0] >= 0:
  191. vis_range[0] -= 200
  192. vis_range[1] -= 200
  193. if key =='f':
  194. vis_range[0] += 200
  195. vis_range[1] += 200
  196. print(f'\nChanged the vis_range to {vis_range[0]}:{vis_range[1]}')
  197. elif key in ['c', 'v']:
  198. if key == 'c':
  199. vis_range[1] -= 50
  200. if key =='v':
  201. vis_range[1] += 50
  202. print(f'\nChanged the vis_range[1] to {vis_range[1]}')
  203. elif opt.output_dir is not None:
  204. stem = 'matches_{:06}_{:06}'.format(stem0, stem1)
  205. out_file = str(Path(opt.output_dir, stem + '.png'))
  206. print('\nWriting image to {}'.format(out_file))
  207. cv2.imwrite(out_file, out)
  208. else:
  209. raise ValueError("output_dir is required when no display is given.")
  210. timer.update('viz')
  211. timer.print()
  212. cv2.destroyAllWindows()
  213. vs.cleanup()