"""在图片上绘制绿色点并保存为 output/keypoint.png。 程序里可调用 ``draw_point(match_dict, \"output/point.png\")``,字典须含 ``screenshot_image_path``、``matched_keypoints_original_xy``。 命令行示例(在仓库根目录执行):: python -m thirdparty.draw_point input/screenshot.png --npy points.npy python -m thirdparty.draw_point input/screenshot.png --xy \"100,200;300,400;500,600\" .npy 形状须为 (N, 2) 或长度 2N 的一维数组(按 x0,y0,x1,y1,… 展开)。 """ from __future__ import annotations import argparse import sys from pathlib import Path import cv2 import numpy as np from routing import OUTPUT_DIR, repo_path _OUTPUT_PATH = OUTPUT_DIR / "keypoint.png" def _load_points_from_npy(path: Path) -> np.ndarray: arr = np.load(path, allow_pickle=False) if arr.ndim == 1: if arr.size % 2 != 0: raise ValueError(f"一维点数组长度须为偶数,当前 size={arr.size}") arr = arr.reshape(-1, 2) elif arr.ndim == 2: if arr.shape[1] != 2: raise ValueError(f".npy 二维数组形状须为 (N,2),当前 {arr.shape}") else: raise ValueError(f"不支持的数组维度: {arr.ndim}") return arr.astype(np.float64) def _parse_xy_string(s: str) -> np.ndarray: parts = [p.strip() for p in s.split(";") if p.strip()] if not parts: raise ValueError("未解析到任何点,格式示例:100,200;300,400") rows: list[list[float]] = [] for p in parts: xy = [float(x.strip()) for x in p.split(",")] if len(xy) != 2: raise ValueError(f"每点须为 x,y,无效片段:{p!r}") rows.append(xy) return np.asarray(rows, dtype=np.float64) def draw_point( match_dict: dict, output_image_path: str | Path | None = None, *, point_color_bgr: tuple[int, int, int] = (0, 255, 0), radius: int = 8, thickness: int = -1, ) -> Path: """在 ``match_dict['screenshot_image_path']`` 上绘制 ``matched_keypoints_original_xy``。""" out = ( repo_path(output_image_path) if output_image_path is not None else _OUTPUT_PATH.resolve() ) scr = repo_path(match_dict["screenshot_image_path"]) raw = match_dict.get("matched_keypoints_original_xy") if raw is None: raise KeyError("match_dict 缺少 matched_keypoints_original_xy") pts = np.asarray(raw, dtype=np.float64).reshape(-1, 2) bgr = cv2.imread(str(scr), cv2.IMREAD_COLOR) if bgr is None: raise FileNotFoundError(f"无法读取截图:{scr}") for x, y in pts: cv2.circle( bgr, (int(round(x)), int(round(y))), radius, point_color_bgr, thickness, lineType=cv2.LINE_AA, ) out.parent.mkdir(parents=True, exist_ok=True) if not cv2.imwrite(str(out), bgr): raise OSError(f"无法写入:{out}") return out class DrawPoint: """Thin wrapper around :func:`draw_point` when you already have an ``(N, 2)`` array.""" def __init__( self, points: np.ndarray | list, screenshot_image_path: str | Path, output_image_path: str | Path | None = None, *, point_color_bgr: tuple[int, int, int] = (0, 255, 0), radius: int = 8, thickness: int = -1, ) -> None: self._match_dict = { "screenshot_image_path": str(repo_path(screenshot_image_path)), "matched_keypoints_original_xy": np.asarray(points, dtype=np.float64).tolist(), } self._output_image_path = ( repo_path(output_image_path) if output_image_path is not None else None ) self._style = dict( point_color_bgr=point_color_bgr, radius=radius, thickness=thickness, ) def draw_point(self) -> Path: return draw_point(self._match_dict, self._output_image_path, **self._style) def main() -> None: parser = argparse.ArgumentParser(description="在截图上绘制关键点并保存。") parser.add_argument("screenshot", type=Path, help="截图路径(相对路径相对仓库根目录)") g = parser.add_mutually_exclusive_group(required=True) g.add_argument("--npy", type=Path, help="点坐标 .npy") g.add_argument("--xy", type=str, help='点列表,如 "100,200;300,400"') parser.add_argument( "-o", "--out", type=Path, default=_OUTPUT_PATH, help=f"输出路径(默认 {_OUTPUT_PATH})", ) parser.add_argument("--radius", type=int, default=8) args = parser.parse_args() if args.npy is not None: pts = _load_points_from_npy(repo_path(args.npy)) else: pts = _parse_xy_string(args.xy) out_path = repo_path(args.out) d = { "screenshot_image_path": str(repo_path(args.screenshot)), "matched_keypoints_original_xy": pts.tolist(), } draw_point(d, out_path, radius=args.radius) print(f"已写入:{out_path}", flush=True) if __name__ == "__main__": if sys.platform == "win32": for stream in (sys.stdout, sys.stderr): try: stream.reconfigure(encoding="utf-8") except Exception: pass main() __all__ = ["draw_point", "DrawPoint"]