image-match.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. """
  4. 模板匹配:在截图中查找模板图片的位置
  5. 用法1: python image-match.py <screenshot_path> <template_path> [threshold]
  6. 用法2: python image-match.py --adb <adb_path> --device <device_id> --screenshot <out_path> --template <template_path> [--threshold 0.8] [--method template|feature]
  7. 用法2 会在 Python 内执行 adb 截图,避免 Node 处理二进制数据导致的兼容性问题
  8. --method feature: 特征点匹配(优先 LightGlue,失败则 ORB + 多尺度模板),不同分辨率可复用
  9. --method template: 像素模板匹配(TM_CCOEFF_NORMED),仅适合同分辨率
  10. 输出: JSON 到 stdout
  11. """
  12. import sys
  13. import os
  14. import json
  15. import subprocess
  16. try:
  17. import cv2
  18. import numpy as np
  19. except ImportError as e:
  20. print(json.dumps({"success": False, "error": f"OpenCV 导入失败: {e}。请安装: pip install opencv-python numpy"}))
  21. sys.exit(1)
  22. try:
  23. from PIL import Image as PILImage
  24. HAS_PIL = True
  25. except ImportError:
  26. HAS_PIL = False
  27. # LightGlue:若已安装(python/LightGlue pip install -e .),则优先用于 feature 匹配
  28. HAS_LIGHTGLUE = False
  29. try:
  30. _lg_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'LightGlue'))
  31. if _lg_root not in sys.path:
  32. sys.path.insert(0, _lg_root)
  33. from lightglue import LightGlue, SuperPoint
  34. from lightglue.utils import match_pair
  35. import torch
  36. HAS_LIGHTGLUE = True
  37. except Exception:
  38. pass
  39. def run_adb_screencap(adb_path, device, output_path):
  40. """在 Python 内执行 adb 截图,直接处理二进制流"""
  41. # Windows 下子进程需要可执行路径,正斜杠也可用
  42. args = [adb_path.replace('/', os.sep), '-s', device, 'exec-out', 'screencap', '-p']
  43. try:
  44. result = subprocess.run(args, capture_output=True, timeout=15)
  45. if result.returncode != 0:
  46. return False, (result.stderr or result.stdout or b'').decode('utf-8', errors='replace')
  47. data = result.stdout
  48. if not data or len(data) < 100:
  49. return False, "截图数据为空"
  50. # 注意:不要对 PNG 数据做 \r\n 替换,会破坏 IDAT 压缩块导致无法解析
  51. out_dir = os.path.dirname(output_path)
  52. if out_dir:
  53. os.makedirs(out_dir, exist_ok=True)
  54. with open(output_path, 'wb') as f:
  55. f.write(data)
  56. return True, output_path
  57. except subprocess.TimeoutExpired:
  58. return False, "截图超时"
  59. except Exception as e:
  60. return False, str(e)
  61. def load_image(path):
  62. """从文件路径加载图片,兼容 OpenCV 无法直接读取的 PNG(如部分 Android 截图)"""
  63. if not os.path.exists(path):
  64. return None
  65. with open(path, 'rb') as f:
  66. data = np.frombuffer(f.read(), dtype=np.uint8)
  67. img = cv2.imdecode(data, cv2.IMREAD_COLOR)
  68. if img is not None:
  69. return img
  70. img = cv2.imread(path)
  71. if img is not None:
  72. return img
  73. if HAS_PIL:
  74. try:
  75. pil_img = PILImage.open(path).convert('RGB')
  76. img = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
  77. return img
  78. except Exception:
  79. pass
  80. return None
  81. def _numpy_bgr_to_torch_rgb(img_bgr):
  82. """(H,W,3) BGR numpy uint8 -> (3,H,W) float [0,1] RGB for LightGlue"""
  83. rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
  84. t = np.ascontiguousarray(rgb.transpose(2, 0, 1))
  85. return torch.from_numpy(t).float().div(255.0)
  86. def match_by_lightglue(screenshot, template, min_matches=6, device='cpu'):
  87. """
  88. 使用 LightGlue + SuperPoint 做特征匹配,在截图中找模板位置。
  89. 参数已调优以提高难图(如缩略图)匹配率:更多特征点、关闭裁剪/早停、放宽匹配与 RANSAC。
  90. 返回 (x, y, w, h, center_x, center_y) 或 None。
  91. """
  92. if not HAS_LIGHTGLUE:
  93. return None
  94. t_h, t_w = template.shape[:2]
  95. try:
  96. img0 = _numpy_bgr_to_torch_rgb(screenshot)
  97. img1 = _numpy_bgr_to_torch_rgb(template)
  98. extractor = SuperPoint(max_num_keypoints=4096).eval().to(device)
  99. matcher = LightGlue(
  100. features='superpoint',
  101. depth_confidence=-1,
  102. width_confidence=-1,
  103. filter_threshold=0.05,
  104. ).eval().to(device)
  105. feats0, feats1, matches01 = match_pair(extractor, matcher, img0, img1, device=device)
  106. matches = matches01.get('matches')
  107. if matches is None or matches.shape[0] < min_matches:
  108. return None
  109. kp0 = feats0['keypoints']
  110. kp1 = feats1['keypoints']
  111. idx0 = matches[:, 0]
  112. idx1 = matches[:, 1]
  113. pts_screen = kp0[idx0].cpu().numpy().astype(np.float32)
  114. pts_template = kp1[idx1].cpu().numpy().astype(np.float32)
  115. H, mask = cv2.findHomography(pts_template, pts_screen, cv2.RANSAC, 8.0)
  116. if H is None:
  117. return None
  118. corners = np.float32([[0, 0], [t_w, 0], [t_w, t_h], [0, t_h]]).reshape(-1, 1, 2)
  119. corners_screen = cv2.perspectiveTransform(corners, H)
  120. x_coords = corners_screen[:, 0, 0]
  121. y_coords = corners_screen[:, 0, 1]
  122. x = int(round(np.min(x_coords)))
  123. y = int(round(np.min(y_coords)))
  124. w = int(round(np.max(x_coords) - np.min(x_coords)))
  125. h = int(round(np.max(y_coords) - np.min(y_coords)))
  126. center_x = int(round(np.mean(x_coords)))
  127. center_y = int(round(np.mean(y_coords)))
  128. return (x, y, w, h, center_x, center_y)
  129. except Exception:
  130. return None
  131. def match_by_features(screenshot, template, min_good_matches=8):
  132. """
  133. 基于特征点(ORB)匹配作为回退:在截图中找模板位置,返回 (x, y, w, h, center_x, center_y) 或 None。
  134. """
  135. gray_screen = cv2.cvtColor(screenshot, cv2.COLOR_BGR2GRAY)
  136. gray_tpl = cv2.cvtColor(template, cv2.COLOR_BGR2GRAY)
  137. t_h, t_w = template.shape[:2]
  138. orb = cv2.ORB_create(nfeatures=2000)
  139. kp1, desc1 = orb.detectAndCompute(gray_tpl, None)
  140. kp2, desc2 = orb.detectAndCompute(gray_screen, None)
  141. if desc1 is None or desc2 is None or len(kp1) < 4 or len(kp2) < 4:
  142. return None
  143. bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=False)
  144. matches = bf.knnMatch(desc1, desc2, k=2)
  145. good = []
  146. for m_n in matches:
  147. if len(m_n) != 2:
  148. continue
  149. m, n = m_n
  150. if m.distance < 0.75 * n.distance:
  151. good.append(m)
  152. if len(good) < min_good_matches:
  153. return None
  154. src_pts = np.float32([kp1[m.queryIdx].pt for m in good]).reshape(-1, 1, 2)
  155. dst_pts = np.float32([kp2[m.trainIdx].pt for m in good]).reshape(-1, 1, 2)
  156. H, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 5.0)
  157. if H is None:
  158. return None
  159. # 模板四角在截图中的坐标,用质心作为中心点
  160. corners = np.float32([[0, 0], [t_w, 0], [t_w, t_h], [0, t_h]]).reshape(-1, 1, 2)
  161. corners_screen = cv2.perspectiveTransform(corners, H)
  162. x_coords = corners_screen[:, 0, 0]
  163. y_coords = corners_screen[:, 0, 1]
  164. x = int(round(np.min(x_coords)))
  165. y = int(round(np.min(y_coords)))
  166. w = int(round(np.max(x_coords) - np.min(x_coords)))
  167. h = int(round(np.max(y_coords) - np.min(y_coords)))
  168. center_x = int(round(np.mean(x_coords)))
  169. center_y = int(round(np.mean(y_coords)))
  170. return (x, y, w, h, center_x, center_y)
  171. def multi_scale_template_match(screenshot, template, threshold=0.65, scale_min=0.4, scale_max=1.65):
  172. """
  173. 多尺度模板匹配:对模板做多种缩放后在截图中匹配,适配不同分辨率(如简单图标、轮廓)。
  174. scale_min, scale_max: 缩放比范围,如 0.2~1.6 可匹配缩略图。
  175. 返回 (x, y, w, h, center_x, center_y) 或 None。
  176. """
  177. gray_screen = cv2.cvtColor(screenshot, cv2.COLOR_BGR2GRAY)
  178. gray_tpl = cv2.cvtColor(template, cv2.COLOR_BGR2GRAY)
  179. sh, sw = screenshot.shape[:2]
  180. t_h, t_w = template.shape[:2]
  181. best = None
  182. best_val = threshold
  183. step = max(0.05, (scale_max - scale_min) / 12.0)
  184. for scale in np.arange(scale_min, scale_max + step * 0.5, step):
  185. w = max(8, int(round(t_w * scale)))
  186. h = max(8, int(round(t_h * scale)))
  187. if h > sh or w > sw:
  188. continue
  189. resized = cv2.resize(gray_tpl, (w, h), interpolation=cv2.INTER_AREA)
  190. result = cv2.matchTemplate(gray_screen, resized, cv2.TM_CCOEFF_NORMED)
  191. min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(result)
  192. if max_val > best_val:
  193. best_val = max_val
  194. x, y = int(max_loc[0]), int(max_loc[1])
  195. center_x = x + w // 2
  196. center_y = y + h // 2
  197. best = (x, y, w, h, center_x, center_y)
  198. return best
  199. def main():
  200. screenshot_path = None
  201. template_path = None
  202. threshold = 0.8
  203. method = 'feature' # feature=特征点匹配(跨分辨率), template=像素模板匹配
  204. adb_path = None
  205. device = None
  206. scale_min, scale_max = 0.4, 1.65
  207. if len(sys.argv) >= 2 and sys.argv[1] == '--adb':
  208. # 用法2:--adb --device --screenshot --template [--scale-min 0.2] [--scale-max 1.6]
  209. i = 1
  210. while i < len(sys.argv):
  211. if sys.argv[i] == '--adb' and i + 1 < len(sys.argv):
  212. adb_path = sys.argv[i + 1]
  213. i += 2
  214. elif sys.argv[i] == '--device' and i + 1 < len(sys.argv):
  215. device = sys.argv[i + 1]
  216. i += 2
  217. elif sys.argv[i] == '--screenshot' and i + 1 < len(sys.argv):
  218. screenshot_path = sys.argv[i + 1]
  219. i += 2
  220. elif sys.argv[i] == '--template' and i + 1 < len(sys.argv):
  221. template_path = sys.argv[i + 1]
  222. i += 2
  223. elif sys.argv[i] == '--threshold' and i + 1 < len(sys.argv):
  224. threshold = float(sys.argv[i + 1])
  225. i += 2
  226. elif sys.argv[i] == '--method' and i + 1 < len(sys.argv):
  227. method = (sys.argv[i + 1] or 'feature').strip().lower()
  228. if method not in ('template', 'feature'):
  229. method = 'feature'
  230. i += 2
  231. elif sys.argv[i] == '--scale-min' and i + 1 < len(sys.argv):
  232. scale_min = float(sys.argv[i + 1])
  233. i += 2
  234. elif sys.argv[i] == '--scale-max' and i + 1 < len(sys.argv):
  235. scale_max = float(sys.argv[i + 1])
  236. i += 2
  237. else:
  238. i += 1
  239. if adb_path and device and screenshot_path and template_path:
  240. ok, msg = run_adb_screencap(adb_path, device, screenshot_path)
  241. if not ok:
  242. print(json.dumps({"success": False, "error": f"截图失败: {msg}"}))
  243. sys.exit(1)
  244. else:
  245. print(json.dumps({"success": False, "error": "缺少 --adb/--device/--screenshot/--template 参数"}))
  246. sys.exit(1)
  247. else:
  248. # 用法1:位置参数
  249. if len(sys.argv) < 3:
  250. print(json.dumps({"success": False, "error": "用法: image-match.py <screenshot_path> <template_path> [threshold] [method=feature|template]"}))
  251. sys.exit(1)
  252. screenshot_path = sys.argv[1]
  253. template_path = sys.argv[2]
  254. threshold = float(sys.argv[3]) if len(sys.argv) > 3 else 0.8
  255. if len(sys.argv) > 4 and sys.argv[4].lower() in ('template', 'feature'):
  256. method = sys.argv[4].lower()
  257. if not os.path.exists(screenshot_path):
  258. print(json.dumps({"success": False, "error": f"截图文件不存在: {screenshot_path}"}))
  259. sys.exit(1)
  260. if not os.path.exists(template_path):
  261. print(json.dumps({"success": False, "error": f"模板文件不存在: {template_path}"}))
  262. sys.exit(1)
  263. screenshot = load_image(screenshot_path)
  264. template = load_image(template_path)
  265. if screenshot is None:
  266. print(json.dumps({"success": False, "error": "无法读取截图(文件损坏或格式不支持)"}))
  267. sys.exit(1)
  268. if template is None:
  269. print(json.dumps({"success": False, "error": f"无法读取模板: {template_path}"}))
  270. sys.exit(1)
  271. t_h, t_w = template.shape[:2]
  272. if method == 'template' and (t_h > screenshot.shape[0] or t_w > screenshot.shape[1]):
  273. print(json.dumps({"success": False, "error": "模板尺寸大于截图"}))
  274. sys.exit(1)
  275. if method == 'feature':
  276. # 1) LightGlue + SuperPoint 特征匹配(若已安装)
  277. if HAS_LIGHTGLUE:
  278. lg_result = match_by_lightglue(screenshot, template, device='cpu')
  279. if lg_result is not None:
  280. x, y, w, h, center_x, center_y = lg_result
  281. output = {
  282. "success": True,
  283. "x": x,
  284. "y": y,
  285. "width": w,
  286. "height": h,
  287. "center_x": center_x,
  288. "center_y": center_y
  289. }
  290. print(json.dumps(output))
  291. sys.exit(0)
  292. # 2) 回退:ORB 特征点匹配
  293. feat_result = match_by_features(screenshot, template)
  294. if feat_result is not None:
  295. x, y, w, h, center_x, center_y = feat_result
  296. output = {
  297. "success": True,
  298. "x": x,
  299. "y": y,
  300. "width": w,
  301. "height": h,
  302. "center_x": center_x,
  303. "center_y": center_y
  304. }
  305. print(json.dumps(output))
  306. sys.exit(0)
  307. # 3) 回退:多尺度模板匹配,适合简单图标/轮廓(如心形、纯色图标),跨分辨率
  308. fallback_threshold = min(threshold, 0.65)
  309. scale_result = multi_scale_template_match(screenshot, template, threshold=fallback_threshold, scale_min=scale_min, scale_max=scale_max)
  310. if scale_result is not None:
  311. x, y, w, h, center_x, center_y = scale_result
  312. output = {
  313. "success": True,
  314. "x": x,
  315. "y": y,
  316. "width": w,
  317. "height": h,
  318. "center_x": center_x,
  319. "center_y": center_y
  320. }
  321. print(json.dumps(output))
  322. sys.exit(0)
  323. print(json.dumps({"success": False, "error": "LightGlue/特征点与多尺度模板均未匹配(可检查模板是否在画面中或使用 --method template)"}))
  324. sys.exit(1)
  325. # 使用 TM_CCOEFF_NORMED 进行模板匹配(仅同分辨率推荐)
  326. result = cv2.matchTemplate(screenshot, template, cv2.TM_CCOEFF_NORMED)
  327. min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(result)
  328. if max_val < threshold:
  329. print(json.dumps({"success": False, "error": f"未找到匹配 (相似度 {max_val:.3f} < {threshold})"}))
  330. sys.exit(1)
  331. x, y = int(max_loc[0]), int(max_loc[1])
  332. center_x = x + t_w // 2
  333. center_y = y + t_h // 2
  334. output = {
  335. "success": True,
  336. "x": x,
  337. "y": y,
  338. "width": t_w,
  339. "height": t_h,
  340. "center_x": center_x,
  341. "center_y": center_y
  342. }
  343. print(json.dumps(output))
  344. sys.exit(0)
  345. if __name__ == "__main__":
  346. main()