image-match.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. """
  4. 模板匹配:在截图中查找模板图片的位置
  5. 用法1: python image-match.py <screenshot_path> <template_path> [threshold]
  6. 用法2: python image-match.py --adb <adb_path> --device <device_id> --screenshot <out_path> --template <template_path> [--threshold 0.8] [--method template|feature]
  7. 用法2 会在 Python 内执行 adb 截图,避免 Node 处理二进制数据导致的兼容性问题
  8. --method feature: 特征点匹配(优先 LightGlue,失败则 ORB + 多尺度模板),不同分辨率可复用
  9. --method template: 像素模板匹配(TM_CCOEFF_NORMED),仅适合同分辨率
  10. 输出: JSON 到 stdout
  11. """
  12. import sys
  13. import os
  14. import json
  15. import subprocess
  16. try:
  17. import cv2
  18. import numpy as np
  19. except ImportError as e:
  20. print(json.dumps({"success": False, "error": f"OpenCV 导入失败: {e}。请安装: pip install opencv-python numpy"}))
  21. sys.exit(1)
  22. try:
  23. from PIL import Image as PILImage
  24. HAS_PIL = True
  25. except ImportError:
  26. HAS_PIL = False
  27. # LightGlue:若已安装(python/LightGlue pip install -e .),则优先用于 feature 匹配
  28. HAS_LIGHTGLUE = False
  29. try:
  30. _lg_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'LightGlue'))
  31. if _lg_root not in sys.path:
  32. sys.path.insert(0, _lg_root)
  33. from lightglue import LightGlue, SuperPoint
  34. from lightglue.utils import match_pair
  35. import torch
  36. HAS_LIGHTGLUE = True
  37. except Exception:
  38. pass
  39. def run_adb_screencap(adb_path, device, output_path):
  40. """在 Python 内执行 adb 截图,直接处理二进制流"""
  41. # Windows 下子进程需要可执行路径,正斜杠也可用
  42. args = [adb_path.replace('/', os.sep), '-s', device, 'exec-out', 'screencap', '-p']
  43. try:
  44. result = subprocess.run(args, capture_output=True, timeout=15)
  45. if result.returncode != 0:
  46. return False, (result.stderr or result.stdout or b'').decode('utf-8', errors='replace')
  47. data = result.stdout
  48. if not data or len(data) < 100:
  49. return False, "截图数据为空"
  50. # 注意:不要对 PNG 数据做 \r\n 替换,会破坏 IDAT 压缩块导致无法解析
  51. out_dir = os.path.dirname(output_path)
  52. if out_dir:
  53. os.makedirs(out_dir, exist_ok=True)
  54. with open(output_path, 'wb') as f:
  55. f.write(data)
  56. return True, output_path
  57. except subprocess.TimeoutExpired:
  58. return False, "截图超时"
  59. except Exception as e:
  60. return False, str(e)
  61. def load_image(path):
  62. """从文件路径加载图片,兼容 OpenCV 无法直接读取的 PNG(如部分 Android 截图)"""
  63. if not os.path.exists(path):
  64. return None
  65. with open(path, 'rb') as f:
  66. data = np.frombuffer(f.read(), dtype=np.uint8)
  67. img = cv2.imdecode(data, cv2.IMREAD_COLOR)
  68. if img is not None:
  69. return img
  70. img = cv2.imread(path)
  71. if img is not None:
  72. return img
  73. if HAS_PIL:
  74. try:
  75. pil_img = PILImage.open(path).convert('RGB')
  76. img = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
  77. return img
  78. except Exception:
  79. pass
  80. return None
  81. def _numpy_bgr_to_torch_rgb(img_bgr):
  82. """(H,W,3) BGR numpy uint8 -> (3,H,W) float [0,1] RGB for LightGlue"""
  83. rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
  84. t = np.ascontiguousarray(rgb.transpose(2, 0, 1))
  85. return torch.from_numpy(t).float().div(255.0)
  86. def match_by_lightglue(screenshot, template, min_matches=8, device='cpu'):
  87. """
  88. 使用 LightGlue + SuperPoint 做特征匹配,在截图中找模板位置。
  89. 返回 (x, y, w, h, center_x, center_y) 或 None。
  90. """
  91. if not HAS_LIGHTGLUE:
  92. return None
  93. t_h, t_w = template.shape[:2]
  94. try:
  95. img0 = _numpy_bgr_to_torch_rgb(screenshot)
  96. img1 = _numpy_bgr_to_torch_rgb(template)
  97. extractor = SuperPoint(max_num_keypoints=2048).eval().to(device)
  98. matcher = LightGlue(features='superpoint').eval().to(device)
  99. feats0, feats1, matches01 = match_pair(extractor, matcher, img0, img1, device=device)
  100. matches = matches01.get('matches')
  101. if matches is None or matches.shape[0] < min_matches:
  102. return None
  103. kp0 = feats0['keypoints']
  104. kp1 = feats1['keypoints']
  105. idx0 = matches[:, 0]
  106. idx1 = matches[:, 1]
  107. pts_screen = kp0[idx0].cpu().numpy().astype(np.float32)
  108. pts_template = kp1[idx1].cpu().numpy().astype(np.float32)
  109. H, mask = cv2.findHomography(pts_template, pts_screen, cv2.RANSAC, 5.0)
  110. if H is None:
  111. return None
  112. corners = np.float32([[0, 0], [t_w, 0], [t_w, t_h], [0, t_h]]).reshape(-1, 1, 2)
  113. corners_screen = cv2.perspectiveTransform(corners, H)
  114. x_coords = corners_screen[:, 0, 0]
  115. y_coords = corners_screen[:, 0, 1]
  116. x = int(round(np.min(x_coords)))
  117. y = int(round(np.min(y_coords)))
  118. w = int(round(np.max(x_coords) - np.min(x_coords)))
  119. h = int(round(np.max(y_coords) - np.min(y_coords)))
  120. center_x = int(round(np.mean(x_coords)))
  121. center_y = int(round(np.mean(y_coords)))
  122. return (x, y, w, h, center_x, center_y)
  123. except Exception:
  124. return None
  125. def match_by_features(screenshot, template, min_good_matches=8):
  126. """
  127. 基于特征点(ORB)匹配作为回退:在截图中找模板位置,返回 (x, y, w, h, center_x, center_y) 或 None。
  128. """
  129. gray_screen = cv2.cvtColor(screenshot, cv2.COLOR_BGR2GRAY)
  130. gray_tpl = cv2.cvtColor(template, cv2.COLOR_BGR2GRAY)
  131. t_h, t_w = template.shape[:2]
  132. orb = cv2.ORB_create(nfeatures=2000)
  133. kp1, desc1 = orb.detectAndCompute(gray_tpl, None)
  134. kp2, desc2 = orb.detectAndCompute(gray_screen, None)
  135. if desc1 is None or desc2 is None or len(kp1) < 4 or len(kp2) < 4:
  136. return None
  137. bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=False)
  138. matches = bf.knnMatch(desc1, desc2, k=2)
  139. good = []
  140. for m_n in matches:
  141. if len(m_n) != 2:
  142. continue
  143. m, n = m_n
  144. if m.distance < 0.75 * n.distance:
  145. good.append(m)
  146. if len(good) < min_good_matches:
  147. return None
  148. src_pts = np.float32([kp1[m.queryIdx].pt for m in good]).reshape(-1, 1, 2)
  149. dst_pts = np.float32([kp2[m.trainIdx].pt for m in good]).reshape(-1, 1, 2)
  150. H, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 5.0)
  151. if H is None:
  152. return None
  153. # 模板四角在截图中的坐标,用质心作为中心点
  154. corners = np.float32([[0, 0], [t_w, 0], [t_w, t_h], [0, t_h]]).reshape(-1, 1, 2)
  155. corners_screen = cv2.perspectiveTransform(corners, H)
  156. x_coords = corners_screen[:, 0, 0]
  157. y_coords = corners_screen[:, 0, 1]
  158. x = int(round(np.min(x_coords)))
  159. y = int(round(np.min(y_coords)))
  160. w = int(round(np.max(x_coords) - np.min(x_coords)))
  161. h = int(round(np.max(y_coords) - np.min(y_coords)))
  162. center_x = int(round(np.mean(x_coords)))
  163. center_y = int(round(np.mean(y_coords)))
  164. return (x, y, w, h, center_x, center_y)
  165. def multi_scale_template_match(screenshot, template, threshold=0.65):
  166. """
  167. 多尺度模板匹配:对模板做多种缩放后在截图中匹配,适配不同分辨率(如简单图标、轮廓)。
  168. 返回 (x, y, w, h, center_x, center_y) 或 None。
  169. """
  170. gray_screen = cv2.cvtColor(screenshot, cv2.COLOR_BGR2GRAY)
  171. gray_tpl = cv2.cvtColor(template, cv2.COLOR_BGR2GRAY)
  172. sh, sw = screenshot.shape[:2]
  173. t_h, t_w = template.shape[:2]
  174. best = None
  175. best_val = threshold
  176. # 从 0.4 到 1.6 倍缩放,步长约 0.15,保证缩放后不超出截图
  177. for scale in np.arange(0.4, 1.65, 0.12):
  178. w = max(8, int(round(t_w * scale)))
  179. h = max(8, int(round(t_h * scale)))
  180. if h > sh or w > sw:
  181. continue
  182. resized = cv2.resize(gray_tpl, (w, h), interpolation=cv2.INTER_AREA)
  183. result = cv2.matchTemplate(gray_screen, resized, cv2.TM_CCOEFF_NORMED)
  184. min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(result)
  185. if max_val > best_val:
  186. best_val = max_val
  187. x, y = int(max_loc[0]), int(max_loc[1])
  188. center_x = x + w // 2
  189. center_y = y + h // 2
  190. best = (x, y, w, h, center_x, center_y)
  191. return best
  192. def main():
  193. screenshot_path = None
  194. template_path = None
  195. threshold = 0.8
  196. method = 'feature' # feature=特征点匹配(跨分辨率), template=像素模板匹配
  197. adb_path = None
  198. device = None
  199. if len(sys.argv) >= 2 and sys.argv[1] == '--adb':
  200. # 用法2:--adb --device --screenshot --template
  201. i = 1
  202. while i < len(sys.argv):
  203. if sys.argv[i] == '--adb' and i + 1 < len(sys.argv):
  204. adb_path = sys.argv[i + 1]
  205. i += 2
  206. elif sys.argv[i] == '--device' and i + 1 < len(sys.argv):
  207. device = sys.argv[i + 1]
  208. i += 2
  209. elif sys.argv[i] == '--screenshot' and i + 1 < len(sys.argv):
  210. screenshot_path = sys.argv[i + 1]
  211. i += 2
  212. elif sys.argv[i] == '--template' and i + 1 < len(sys.argv):
  213. template_path = sys.argv[i + 1]
  214. i += 2
  215. elif sys.argv[i] == '--threshold' and i + 1 < len(sys.argv):
  216. threshold = float(sys.argv[i + 1])
  217. i += 2
  218. elif sys.argv[i] == '--method' and i + 1 < len(sys.argv):
  219. method = (sys.argv[i + 1] or 'feature').strip().lower()
  220. if method not in ('template', 'feature'):
  221. method = 'feature'
  222. i += 2
  223. else:
  224. i += 1
  225. if adb_path and device and screenshot_path and template_path:
  226. ok, msg = run_adb_screencap(adb_path, device, screenshot_path)
  227. if not ok:
  228. print(json.dumps({"success": False, "error": f"截图失败: {msg}"}))
  229. sys.exit(1)
  230. else:
  231. print(json.dumps({"success": False, "error": "缺少 --adb/--device/--screenshot/--template 参数"}))
  232. sys.exit(1)
  233. else:
  234. # 用法1:位置参数
  235. if len(sys.argv) < 3:
  236. print(json.dumps({"success": False, "error": "用法: image-match.py <screenshot_path> <template_path> [threshold] [method=feature|template]"}))
  237. sys.exit(1)
  238. screenshot_path = sys.argv[1]
  239. template_path = sys.argv[2]
  240. threshold = float(sys.argv[3]) if len(sys.argv) > 3 else 0.8
  241. if len(sys.argv) > 4 and sys.argv[4].lower() in ('template', 'feature'):
  242. method = sys.argv[4].lower()
  243. if not os.path.exists(screenshot_path):
  244. print(json.dumps({"success": False, "error": f"截图文件不存在: {screenshot_path}"}))
  245. sys.exit(1)
  246. if not os.path.exists(template_path):
  247. print(json.dumps({"success": False, "error": f"模板文件不存在: {template_path}"}))
  248. sys.exit(1)
  249. screenshot = load_image(screenshot_path)
  250. template = load_image(template_path)
  251. if screenshot is None:
  252. print(json.dumps({"success": False, "error": "无法读取截图(文件损坏或格式不支持)"}))
  253. sys.exit(1)
  254. if template is None:
  255. print(json.dumps({"success": False, "error": f"无法读取模板: {template_path}"}))
  256. sys.exit(1)
  257. t_h, t_w = template.shape[:2]
  258. if method == 'template' and (t_h > screenshot.shape[0] or t_w > screenshot.shape[1]):
  259. print(json.dumps({"success": False, "error": "模板尺寸大于截图"}))
  260. sys.exit(1)
  261. if method == 'feature':
  262. # 1) LightGlue + SuperPoint 特征匹配(若已安装)
  263. if HAS_LIGHTGLUE:
  264. lg_result = match_by_lightglue(screenshot, template, device='cpu')
  265. if lg_result is not None:
  266. x, y, w, h, center_x, center_y = lg_result
  267. output = {
  268. "success": True,
  269. "x": x,
  270. "y": y,
  271. "width": w,
  272. "height": h,
  273. "center_x": center_x,
  274. "center_y": center_y
  275. }
  276. print(json.dumps(output))
  277. sys.exit(0)
  278. # 2) 回退:ORB 特征点匹配
  279. feat_result = match_by_features(screenshot, template)
  280. if feat_result is not None:
  281. x, y, w, h, center_x, center_y = feat_result
  282. output = {
  283. "success": True,
  284. "x": x,
  285. "y": y,
  286. "width": w,
  287. "height": h,
  288. "center_x": center_x,
  289. "center_y": center_y
  290. }
  291. print(json.dumps(output))
  292. sys.exit(0)
  293. # 3) 回退:多尺度模板匹配,适合简单图标/轮廓(如心形、纯色图标),跨分辨率
  294. fallback_threshold = min(threshold, 0.65)
  295. scale_result = multi_scale_template_match(screenshot, template, threshold=fallback_threshold)
  296. if scale_result is not None:
  297. x, y, w, h, center_x, center_y = scale_result
  298. output = {
  299. "success": True,
  300. "x": x,
  301. "y": y,
  302. "width": w,
  303. "height": h,
  304. "center_x": center_x,
  305. "center_y": center_y
  306. }
  307. print(json.dumps(output))
  308. sys.exit(0)
  309. print(json.dumps({"success": False, "error": "LightGlue/特征点与多尺度模板均未匹配(可检查模板是否在画面中或使用 --method template)"}))
  310. sys.exit(1)
  311. # 使用 TM_CCOEFF_NORMED 进行模板匹配(仅同分辨率推荐)
  312. result = cv2.matchTemplate(screenshot, template, cv2.TM_CCOEFF_NORMED)
  313. min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(result)
  314. if max_val < threshold:
  315. print(json.dumps({"success": False, "error": f"未找到匹配 (相似度 {max_val:.3f} < {threshold})"}))
  316. sys.exit(1)
  317. x, y = int(max_loc[0]), int(max_loc[1])
  318. center_x = x + t_w // 2
  319. center_y = y + t_h // 2
  320. output = {
  321. "success": True,
  322. "x": x,
  323. "y": y,
  324. "width": t_w,
  325. "height": t_h,
  326. "center_x": center_x,
  327. "center_y": center_y
  328. }
  329. print(json.dumps(output))
  330. sys.exit(0)
  331. if __name__ == "__main__":
  332. main()