demo_roma_camera_position_async_unity_bridge.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533
  1. #!/usr/bin/env python3
  2. """
  3. RoMa Unity 联调入口(低耦合桥接层):
  4. - 保持原始 demo_roma_camera_position_async.py 不修改(只做算法演示)
  5. - 本脚本负责 Unity 联调所需的通信/控制/转发:
  6. - 输入:支持 --input udp://0.0.0.0:PORT(底层使用 udp_jpeg_receiver.py 的 FrameHeader 分片重组)
  7. - 图像转发:将最新 JPEG 转发给 Unity 显示(FrameHeader 分片,--forward_port)
  8. - 结果回传:将坐标结果发回 Unity(20B 协议,--result_port)
  9. - 控制口:n/s/r(--control_port),支持 dxcam 截屏(s)作为参考图
  10. 说明:
  11. - `udp_jpeg_receiver.py` 是按 orangepizero2w/UDP_PROTOCOL.md 的 FrameHeader 协议接收并重组 JPEG。
  12. - 这里的 `--input udp://...` 与原 demo 保持一致的参数风格,便于 NetworkConfig 占位符复用。
  13. """
  14. from __future__ import annotations
  15. import argparse
  16. import os
  17. import sys
  18. import time
  19. import socket
  20. import threading
  21. import json
  22. from pathlib import Path
  23. import cv2
  24. import numpy as np
  25. import torch
  26. from PIL import Image
  27. _ROMA_ROOT = Path(__file__).resolve().parents[1]
  28. if str(_ROMA_ROOT) not in sys.path:
  29. sys.path.insert(0, str(_ROMA_ROOT))
  30. from romatch import roma_outdoor, tiny_roma_v1_outdoor
  31. from udp_jpeg_receiver import UDPJPEGReceiver
  32. from udp_frame_forwarder import UDPFrameForwarder
  33. from udp_result_sender import UDPResultSender
  34. try:
  35. import dxcam
  36. try:
  37. _DXCAM_CAMERA = dxcam.create(output_color="RGB")
  38. DXCAM_AVAILABLE = True
  39. except Exception:
  40. _DXCAM_CAMERA = None
  41. DXCAM_AVAILABLE = False
  42. except ImportError:
  43. dxcam = None # type: ignore
  44. _DXCAM_CAMERA = None
  45. DXCAM_AVAILABLE = False
  46. torch.set_grad_enabled(False)
  47. def parse_args():
  48. p = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  49. p.add_argument("--input", type=str, default="udp://0.0.0.0:9000", help="udp://host:port (FrameHeader JPEG)")
  50. p.add_argument("--reference_image", type=str, default=None, help="Optional reference image")
  51. p.add_argument("--resize", type=int, nargs="+", default=[320, 240])
  52. p.add_argument("--model", type=str, default="tiny", choices=["tiny", "roma"])
  53. p.add_argument("--sample_num", type=int, default=500)
  54. p.add_argument("--sample_thresh", type=float, default=0.05)
  55. p.add_argument("--ransac_reproj_threshold", type=float, default=4.0)
  56. p.add_argument("--min_matches", type=int, default=120)
  57. p.add_argument("--min_inlier_ratio", type=float, default=0.08)
  58. p.add_argument("--roma_interval", type=int, default=10)
  59. p.add_argument("--smooth_alpha", type=float, default=0.9)
  60. p.add_argument("--trail_len", type=int, default=150)
  61. # Unity 相关
  62. p.add_argument("--result_ip", type=str, default="127.0.0.1")
  63. p.add_argument("--result_port", type=int, default=12348)
  64. p.add_argument("--control_port", type=int, default=12349)
  65. p.add_argument("--unity_ref_port", type=int, default=12347, help="Optional UDP port to receive Unity GameView JPEG for 'r' reference update")
  66. p.add_argument("--forward_ip", type=str, default="127.0.0.1")
  67. p.add_argument("--forward_port", type=int, default=12366)
  68. p.add_argument("--forward_fps", type=float, default=30.0)
  69. p.add_argument("--device_info_ip", type=str, default=None, help="Unity IP for device info JSON (defaults to --result_ip)")
  70. p.add_argument("--device_info_port", type=int, default=12350, help="Unity port for device info JSON")
  71. p.add_argument("--device_info_interval", type=float, default=1.0, help="Device info report interval seconds")
  72. p.add_argument("--show_fps", action="store_true")
  73. p.add_argument("--max_fps", type=float, default=90.0)
  74. p.add_argument("--max_display_fps", type=float, default=60.0)
  75. p.add_argument("--timer_print_interval", type=float, default=0.0)
  76. p.add_argument("--idle_sleep_ms", type=float, default=2.0)
  77. p.add_argument("--force_cpu", action="store_true")
  78. p.add_argument("--no_display", action="store_true")
  79. p.add_argument("--no_ui", action="store_true")
  80. p.add_argument("--log_control_events", action="store_true", help="打印 n/s/r 等参考图更新日志")
  81. p.add_argument("--log_timer", action="store_true", help="打印周期性 FPS/matches/inliers 日志(需 timer_print_interval>0)")
  82. p.add_argument("--log_send_result", action="store_true", help="打印周期性坐标发送日志(需 timer_print_interval>0)")
  83. return p.parse_args()
  84. def start_device_info_reporter(jpeg_receiver: UDPJPEGReceiver, unity_ip: str, unity_port: int, interval_s: float = 1.0):
  85. stop_evt = threading.Event()
  86. def _loop():
  87. sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
  88. last_payload = None
  89. last_send_t = 0.0
  90. while not stop_evt.is_set():
  91. try:
  92. addr = jpeg_receiver.get_latest_sender() if jpeg_receiver is not None else None
  93. if addr and isinstance(addr, tuple) and len(addr) >= 2:
  94. ip, port = addr[0], int(addr[1])
  95. payload = {"device_ip": str(ip), "device_port": port, "ts": time.time()}
  96. now = time.time()
  97. if payload != last_payload or (now - last_send_t) >= interval_s:
  98. data = json.dumps(payload).encode("utf-8")
  99. sock.sendto(data, (unity_ip, unity_port))
  100. last_payload = payload
  101. last_send_t = now
  102. except Exception:
  103. pass
  104. time.sleep(0.05)
  105. try:
  106. sock.close()
  107. except Exception:
  108. pass
  109. th = threading.Thread(target=_loop, daemon=True, name="DeviceInfoReporter")
  110. th.start()
  111. return stop_evt
  112. def maybe_resize(frame_bgr: np.ndarray, resize_opt) -> np.ndarray:
  113. if frame_bgr is None:
  114. return frame_bgr
  115. if len(resize_opt) == 2:
  116. return cv2.resize(frame_bgr, tuple(resize_opt))
  117. if len(resize_opt) == 1 and resize_opt[0] > 0:
  118. h, w = frame_bgr.shape[:2]
  119. scale = resize_opt[0] / max(h, w)
  120. return cv2.resize(frame_bgr, (int(w * scale), int(h * scale)))
  121. return frame_bgr
  122. def capture_screen_bgr(opt) -> np.ndarray | None:
  123. if not DXCAM_AVAILABLE or _DXCAM_CAMERA is None:
  124. return None
  125. try:
  126. rgb = _DXCAM_CAMERA.grab()
  127. if rgb is None:
  128. return None
  129. if len(rgb.shape) == 3 and rgb.shape[2] == 3:
  130. bgr = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR)
  131. elif len(rgb.shape) == 3 and rgb.shape[2] == 4:
  132. bgr = cv2.cvtColor(rgb, cv2.COLOR_BGRA2BGR)
  133. else:
  134. bgr = rgb if len(rgb.shape) == 3 else cv2.cvtColor(rgb, cv2.COLOR_GRAY2BGR)
  135. bgr = maybe_resize(bgr, opt.resize)
  136. return bgr
  137. except Exception:
  138. return None
  139. def draw_reference_view(ref_frame_bgr, camera_center_ref, is_valid, num_matches, inliers_ratio, trail_points):
  140. if ref_frame_bgr is None:
  141. return None
  142. vis = ref_frame_bgr.copy()
  143. h, w = vis.shape[:2]
  144. center = (w // 2, h // 2)
  145. cv2.circle(vis, center, 10, (0, 255, 0), 2)
  146. if trail_points and len(trail_points) >= 2:
  147. for i in range(1, len(trail_points)):
  148. p0 = (int(trail_points[i - 1][0]), int(trail_points[i - 1][1]))
  149. p1 = (int(trail_points[i][0]), int(trail_points[i][1]))
  150. cv2.line(vis, p0, p1, (255, 255, 0), 2, cv2.LINE_AA)
  151. if is_valid and camera_center_ref is not None:
  152. p = (int(camera_center_ref[0]), int(camera_center_ref[1]))
  153. cv2.circle(vis, p, 10, (0, 0, 255), 2)
  154. cv2.line(vis, center, p, (255, 0, 255), 2)
  155. cv2.putText(vis, f"matches={num_matches} inliers={inliers_ratio:.1%}", (10, 30),
  156. cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0) if is_valid else (0, 0, 255), 2)
  157. return vis
  158. def start_control_listener(port: int, on_n, on_s, on_r):
  159. if port <= 0:
  160. return None
  161. stop_event = threading.Event()
  162. def _worker():
  163. sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
  164. sock.bind(("0.0.0.0", port))
  165. sock.settimeout(0.5)
  166. print(f"[Control] Listening on 0.0.0.0:{port} (n/s/r)", flush=True)
  167. while not stop_event.is_set():
  168. try:
  169. try:
  170. data, _addr = sock.recvfrom(1024)
  171. except socket.timeout:
  172. continue
  173. if not data:
  174. continue
  175. c = data[0]
  176. if c in (ord("n"), ord("N"), 1):
  177. on_n()
  178. elif c == ord("s"):
  179. on_s()
  180. elif c == ord("r"):
  181. on_r()
  182. except OSError:
  183. break
  184. except Exception as exc:
  185. print(f"[Control] error: {exc}", flush=True)
  186. try:
  187. sock.close()
  188. except Exception:
  189. pass
  190. th = threading.Thread(target=_worker, daemon=True, name="RomaControlListener")
  191. th.start()
  192. return stop_event
  193. def start_unity_ref_receiver(port: int):
  194. if port <= 0:
  195. return None, None, None
  196. stop_event = threading.Event()
  197. latest_lock = threading.Lock()
  198. latest_jpeg = {"data": None}
  199. def _worker():
  200. sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
  201. sock.bind(("0.0.0.0", port))
  202. sock.settimeout(0.5)
  203. print(f"[UnityRef] Listening on 0.0.0.0:{port} (JPEG single-packet)", flush=True)
  204. while not stop_event.is_set():
  205. try:
  206. try:
  207. data, _addr = sock.recvfrom(65535)
  208. except socket.timeout:
  209. continue
  210. if not data or len(data) < 4:
  211. continue
  212. # basic JPEG validation
  213. if data[:2] != b"\xFF\xD8" or data[-2:] != b"\xFF\xD9":
  214. continue
  215. with latest_lock:
  216. latest_jpeg["data"] = data
  217. except OSError:
  218. break
  219. except Exception as exc:
  220. print(f"[UnityRef] error: {exc}", flush=True)
  221. try:
  222. sock.close()
  223. except Exception:
  224. pass
  225. th = threading.Thread(target=_worker, daemon=True, name="UnityRefReceiver")
  226. th.start()
  227. def _get_latest():
  228. with latest_lock:
  229. return latest_jpeg["data"]
  230. return stop_event, th, _get_latest
  231. def parse_udp_input(input_str: str):
  232. if not input_str.startswith("udp://"):
  233. raise ValueError("This bridge expects --input udp://host:port")
  234. parts = input_str.replace("udp://", "").split(":")
  235. if len(parts) == 2:
  236. host = parts[0] if parts[0] else "0.0.0.0"
  237. port = int(parts[1])
  238. else:
  239. host = "0.0.0.0"
  240. port = int(parts[0])
  241. return host, port
  242. def main():
  243. opt = parse_args()
  244. if opt.no_ui:
  245. sys.stdout = open(os.devnull, "w")
  246. sys.stderr = open(os.devnull, "w")
  247. device = "cuda" if torch.cuda.is_available() and not opt.force_cpu else "cpu"
  248. print(f'Running inference on device "{device}"', flush=True)
  249. roma_model = tiny_roma_v1_outdoor(device=torch.device(device)) if opt.model == "tiny" else roma_outdoor(device=torch.device(device))
  250. roma_model.sample_thresh = opt.sample_thresh
  251. # UDP input receiver (FrameHeader)
  252. host, port = parse_udp_input(opt.input)
  253. receiver = UDPJPEGReceiver(host=host, port=port, timeout=0.5, max_frame_buffers=8)
  254. receiver.start()
  255. forwarder = UDPFrameForwarder(target_ip=opt.forward_ip, target_port=opt.forward_port)
  256. result_sender = UDPResultSender(unity_ip=opt.result_ip, unity_port=opt.result_port)
  257. device_info_ip = opt.device_info_ip or opt.result_ip
  258. device_info_stop = start_device_info_reporter(
  259. jpeg_receiver=receiver,
  260. unity_ip=device_info_ip,
  261. unity_port=int(opt.device_info_port),
  262. interval_s=float(opt.device_info_interval),
  263. )
  264. # Control events
  265. set_ref_now_event = threading.Event()
  266. screen_ref_event = threading.Event()
  267. next_frame_ref_event = threading.Event()
  268. def _on_n(): set_ref_now_event.set()
  269. def _on_s(): screen_ref_event.set()
  270. def _on_r(): next_frame_ref_event.set()
  271. control_stop = start_control_listener(opt.control_port, _on_n, _on_s, _on_r)
  272. unity_ref_stop, unity_ref_thread, get_unity_ref_jpeg = start_unity_ref_receiver(opt.unity_ref_port)
  273. # Reference frame
  274. ref_frame = None
  275. if opt.reference_image:
  276. img = cv2.imread(opt.reference_image, cv2.IMREAD_COLOR)
  277. if img is not None:
  278. ref_frame = maybe_resize(img, opt.resize)
  279. last_good_H = None
  280. last_camera_center_ref = None
  281. trail_points = []
  282. fps_ema = 0.0
  283. last_timer = time.time()
  284. last_forward = 0.0
  285. last_display = 0.0
  286. try:
  287. while True:
  288. loop_start = time.time()
  289. # Control: screen capture as reference
  290. if screen_ref_event.is_set():
  291. screen_ref_event.clear()
  292. sc = capture_screen_bgr(opt)
  293. if sc is not None:
  294. ref_frame = sc.copy()
  295. last_good_H = None
  296. last_camera_center_ref = None
  297. trail_points = []
  298. if opt.log_control_events:
  299. print("[Control] Reference updated from screen capture", flush=True)
  300. elif not DXCAM_AVAILABLE:
  301. if opt.log_control_events:
  302. print("[Control] Screen capture skipped: pip install dxcam", flush=True)
  303. frame = receiver.get_image(timeout=0.2)
  304. if frame is None:
  305. if opt.idle_sleep_ms and opt.idle_sleep_ms > 0:
  306. time.sleep(opt.idle_sleep_ms / 1000.0)
  307. continue
  308. frame = maybe_resize(frame, opt.resize)
  309. # forward latest JPEG bytes (non-blocking)
  310. if opt.forward_fps <= 0 or (time.time() - last_forward) >= (1.0 / max(opt.forward_fps, 1e-6)):
  311. jpeg = receiver.get_latest_jpeg()
  312. if jpeg:
  313. forwarder.send_jpeg(jpeg)
  314. last_forward = time.time()
  315. if ref_frame is None or set_ref_now_event.is_set():
  316. set_ref_now_event.clear()
  317. ref_frame = frame.copy()
  318. last_good_H = None
  319. last_camera_center_ref = None
  320. trail_points = []
  321. if opt.log_control_events:
  322. print("[Control] Reference updated from current frame", flush=True)
  323. continue
  324. if next_frame_ref_event.is_set():
  325. next_frame_ref_event.clear()
  326. # 优先使用 Unity 发来的 GameView JPEG 作为参考图(若可用)
  327. applied = False
  328. if get_unity_ref_jpeg is not None:
  329. jpeg = get_unity_ref_jpeg()
  330. if jpeg is not None:
  331. arr = np.frombuffer(jpeg, dtype=np.uint8)
  332. img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
  333. if img is not None:
  334. ref_frame = maybe_resize(img, opt.resize)
  335. applied = True
  336. if not applied:
  337. ref_frame = frame.copy()
  338. last_good_H = None
  339. last_camera_center_ref = None
  340. trail_points = []
  341. print("[Control] Reference updated from Unity GameView JPEG (r)" if applied else "[Control] Reference updated from next frame (r)", flush=True)
  342. continue
  343. # RoMa estimate H_ref->cur
  344. h_ref, w_ref = ref_frame.shape[:2]
  345. h_cur, w_cur = frame.shape[:2]
  346. num_matches = 0
  347. inliers_ratio = 0.0
  348. H_ref_to_cur = None
  349. ref_pil = Image.fromarray(cv2.cvtColor(ref_frame, cv2.COLOR_BGR2RGB))
  350. cur_pil = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
  351. warp, certainty = roma_model.match(ref_pil, cur_pil) if opt.model == "tiny" else roma_model.match(ref_pil, cur_pil, device=torch.device(device))
  352. matches, _ = roma_model.sample(warp, certainty, num=opt.sample_num)
  353. k_ref, k_cur = roma_model.to_pixel_coordinates(matches, h_ref, w_ref, h_cur, w_cur)
  354. pts_ref = k_ref.detach().cpu().numpy().astype(np.float32)
  355. pts_cur = k_cur.detach().cpu().numpy().astype(np.float32)
  356. num_matches = len(pts_ref)
  357. if num_matches >= opt.min_matches:
  358. H_tmp, mask = cv2.findHomography(pts_ref, pts_cur, cv2.RANSAC, opt.ransac_reproj_threshold)
  359. if H_tmp is not None and mask is not None:
  360. inliers_ratio = float(mask.mean())
  361. if inliers_ratio >= opt.min_inlier_ratio:
  362. H_ref_to_cur = H_tmp
  363. last_good_H = H_tmp
  364. if H_ref_to_cur is None:
  365. H_ref_to_cur = last_good_H
  366. # camera center mapping
  367. camera_center_current = (w_cur // 2, h_cur // 2)
  368. camera_center_ref = None
  369. is_valid = False
  370. if H_ref_to_cur is not None:
  371. try:
  372. H_cur_to_ref = np.linalg.inv(H_ref_to_cur)
  373. center_cur = np.array([[camera_center_current]], dtype=np.float32)
  374. center_ref_now = cv2.perspectiveTransform(center_cur, H_cur_to_ref)[0, 0]
  375. center_ref_now = np.clip(center_ref_now, [0, 0], [w_ref - 1, h_ref - 1])
  376. if last_camera_center_ref is None:
  377. camera_center_ref = center_ref_now
  378. else:
  379. camera_center_ref = opt.smooth_alpha * last_camera_center_ref + (1.0 - opt.smooth_alpha) * center_ref_now
  380. last_camera_center_ref = camera_center_ref
  381. is_valid = True
  382. trail_points.append(camera_center_ref.copy())
  383. if len(trail_points) > opt.trail_len:
  384. trail_points = trail_points[-opt.trail_len:]
  385. except np.linalg.LinAlgError:
  386. is_valid = False
  387. # send result
  388. if is_valid and camera_center_ref is not None:
  389. result_sender.send_result(True, num_matches, inliers_ratio, float(camera_center_ref[0]), float(camera_center_ref[1]))
  390. else:
  391. result_sender.send_result(False, num_matches, 0.0, 0.0, 0.0)
  392. # display
  393. if not opt.no_display:
  394. nowp = time.time()
  395. if opt.max_display_fps <= 0 or (nowp - last_display) >= (1.0 / max(opt.max_display_fps, 1e-6)):
  396. ref_view = draw_reference_view(ref_frame, camera_center_ref, is_valid, num_matches, inliers_ratio, trail_points)
  397. if ref_view is not None and opt.show_fps:
  398. cv2.putText(ref_view, f"FPS: {fps_ema:.1f}", (10, 60),
  399. cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
  400. if ref_view is not None:
  401. cv2.imshow("Camera Position in Reference", ref_view)
  402. cv2.imshow("Live Camera", frame)
  403. key = cv2.waitKey(1) & 0xFF
  404. if key == ord("q"):
  405. break
  406. if key == ord("n"):
  407. set_ref_now_event.set()
  408. if key == ord("s"):
  409. screen_ref_event.set()
  410. last_display = nowp
  411. # FPS & throttle
  412. if opt.max_fps and opt.max_fps > 0:
  413. elapsed = time.time() - loop_start
  414. target_dt = 1.0 / float(opt.max_fps)
  415. if elapsed < target_dt:
  416. time.sleep(target_dt - elapsed)
  417. elapsed = time.time() - loop_start
  418. fps_ema = 0.9 * fps_ema + 0.1 * (1.0 / max(elapsed, 1e-6))
  419. if opt.timer_print_interval and opt.timer_print_interval > 0 and (time.time() - last_timer) >= opt.timer_print_interval:
  420. if opt.log_timer:
  421. print(f"[RomaBridge] FPS={fps_ema:.1f} matches={num_matches} inliers={inliers_ratio:.1%}", flush=True)
  422. if opt.log_send_result:
  423. if is_valid and camera_center_ref is not None:
  424. print(f"[RomaBridge] send_result valid=1 x={float(camera_center_ref[0]):.2f} y={float(camera_center_ref[1]):.2f} matches={num_matches} inliers={inliers_ratio:.1%}", flush=True)
  425. else:
  426. print(f"[RomaBridge] send_result valid=0 matches={num_matches}", flush=True)
  427. last_timer = time.time()
  428. finally:
  429. try:
  430. receiver.stop()
  431. except Exception:
  432. pass
  433. try:
  434. forwarder.close()
  435. except Exception:
  436. pass
  437. try:
  438. result_sender.close()
  439. except Exception:
  440. pass
  441. if device_info_stop is not None:
  442. try:
  443. device_info_stop.set()
  444. except Exception:
  445. pass
  446. if control_stop is not None:
  447. control_stop.set()
  448. if unity_ref_stop is not None:
  449. unity_ref_stop.set()
  450. try:
  451. cv2.destroyAllWindows()
  452. except Exception:
  453. pass
  454. if __name__ == "__main__":
  455. main()