| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533 |
- #!/usr/bin/env python3
- """
- RoMa Unity 联调入口(低耦合桥接层):
- - 保持原始 demo_roma_camera_position_async.py 不修改(只做算法演示)
- - 本脚本负责 Unity 联调所需的通信/控制/转发:
- - 输入:支持 --input udp://0.0.0.0:PORT(底层使用 udp_jpeg_receiver.py 的 FrameHeader 分片重组)
- - 图像转发:将最新 JPEG 转发给 Unity 显示(FrameHeader 分片,--forward_port)
- - 结果回传:将坐标结果发回 Unity(20B 协议,--result_port)
- - 控制口:n/s/r(--control_port),支持 dxcam 截屏(s)作为参考图
- 说明:
- - `udp_jpeg_receiver.py` 是按 orangepizero2w/UDP_PROTOCOL.md 的 FrameHeader 协议接收并重组 JPEG。
- - 这里的 `--input udp://...` 与原 demo 保持一致的参数风格,便于 NetworkConfig 占位符复用。
- """
- from __future__ import annotations
- import argparse
- import os
- import sys
- import time
- import socket
- import threading
- import json
- from pathlib import Path
- import cv2
- import numpy as np
- import torch
- from PIL import Image
- _ROMA_ROOT = Path(__file__).resolve().parents[1]
- if str(_ROMA_ROOT) not in sys.path:
- sys.path.insert(0, str(_ROMA_ROOT))
- from romatch import roma_outdoor, tiny_roma_v1_outdoor
- from udp_jpeg_receiver import UDPJPEGReceiver
- from udp_frame_forwarder import UDPFrameForwarder
- from udp_result_sender import UDPResultSender
- try:
- import dxcam
- try:
- _DXCAM_CAMERA = dxcam.create(output_color="RGB")
- DXCAM_AVAILABLE = True
- except Exception:
- _DXCAM_CAMERA = None
- DXCAM_AVAILABLE = False
- except ImportError:
- dxcam = None # type: ignore
- _DXCAM_CAMERA = None
- DXCAM_AVAILABLE = False
- torch.set_grad_enabled(False)
- def parse_args():
- p = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
- p.add_argument("--input", type=str, default="udp://0.0.0.0:9000", help="udp://host:port (FrameHeader JPEG)")
- p.add_argument("--reference_image", type=str, default=None, help="Optional reference image")
- p.add_argument("--resize", type=int, nargs="+", default=[320, 240])
- p.add_argument("--model", type=str, default="tiny", choices=["tiny", "roma"])
- p.add_argument("--sample_num", type=int, default=500)
- p.add_argument("--sample_thresh", type=float, default=0.05)
- p.add_argument("--ransac_reproj_threshold", type=float, default=4.0)
- p.add_argument("--min_matches", type=int, default=120)
- p.add_argument("--min_inlier_ratio", type=float, default=0.08)
- p.add_argument("--roma_interval", type=int, default=10)
- p.add_argument("--smooth_alpha", type=float, default=0.9)
- p.add_argument("--trail_len", type=int, default=150)
- # Unity 相关
- p.add_argument("--result_ip", type=str, default="127.0.0.1")
- p.add_argument("--result_port", type=int, default=12348)
- p.add_argument("--control_port", type=int, default=12349)
- p.add_argument("--unity_ref_port", type=int, default=12347, help="Optional UDP port to receive Unity GameView JPEG for 'r' reference update")
- p.add_argument("--forward_ip", type=str, default="127.0.0.1")
- p.add_argument("--forward_port", type=int, default=12366)
- p.add_argument("--forward_fps", type=float, default=30.0)
- p.add_argument("--device_info_ip", type=str, default=None, help="Unity IP for device info JSON (defaults to --result_ip)")
- p.add_argument("--device_info_port", type=int, default=12350, help="Unity port for device info JSON")
- p.add_argument("--device_info_interval", type=float, default=1.0, help="Device info report interval seconds")
- p.add_argument("--show_fps", action="store_true")
- p.add_argument("--max_fps", type=float, default=90.0)
- p.add_argument("--max_display_fps", type=float, default=60.0)
- p.add_argument("--timer_print_interval", type=float, default=0.0)
- p.add_argument("--idle_sleep_ms", type=float, default=2.0)
- p.add_argument("--force_cpu", action="store_true")
- p.add_argument("--no_display", action="store_true")
- p.add_argument("--no_ui", action="store_true")
- p.add_argument("--log_control_events", action="store_true", help="打印 n/s/r 等参考图更新日志")
- p.add_argument("--log_timer", action="store_true", help="打印周期性 FPS/matches/inliers 日志(需 timer_print_interval>0)")
- p.add_argument("--log_send_result", action="store_true", help="打印周期性坐标发送日志(需 timer_print_interval>0)")
- return p.parse_args()
- def start_device_info_reporter(jpeg_receiver: UDPJPEGReceiver, unity_ip: str, unity_port: int, interval_s: float = 1.0):
- stop_evt = threading.Event()
- def _loop():
- sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
- last_payload = None
- last_send_t = 0.0
- while not stop_evt.is_set():
- try:
- addr = jpeg_receiver.get_latest_sender() if jpeg_receiver is not None else None
- if addr and isinstance(addr, tuple) and len(addr) >= 2:
- ip, port = addr[0], int(addr[1])
- payload = {"device_ip": str(ip), "device_port": port, "ts": time.time()}
- now = time.time()
- if payload != last_payload or (now - last_send_t) >= interval_s:
- data = json.dumps(payload).encode("utf-8")
- sock.sendto(data, (unity_ip, unity_port))
- last_payload = payload
- last_send_t = now
- except Exception:
- pass
- time.sleep(0.05)
- try:
- sock.close()
- except Exception:
- pass
- th = threading.Thread(target=_loop, daemon=True, name="DeviceInfoReporter")
- th.start()
- return stop_evt
- def maybe_resize(frame_bgr: np.ndarray, resize_opt) -> np.ndarray:
- if frame_bgr is None:
- return frame_bgr
- if len(resize_opt) == 2:
- return cv2.resize(frame_bgr, tuple(resize_opt))
- if len(resize_opt) == 1 and resize_opt[0] > 0:
- h, w = frame_bgr.shape[:2]
- scale = resize_opt[0] / max(h, w)
- return cv2.resize(frame_bgr, (int(w * scale), int(h * scale)))
- return frame_bgr
- def capture_screen_bgr(opt) -> np.ndarray | None:
- if not DXCAM_AVAILABLE or _DXCAM_CAMERA is None:
- return None
- try:
- rgb = _DXCAM_CAMERA.grab()
- if rgb is None:
- return None
- if len(rgb.shape) == 3 and rgb.shape[2] == 3:
- bgr = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR)
- elif len(rgb.shape) == 3 and rgb.shape[2] == 4:
- bgr = cv2.cvtColor(rgb, cv2.COLOR_BGRA2BGR)
- else:
- bgr = rgb if len(rgb.shape) == 3 else cv2.cvtColor(rgb, cv2.COLOR_GRAY2BGR)
- bgr = maybe_resize(bgr, opt.resize)
- return bgr
- except Exception:
- return None
- def draw_reference_view(ref_frame_bgr, camera_center_ref, is_valid, num_matches, inliers_ratio, trail_points):
- if ref_frame_bgr is None:
- return None
- vis = ref_frame_bgr.copy()
- h, w = vis.shape[:2]
- center = (w // 2, h // 2)
- cv2.circle(vis, center, 10, (0, 255, 0), 2)
- if trail_points and len(trail_points) >= 2:
- for i in range(1, len(trail_points)):
- p0 = (int(trail_points[i - 1][0]), int(trail_points[i - 1][1]))
- p1 = (int(trail_points[i][0]), int(trail_points[i][1]))
- cv2.line(vis, p0, p1, (255, 255, 0), 2, cv2.LINE_AA)
- if is_valid and camera_center_ref is not None:
- p = (int(camera_center_ref[0]), int(camera_center_ref[1]))
- cv2.circle(vis, p, 10, (0, 0, 255), 2)
- cv2.line(vis, center, p, (255, 0, 255), 2)
- cv2.putText(vis, f"matches={num_matches} inliers={inliers_ratio:.1%}", (10, 30),
- cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0) if is_valid else (0, 0, 255), 2)
- return vis
- def start_control_listener(port: int, on_n, on_s, on_r):
- if port <= 0:
- return None
- stop_event = threading.Event()
- def _worker():
- sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
- sock.bind(("0.0.0.0", port))
- sock.settimeout(0.5)
- print(f"[Control] Listening on 0.0.0.0:{port} (n/s/r)", flush=True)
- while not stop_event.is_set():
- try:
- try:
- data, _addr = sock.recvfrom(1024)
- except socket.timeout:
- continue
- if not data:
- continue
- c = data[0]
- if c in (ord("n"), ord("N"), 1):
- on_n()
- elif c == ord("s"):
- on_s()
- elif c == ord("r"):
- on_r()
- except OSError:
- break
- except Exception as exc:
- print(f"[Control] error: {exc}", flush=True)
- try:
- sock.close()
- except Exception:
- pass
- th = threading.Thread(target=_worker, daemon=True, name="RomaControlListener")
- th.start()
- return stop_event
- def start_unity_ref_receiver(port: int):
- if port <= 0:
- return None, None, None
- stop_event = threading.Event()
- latest_lock = threading.Lock()
- latest_jpeg = {"data": None}
- def _worker():
- sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
- sock.bind(("0.0.0.0", port))
- sock.settimeout(0.5)
- print(f"[UnityRef] Listening on 0.0.0.0:{port} (JPEG single-packet)", flush=True)
- while not stop_event.is_set():
- try:
- try:
- data, _addr = sock.recvfrom(65535)
- except socket.timeout:
- continue
- if not data or len(data) < 4:
- continue
- # basic JPEG validation
- if data[:2] != b"\xFF\xD8" or data[-2:] != b"\xFF\xD9":
- continue
- with latest_lock:
- latest_jpeg["data"] = data
- except OSError:
- break
- except Exception as exc:
- print(f"[UnityRef] error: {exc}", flush=True)
- try:
- sock.close()
- except Exception:
- pass
- th = threading.Thread(target=_worker, daemon=True, name="UnityRefReceiver")
- th.start()
- def _get_latest():
- with latest_lock:
- return latest_jpeg["data"]
- return stop_event, th, _get_latest
- def parse_udp_input(input_str: str):
- if not input_str.startswith("udp://"):
- raise ValueError("This bridge expects --input udp://host:port")
- parts = input_str.replace("udp://", "").split(":")
- if len(parts) == 2:
- host = parts[0] if parts[0] else "0.0.0.0"
- port = int(parts[1])
- else:
- host = "0.0.0.0"
- port = int(parts[0])
- return host, port
- def main():
- opt = parse_args()
- if opt.no_ui:
- sys.stdout = open(os.devnull, "w")
- sys.stderr = open(os.devnull, "w")
- device = "cuda" if torch.cuda.is_available() and not opt.force_cpu else "cpu"
- print(f'Running inference on device "{device}"', flush=True)
- roma_model = tiny_roma_v1_outdoor(device=torch.device(device)) if opt.model == "tiny" else roma_outdoor(device=torch.device(device))
- roma_model.sample_thresh = opt.sample_thresh
- # UDP input receiver (FrameHeader)
- host, port = parse_udp_input(opt.input)
- receiver = UDPJPEGReceiver(host=host, port=port, timeout=0.5, max_frame_buffers=8)
- receiver.start()
- forwarder = UDPFrameForwarder(target_ip=opt.forward_ip, target_port=opt.forward_port)
- result_sender = UDPResultSender(unity_ip=opt.result_ip, unity_port=opt.result_port)
- device_info_ip = opt.device_info_ip or opt.result_ip
- device_info_stop = start_device_info_reporter(
- jpeg_receiver=receiver,
- unity_ip=device_info_ip,
- unity_port=int(opt.device_info_port),
- interval_s=float(opt.device_info_interval),
- )
- # Control events
- set_ref_now_event = threading.Event()
- screen_ref_event = threading.Event()
- next_frame_ref_event = threading.Event()
- def _on_n(): set_ref_now_event.set()
- def _on_s(): screen_ref_event.set()
- def _on_r(): next_frame_ref_event.set()
- control_stop = start_control_listener(opt.control_port, _on_n, _on_s, _on_r)
- unity_ref_stop, unity_ref_thread, get_unity_ref_jpeg = start_unity_ref_receiver(opt.unity_ref_port)
- # Reference frame
- ref_frame = None
- if opt.reference_image:
- img = cv2.imread(opt.reference_image, cv2.IMREAD_COLOR)
- if img is not None:
- ref_frame = maybe_resize(img, opt.resize)
- last_good_H = None
- last_camera_center_ref = None
- trail_points = []
- fps_ema = 0.0
- last_timer = time.time()
- last_forward = 0.0
- last_display = 0.0
- try:
- while True:
- loop_start = time.time()
- # Control: screen capture as reference
- if screen_ref_event.is_set():
- screen_ref_event.clear()
- sc = capture_screen_bgr(opt)
- if sc is not None:
- ref_frame = sc.copy()
- last_good_H = None
- last_camera_center_ref = None
- trail_points = []
- if opt.log_control_events:
- print("[Control] Reference updated from screen capture", flush=True)
- elif not DXCAM_AVAILABLE:
- if opt.log_control_events:
- print("[Control] Screen capture skipped: pip install dxcam", flush=True)
- frame = receiver.get_image(timeout=0.2)
- if frame is None:
- if opt.idle_sleep_ms and opt.idle_sleep_ms > 0:
- time.sleep(opt.idle_sleep_ms / 1000.0)
- continue
- frame = maybe_resize(frame, opt.resize)
- # forward latest JPEG bytes (non-blocking)
- if opt.forward_fps <= 0 or (time.time() - last_forward) >= (1.0 / max(opt.forward_fps, 1e-6)):
- jpeg = receiver.get_latest_jpeg()
- if jpeg:
- forwarder.send_jpeg(jpeg)
- last_forward = time.time()
- if ref_frame is None or set_ref_now_event.is_set():
- set_ref_now_event.clear()
- ref_frame = frame.copy()
- last_good_H = None
- last_camera_center_ref = None
- trail_points = []
- if opt.log_control_events:
- print("[Control] Reference updated from current frame", flush=True)
- continue
- if next_frame_ref_event.is_set():
- next_frame_ref_event.clear()
- # 优先使用 Unity 发来的 GameView JPEG 作为参考图(若可用)
- applied = False
- if get_unity_ref_jpeg is not None:
- jpeg = get_unity_ref_jpeg()
- if jpeg is not None:
- arr = np.frombuffer(jpeg, dtype=np.uint8)
- img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
- if img is not None:
- ref_frame = maybe_resize(img, opt.resize)
- applied = True
- if not applied:
- ref_frame = frame.copy()
- last_good_H = None
- last_camera_center_ref = None
- trail_points = []
- print("[Control] Reference updated from Unity GameView JPEG (r)" if applied else "[Control] Reference updated from next frame (r)", flush=True)
- continue
- # RoMa estimate H_ref->cur
- h_ref, w_ref = ref_frame.shape[:2]
- h_cur, w_cur = frame.shape[:2]
- num_matches = 0
- inliers_ratio = 0.0
- H_ref_to_cur = None
- ref_pil = Image.fromarray(cv2.cvtColor(ref_frame, cv2.COLOR_BGR2RGB))
- cur_pil = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
- warp, certainty = roma_model.match(ref_pil, cur_pil) if opt.model == "tiny" else roma_model.match(ref_pil, cur_pil, device=torch.device(device))
- matches, _ = roma_model.sample(warp, certainty, num=opt.sample_num)
- k_ref, k_cur = roma_model.to_pixel_coordinates(matches, h_ref, w_ref, h_cur, w_cur)
- pts_ref = k_ref.detach().cpu().numpy().astype(np.float32)
- pts_cur = k_cur.detach().cpu().numpy().astype(np.float32)
- num_matches = len(pts_ref)
- if num_matches >= opt.min_matches:
- H_tmp, mask = cv2.findHomography(pts_ref, pts_cur, cv2.RANSAC, opt.ransac_reproj_threshold)
- if H_tmp is not None and mask is not None:
- inliers_ratio = float(mask.mean())
- if inliers_ratio >= opt.min_inlier_ratio:
- H_ref_to_cur = H_tmp
- last_good_H = H_tmp
- if H_ref_to_cur is None:
- H_ref_to_cur = last_good_H
- # camera center mapping
- camera_center_current = (w_cur // 2, h_cur // 2)
- camera_center_ref = None
- is_valid = False
- if H_ref_to_cur is not None:
- try:
- H_cur_to_ref = np.linalg.inv(H_ref_to_cur)
- center_cur = np.array([[camera_center_current]], dtype=np.float32)
- center_ref_now = cv2.perspectiveTransform(center_cur, H_cur_to_ref)[0, 0]
- center_ref_now = np.clip(center_ref_now, [0, 0], [w_ref - 1, h_ref - 1])
- if last_camera_center_ref is None:
- camera_center_ref = center_ref_now
- else:
- camera_center_ref = opt.smooth_alpha * last_camera_center_ref + (1.0 - opt.smooth_alpha) * center_ref_now
- last_camera_center_ref = camera_center_ref
- is_valid = True
- trail_points.append(camera_center_ref.copy())
- if len(trail_points) > opt.trail_len:
- trail_points = trail_points[-opt.trail_len:]
- except np.linalg.LinAlgError:
- is_valid = False
- # send result
- if is_valid and camera_center_ref is not None:
- result_sender.send_result(True, num_matches, inliers_ratio, float(camera_center_ref[0]), float(camera_center_ref[1]))
- else:
- result_sender.send_result(False, num_matches, 0.0, 0.0, 0.0)
- # display
- if not opt.no_display:
- nowp = time.time()
- if opt.max_display_fps <= 0 or (nowp - last_display) >= (1.0 / max(opt.max_display_fps, 1e-6)):
- ref_view = draw_reference_view(ref_frame, camera_center_ref, is_valid, num_matches, inliers_ratio, trail_points)
- if ref_view is not None and opt.show_fps:
- cv2.putText(ref_view, f"FPS: {fps_ema:.1f}", (10, 60),
- cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
- if ref_view is not None:
- cv2.imshow("Camera Position in Reference", ref_view)
- cv2.imshow("Live Camera", frame)
- key = cv2.waitKey(1) & 0xFF
- if key == ord("q"):
- break
- if key == ord("n"):
- set_ref_now_event.set()
- if key == ord("s"):
- screen_ref_event.set()
- last_display = nowp
- # FPS & throttle
- if opt.max_fps and opt.max_fps > 0:
- elapsed = time.time() - loop_start
- target_dt = 1.0 / float(opt.max_fps)
- if elapsed < target_dt:
- time.sleep(target_dt - elapsed)
- elapsed = time.time() - loop_start
- fps_ema = 0.9 * fps_ema + 0.1 * (1.0 / max(elapsed, 1e-6))
- if opt.timer_print_interval and opt.timer_print_interval > 0 and (time.time() - last_timer) >= opt.timer_print_interval:
- if opt.log_timer:
- print(f"[RomaBridge] FPS={fps_ema:.1f} matches={num_matches} inliers={inliers_ratio:.1%}", flush=True)
- if opt.log_send_result:
- if is_valid and camera_center_ref is not None:
- print(f"[RomaBridge] send_result valid=1 x={float(camera_center_ref[0]):.2f} y={float(camera_center_ref[1]):.2f} matches={num_matches} inliers={inliers_ratio:.1%}", flush=True)
- else:
- print(f"[RomaBridge] send_result valid=0 matches={num_matches}", flush=True)
- last_timer = time.time()
- finally:
- try:
- receiver.stop()
- except Exception:
- pass
- try:
- forwarder.close()
- except Exception:
- pass
- try:
- result_sender.close()
- except Exception:
- pass
- if device_info_stop is not None:
- try:
- device_info_stop.set()
- except Exception:
- pass
- if control_stop is not None:
- control_stop.set()
- if unity_ref_stop is not None:
- unity_ref_stop.set()
- try:
- cv2.destroyAllWindows()
- except Exception:
- pass
- if __name__ == "__main__":
- main()
|