image-match.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. """
  4. 模板匹配:在截图中查找模板图片的位置
  5. 用法1: python image-match.py <screenshot_path> <template_path> [threshold]
  6. 用法2: python image-match.py --adb <adb_path> --device <device_id> --screenshot <out_path> --template <template_path> [--threshold 0.8] [--method template|feature]
  7. 用法2 会在 Python 内执行 adb 截图,避免 Node 处理二进制数据导致的兼容性问题
  8. --method feature: 特征点匹配(优先 RoMa,失败则 ORB + 多尺度模板),不同分辨率可复用
  9. --method template: 像素模板匹配(TM_CCOEFF_NORMED),仅适合同分辨率
  10. 输出: JSON 到 stdout
  11. """
  12. import sys
  13. import os
  14. import json
  15. import subprocess
  16. try:
  17. import cv2
  18. import numpy as np
  19. except ImportError as e:
  20. print(json.dumps({"success": False, "error": f"OpenCV 导入失败: {e}。请安装: pip install opencv-python numpy"}))
  21. sys.exit(1)
  22. try:
  23. from PIL import Image as PILImage
  24. HAS_PIL = True
  25. except ImportError:
  26. HAS_PIL = False
  27. # RoMa:若已安装(python/RoMa,pip install -e .),则优先用于 feature 匹配
  28. HAS_ROMA = False
  29. try:
  30. _roma_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', 'python', 'RoMa'))
  31. if os.path.isdir(_roma_root) and _roma_root not in sys.path:
  32. sys.path.insert(0, _roma_root)
  33. from romatch import roma_outdoor
  34. import torch as _torch_roma
  35. HAS_ROMA = True
  36. except Exception:
  37. pass
  38. def run_adb_screencap(adb_path, device, output_path):
  39. """在 Python 内执行 adb 截图,直接处理二进制流"""
  40. # Windows 下子进程需要可执行路径,正斜杠也可用
  41. args = [adb_path.replace('/', os.sep), '-s', device, 'exec-out', 'screencap', '-p']
  42. try:
  43. result = subprocess.run(args, capture_output=True, timeout=15)
  44. if result.returncode != 0:
  45. return False, (result.stderr or result.stdout or b'').decode('utf-8', errors='replace')
  46. data = result.stdout
  47. if not data or len(data) < 100:
  48. return False, "截图数据为空"
  49. # 注意:不要对 PNG 数据做 \r\n 替换,会破坏 IDAT 压缩块导致无法解析
  50. out_dir = os.path.dirname(output_path)
  51. if out_dir:
  52. os.makedirs(out_dir, exist_ok=True)
  53. with open(output_path, 'wb') as f:
  54. f.write(data)
  55. return True, output_path
  56. except subprocess.TimeoutExpired:
  57. return False, "截图超时"
  58. except Exception as e:
  59. return False, str(e)
  60. def load_image(path):
  61. """从文件路径加载图片,兼容 OpenCV 无法直接读取的 PNG(如部分 Android 截图)"""
  62. if not os.path.exists(path):
  63. return None
  64. with open(path, 'rb') as f:
  65. data = np.frombuffer(f.read(), dtype=np.uint8)
  66. img = cv2.imdecode(data, cv2.IMREAD_COLOR)
  67. if img is not None:
  68. return img
  69. img = cv2.imread(path)
  70. if img is not None:
  71. return img
  72. if HAS_PIL:
  73. try:
  74. pil_img = PILImage.open(path).convert('RGB')
  75. img = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
  76. return img
  77. except Exception:
  78. pass
  79. return None
  80. def _roma_params():
  81. """从环境变量读取 RoMa 参数,便于反复测试调参。默认针对「模板为截图中缩略图」优化。"""
  82. import os as _os
  83. coarse = int(_os.environ.get("ROMA_COARSE_RES", "560"))
  84. upsample = int(_os.environ.get("ROMA_UPSAMPLE_RES", "1152"))
  85. min_m = int(_os.environ.get("ROMA_MIN_MATCHES", "3"))
  86. sample_num = int(_os.environ.get("ROMA_SAMPLE_NUM", "20000"))
  87. ransac = float(_os.environ.get("ROMA_RANSAC_THRESH", "14.0"))
  88. return coarse, upsample, min_m, sample_num, ransac
  89. def match_by_roma(screenshot, template, min_matches=6, device=None):
  90. """
  91. 使用 RoMa 稠密特征匹配,在截图中找模板位置;精度高、跨分辨率。
  92. 返回 (x, y, w, h, center_x, center_y) 或 None。
  93. 可通过环境变量调参: ROMA_COARSE_RES, ROMA_UPSAMPLE_RES, ROMA_MIN_MATCHES, ROMA_SAMPLE_NUM, ROMA_RANSAC_THRESH
  94. """
  95. if not HAS_ROMA:
  96. return None
  97. t_h, t_w = template.shape[:2]
  98. sh_h, sh_w = screenshot.shape[:2]
  99. coarse_res, upsample_res, env_min_matches, sample_num, ransac_thresh = _roma_params()
  100. min_matches = env_min_matches # 调参时用环境变量 ROMA_MIN_MATCHES
  101. import tempfile
  102. try:
  103. if _torch_roma.get_float32_matmul_precision() != "highest":
  104. _torch_roma.set_float32_matmul_precision("highest")
  105. except Exception:
  106. pass
  107. try:
  108. if device is None:
  109. device = _torch_roma.device("cuda" if _torch_roma.cuda.is_available() else "cpu")
  110. roma_model = roma_outdoor(device=device, coarse_res=coarse_res, upsample_res=upsample_res)
  111. with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as fa:
  112. with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as fb:
  113. path_a = fa.name
  114. path_b = fb.name
  115. try:
  116. if HAS_PIL:
  117. PILImage.fromarray(cv2.cvtColor(screenshot, cv2.COLOR_BGR2RGB)).save(path_a)
  118. PILImage.fromarray(cv2.cvtColor(template, cv2.COLOR_BGR2RGB)).save(path_b)
  119. else:
  120. cv2.imwrite(path_a, cv2.cvtColor(screenshot, cv2.COLOR_BGR2RGB))
  121. cv2.imwrite(path_b, cv2.cvtColor(template, cv2.COLOR_BGR2RGB))
  122. warp, certainty = roma_model.match(path_a, path_b, device=device)
  123. matches, certainty = roma_model.sample(warp, certainty, num=sample_num)
  124. H_out, W_out = roma_model.get_output_resolution()
  125. kptsA, kptsB = roma_model.to_pixel_coordinates(matches, H_out, W_out, H_out, W_out)
  126. kptsA = kptsA.cpu().numpy().astype(np.float32)
  127. kptsB = kptsB.cpu().numpy().astype(np.float32)
  128. if kptsA.shape[0] < min_matches:
  129. return None
  130. scale_ax = sh_w / float(W_out)
  131. scale_ay = sh_h / float(H_out)
  132. scale_bx = t_w / float(W_out)
  133. scale_by = t_h / float(H_out)
  134. kptsA_orig = kptsA * np.array([scale_ax, scale_ay])
  135. kptsB_orig = kptsB * np.array([scale_bx, scale_by])
  136. # RANSAC 距离阈值略放宽,适配缩放/透视变形(可由 ROMA_RANSAC_THRESH 调节)
  137. H, mask = cv2.findHomography(kptsB_orig, kptsA_orig, cv2.RANSAC, ransac_thresh)
  138. if H is None:
  139. return None
  140. corners = np.float32([[0, 0], [t_w, 0], [t_w, t_h], [0, t_h]]).reshape(-1, 1, 2)
  141. corners_screen = cv2.perspectiveTransform(corners, H)
  142. x_coords = corners_screen[:, 0, 0]
  143. y_coords = corners_screen[:, 0, 1]
  144. x = int(round(np.min(x_coords)))
  145. y = int(round(np.min(y_coords)))
  146. w = int(round(np.max(x_coords) - np.min(x_coords)))
  147. h = int(round(np.max(y_coords) - np.min(y_coords)))
  148. center_x = int(round(np.mean(x_coords)))
  149. center_y = int(round(np.mean(y_coords)))
  150. return (x, y, w, h, center_x, center_y)
  151. finally:
  152. try:
  153. os.unlink(path_a)
  154. os.unlink(path_b)
  155. except Exception:
  156. pass
  157. except Exception:
  158. return None
  159. def match_by_features(screenshot, template, min_good_matches=6):
  160. """
  161. 基于特征点(ORB)匹配作为回退:在截图中找模板位置,返回 (x, y, w, h, center_x, center_y) 或 None。
  162. """
  163. gray_screen = cv2.cvtColor(screenshot, cv2.COLOR_BGR2GRAY)
  164. gray_tpl = cv2.cvtColor(template, cv2.COLOR_BGR2GRAY)
  165. t_h, t_w = template.shape[:2]
  166. orb = cv2.ORB_create(nfeatures=2000)
  167. kp1, desc1 = orb.detectAndCompute(gray_tpl, None)
  168. kp2, desc2 = orb.detectAndCompute(gray_screen, None)
  169. if desc1 is None or desc2 is None or len(kp1) < 4 or len(kp2) < 4:
  170. return None
  171. bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=False)
  172. matches = bf.knnMatch(desc1, desc2, k=2)
  173. good = []
  174. for m_n in matches:
  175. if len(m_n) != 2:
  176. continue
  177. m, n = m_n
  178. if m.distance < 0.82 * n.distance:
  179. good.append(m)
  180. if len(good) < min_good_matches:
  181. return None
  182. src_pts = np.float32([kp1[m.queryIdx].pt for m in good]).reshape(-1, 1, 2)
  183. dst_pts = np.float32([kp2[m.trainIdx].pt for m in good]).reshape(-1, 1, 2)
  184. H, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 5.0)
  185. if H is None:
  186. return None
  187. # 模板四角在截图中的坐标,用质心作为中心点
  188. corners = np.float32([[0, 0], [t_w, 0], [t_w, t_h], [0, t_h]]).reshape(-1, 1, 2)
  189. corners_screen = cv2.perspectiveTransform(corners, H)
  190. x_coords = corners_screen[:, 0, 0]
  191. y_coords = corners_screen[:, 0, 1]
  192. x = int(round(np.min(x_coords)))
  193. y = int(round(np.min(y_coords)))
  194. w = int(round(np.max(x_coords) - np.min(x_coords)))
  195. h = int(round(np.max(y_coords) - np.min(y_coords)))
  196. center_x = int(round(np.mean(x_coords)))
  197. center_y = int(round(np.mean(y_coords)))
  198. return (x, y, w, h, center_x, center_y)
  199. def multi_scale_template_match(screenshot, template, threshold=0.50, scale_min=0.4, scale_max=1.65):
  200. """
  201. 多尺度模板匹配:对模板做多种缩放后在截图中匹配,适配不同分辨率(如简单图标、轮廓)。
  202. scale_min, scale_max: 缩放比范围,如 0.08~2.0 可匹配截图中小缩略图。
  203. 返回 (x, y, w, h, center_x, center_y) 或 None。
  204. """
  205. gray_screen = cv2.cvtColor(screenshot, cv2.COLOR_BGR2GRAY)
  206. gray_tpl = cv2.cvtColor(template, cv2.COLOR_BGR2GRAY)
  207. sh, sw = screenshot.shape[:2]
  208. t_h, t_w = template.shape[:2]
  209. best = None
  210. best_val = threshold
  211. step = max(0.03, (scale_max - scale_min) / 38.0)
  212. for scale in np.arange(scale_min, scale_max + step * 0.5, step):
  213. w = max(8, int(round(t_w * scale)))
  214. h = max(8, int(round(t_h * scale)))
  215. if h > sh or w > sw:
  216. continue
  217. resized = cv2.resize(gray_tpl, (w, h), interpolation=cv2.INTER_AREA)
  218. result = cv2.matchTemplate(gray_screen, resized, cv2.TM_CCOEFF_NORMED)
  219. min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(result)
  220. if max_val > best_val:
  221. best_val = max_val
  222. x, y = int(max_loc[0]), int(max_loc[1])
  223. center_x = x + w // 2
  224. center_y = y + h // 2
  225. best = (x, y, w, h, center_x, center_y)
  226. return best
  227. def main():
  228. screenshot_path = None
  229. template_path = None
  230. threshold = 0.8
  231. method = 'feature' # feature=特征点匹配(跨分辨率), template=像素模板匹配
  232. adb_path = None
  233. device = None
  234. scale_min, scale_max = 0.4, 1.65
  235. if len(sys.argv) >= 2 and sys.argv[1] == '--adb':
  236. # 用法2:--adb --device --screenshot --template [--scale-min 0.2] [--scale-max 1.6]
  237. i = 1
  238. while i < len(sys.argv):
  239. if sys.argv[i] == '--adb' and i + 1 < len(sys.argv):
  240. adb_path = sys.argv[i + 1]
  241. i += 2
  242. elif sys.argv[i] == '--device' and i + 1 < len(sys.argv):
  243. device = sys.argv[i + 1]
  244. i += 2
  245. elif sys.argv[i] == '--screenshot' and i + 1 < len(sys.argv):
  246. screenshot_path = sys.argv[i + 1]
  247. i += 2
  248. elif sys.argv[i] == '--template' and i + 1 < len(sys.argv):
  249. template_path = sys.argv[i + 1]
  250. i += 2
  251. elif sys.argv[i] == '--threshold' and i + 1 < len(sys.argv):
  252. threshold = float(sys.argv[i + 1])
  253. i += 2
  254. elif sys.argv[i] == '--method' and i + 1 < len(sys.argv):
  255. method = (sys.argv[i + 1] or 'feature').strip().lower()
  256. if method not in ('template', 'feature'):
  257. method = 'feature'
  258. i += 2
  259. elif sys.argv[i] == '--scale-min' and i + 1 < len(sys.argv):
  260. scale_min = float(sys.argv[i + 1])
  261. i += 2
  262. elif sys.argv[i] == '--scale-max' and i + 1 < len(sys.argv):
  263. scale_max = float(sys.argv[i + 1])
  264. i += 2
  265. else:
  266. i += 1
  267. if adb_path and device and screenshot_path and template_path:
  268. ok, msg = run_adb_screencap(adb_path, device, screenshot_path)
  269. if not ok:
  270. print(json.dumps({"success": False, "error": f"截图失败: {msg}"}))
  271. sys.exit(1)
  272. else:
  273. print(json.dumps({"success": False, "error": "缺少 --adb/--device/--screenshot/--template 参数"}))
  274. sys.exit(1)
  275. else:
  276. # 用法1:位置参数
  277. if len(sys.argv) < 3:
  278. print(json.dumps({"success": False, "error": "用法: image-match.py <screenshot_path> <template_path> [threshold] [method=feature|template]"}))
  279. sys.exit(1)
  280. screenshot_path = sys.argv[1]
  281. template_path = sys.argv[2]
  282. threshold = float(sys.argv[3]) if len(sys.argv) > 3 else 0.8
  283. if len(sys.argv) > 4 and sys.argv[4].lower() in ('template', 'feature'):
  284. method = sys.argv[4].lower()
  285. if not os.path.exists(screenshot_path):
  286. print(json.dumps({"success": False, "error": f"截图文件不存在: {screenshot_path}"}))
  287. sys.exit(1)
  288. if not os.path.exists(template_path):
  289. print(json.dumps({"success": False, "error": f"模板文件不存在: {template_path}"}))
  290. sys.exit(1)
  291. screenshot = load_image(screenshot_path)
  292. template = load_image(template_path)
  293. if screenshot is None:
  294. print(json.dumps({"success": False, "error": "无法读取截图(文件损坏或格式不支持)"}))
  295. sys.exit(1)
  296. if template is None:
  297. print(json.dumps({"success": False, "error": f"无法读取模板: {template_path}"}))
  298. sys.exit(1)
  299. t_h, t_w = template.shape[:2]
  300. if method == 'template' and (t_h > screenshot.shape[0] or t_w > screenshot.shape[1]):
  301. print(json.dumps({"success": False, "error": "模板尺寸大于截图"}))
  302. sys.exit(1)
  303. if method == 'feature':
  304. # 1) RoMa 稠密特征匹配(若已安装);失败时用备用参数再试一次
  305. if HAS_ROMA:
  306. roma_result = match_by_roma(screenshot, template, min_matches=4)
  307. if roma_result is None:
  308. _save = (os.environ.get('ROMA_COARSE_RES'), os.environ.get('ROMA_UPSAMPLE_RES'), os.environ.get('ROMA_MIN_MATCHES'))
  309. for co, up, mn in [(672, 1120, 4), (448, 864, 2)]:
  310. try:
  311. os.environ['ROMA_COARSE_RES'] = str(co)
  312. os.environ['ROMA_UPSAMPLE_RES'] = str(up)
  313. os.environ['ROMA_MIN_MATCHES'] = str(mn)
  314. roma_result = match_by_roma(screenshot, template, min_matches=mn)
  315. if roma_result is not None:
  316. break
  317. finally:
  318. pass
  319. try:
  320. if _save[0] is None and 'ROMA_COARSE_RES' in os.environ:
  321. del os.environ['ROMA_COARSE_RES']
  322. elif _save[0] is not None:
  323. os.environ['ROMA_COARSE_RES'] = _save[0]
  324. if _save[1] is None and 'ROMA_UPSAMPLE_RES' in os.environ:
  325. del os.environ['ROMA_UPSAMPLE_RES']
  326. elif _save[1] is not None:
  327. os.environ['ROMA_UPSAMPLE_RES'] = _save[1]
  328. if _save[2] is None and 'ROMA_MIN_MATCHES' in os.environ:
  329. del os.environ['ROMA_MIN_MATCHES']
  330. elif _save[2] is not None:
  331. os.environ['ROMA_MIN_MATCHES'] = _save[2]
  332. except Exception:
  333. pass
  334. if roma_result is not None:
  335. x, y, w, h, center_x, center_y = roma_result
  336. output = {
  337. "success": True,
  338. "x": x,
  339. "y": y,
  340. "width": w,
  341. "height": h,
  342. "center_x": center_x,
  343. "center_y": center_y
  344. }
  345. print(json.dumps(output))
  346. sys.exit(0)
  347. # 2) 回退:ORB 特征点匹配
  348. feat_result = match_by_features(screenshot, template)
  349. if feat_result is not None:
  350. x, y, w, h, center_x, center_y = feat_result
  351. output = {
  352. "success": True,
  353. "x": x,
  354. "y": y,
  355. "width": w,
  356. "height": h,
  357. "center_x": center_x,
  358. "center_y": center_y
  359. }
  360. print(json.dumps(output))
  361. sys.exit(0)
  362. # 3) 回退:多尺度模板匹配,放宽阈值与步数以适配截图中缩略图
  363. fallback_threshold = min(threshold, 0.50)
  364. scale_min_use = min(scale_min, 0.08)
  365. scale_result = multi_scale_template_match(screenshot, template, threshold=fallback_threshold, scale_min=scale_min_use, scale_max=scale_max)
  366. if scale_result is None and (t_w > 1.3 * t_h or t_h > 1.3 * t_w):
  367. t_s = min(t_w, t_h)
  368. cx, cy = t_w // 2, t_h // 2
  369. y0, y1 = max(0, cy - t_s // 2), min(t_h, cy + t_s // 2)
  370. x0, x1 = max(0, cx - t_s // 2), min(t_w, cx + t_s // 2)
  371. if y1 > y0 and x1 > x0:
  372. crop = template[y0:y1, x0:x1]
  373. scale_result = multi_scale_template_match(screenshot, crop, threshold=fallback_threshold, scale_min=scale_min_use, scale_max=scale_max)
  374. if scale_result is not None:
  375. x, y, w, h, center_x, center_y = scale_result
  376. output = {
  377. "success": True,
  378. "x": x,
  379. "y": y,
  380. "width": w,
  381. "height": h,
  382. "center_x": center_x,
  383. "center_y": center_y
  384. }
  385. print(json.dumps(output))
  386. sys.exit(0)
  387. print(json.dumps({"success": False, "error": "RoMa/特征点与多尺度模板均未匹配(可检查模板是否在画面中或使用 --method template)"}))
  388. sys.exit(1)
  389. # 使用 TM_CCOEFF_NORMED 进行模板匹配(仅同分辨率推荐)
  390. result = cv2.matchTemplate(screenshot, template, cv2.TM_CCOEFF_NORMED)
  391. min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(result)
  392. if max_val < threshold:
  393. print(json.dumps({"success": False, "error": f"未找到匹配 (相似度 {max_val:.3f} < {threshold})"}))
  394. sys.exit(1)
  395. x, y = int(max_loc[0]), int(max_loc[1])
  396. center_x = x + t_w // 2
  397. center_y = y + t_h // 2
  398. output = {
  399. "success": True,
  400. "x": x,
  401. "y": y,
  402. "width": t_w,
  403. "height": t_h,
  404. "center_x": center_x,
  405. "center_y": center_y
  406. }
  407. print(json.dumps(output))
  408. sys.exit(0)
  409. if __name__ == "__main__":
  410. main()