draw_point.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. """在图片上绘制绿色点并保存为 output/keypoint.png。
  2. 程序里可调用 ``draw_point(match_dict, \"output/point.png\")``,字典须含
  3. ``screenshot_image_path``、``matched_keypoints_original_xy``。
  4. 命令行示例(在仓库根目录执行)::
  5. python -m thirdparty.draw_point input/screenshot.png --npy points.npy
  6. python -m thirdparty.draw_point input/screenshot.png --xy \"100,200;300,400;500,600\"
  7. .npy 形状须为 (N, 2) 或长度 2N 的一维数组(按 x0,y0,x1,y1,… 展开)。
  8. """
  9. from __future__ import annotations
  10. import argparse
  11. import sys
  12. from pathlib import Path
  13. import cv2
  14. import numpy as np
  15. from routing import OUTPUT_DIR, repo_path
  16. _OUTPUT_PATH = OUTPUT_DIR / "keypoint.png"
  17. def _load_points_from_npy(path: Path) -> np.ndarray:
  18. arr = np.load(path, allow_pickle=False)
  19. if arr.ndim == 1:
  20. if arr.size % 2 != 0:
  21. raise ValueError(f"一维点数组长度须为偶数,当前 size={arr.size}")
  22. arr = arr.reshape(-1, 2)
  23. elif arr.ndim == 2:
  24. if arr.shape[1] != 2:
  25. raise ValueError(f".npy 二维数组形状须为 (N,2),当前 {arr.shape}")
  26. else:
  27. raise ValueError(f"不支持的数组维度: {arr.ndim}")
  28. return arr.astype(np.float64)
  29. def _parse_xy_string(s: str) -> np.ndarray:
  30. parts = [p.strip() for p in s.split(";") if p.strip()]
  31. if not parts:
  32. raise ValueError("未解析到任何点,格式示例:100,200;300,400")
  33. rows: list[list[float]] = []
  34. for p in parts:
  35. xy = [float(x.strip()) for x in p.split(",")]
  36. if len(xy) != 2:
  37. raise ValueError(f"每点须为 x,y,无效片段:{p!r}")
  38. rows.append(xy)
  39. return np.asarray(rows, dtype=np.float64)
  40. def draw_point(
  41. match_dict: dict,
  42. output_image_path: str | Path | None = None,
  43. *,
  44. point_color_bgr: tuple[int, int, int] = (0, 255, 0),
  45. radius: int = 8,
  46. thickness: int = -1,
  47. ) -> Path:
  48. """在 ``match_dict['screenshot_image_path']`` 上绘制 ``matched_keypoints_original_xy``。"""
  49. out = (
  50. repo_path(output_image_path)
  51. if output_image_path is not None
  52. else _OUTPUT_PATH.resolve()
  53. )
  54. scr = repo_path(match_dict["screenshot_image_path"])
  55. raw = match_dict.get("matched_keypoints_original_xy")
  56. if raw is None:
  57. raise KeyError("match_dict 缺少 matched_keypoints_original_xy")
  58. pts = np.asarray(raw, dtype=np.float64).reshape(-1, 2)
  59. bgr = cv2.imread(str(scr), cv2.IMREAD_COLOR)
  60. if bgr is None:
  61. raise FileNotFoundError(f"无法读取截图:{scr}")
  62. for x, y in pts:
  63. cv2.circle(
  64. bgr,
  65. (int(round(x)), int(round(y))),
  66. radius,
  67. point_color_bgr,
  68. thickness,
  69. lineType=cv2.LINE_AA,
  70. )
  71. out.parent.mkdir(parents=True, exist_ok=True)
  72. if not cv2.imwrite(str(out), bgr):
  73. raise OSError(f"无法写入:{out}")
  74. return out
  75. class DrawPoint:
  76. """Thin wrapper around :func:`draw_point` when you already have an ``(N, 2)`` array."""
  77. def __init__(
  78. self,
  79. points: np.ndarray | list,
  80. screenshot_image_path: str | Path,
  81. output_image_path: str | Path | None = None,
  82. *,
  83. point_color_bgr: tuple[int, int, int] = (0, 255, 0),
  84. radius: int = 8,
  85. thickness: int = -1,
  86. ) -> None:
  87. self._match_dict = {
  88. "screenshot_image_path": str(repo_path(screenshot_image_path)),
  89. "matched_keypoints_original_xy": np.asarray(points, dtype=np.float64).tolist(),
  90. }
  91. self._output_image_path = (
  92. repo_path(output_image_path) if output_image_path is not None else None
  93. )
  94. self._style = dict(
  95. point_color_bgr=point_color_bgr,
  96. radius=radius,
  97. thickness=thickness,
  98. )
  99. def draw_point(self) -> Path:
  100. return draw_point(self._match_dict, self._output_image_path, **self._style)
  101. def main() -> None:
  102. parser = argparse.ArgumentParser(description="在截图上绘制关键点并保存。")
  103. parser.add_argument("screenshot", type=Path, help="截图路径(相对路径相对仓库根目录)")
  104. g = parser.add_mutually_exclusive_group(required=True)
  105. g.add_argument("--npy", type=Path, help="点坐标 .npy")
  106. g.add_argument("--xy", type=str, help='点列表,如 "100,200;300,400"')
  107. parser.add_argument(
  108. "-o",
  109. "--out",
  110. type=Path,
  111. default=_OUTPUT_PATH,
  112. help=f"输出路径(默认 {_OUTPUT_PATH})",
  113. )
  114. parser.add_argument("--radius", type=int, default=8)
  115. args = parser.parse_args()
  116. if args.npy is not None:
  117. pts = _load_points_from_npy(repo_path(args.npy))
  118. else:
  119. pts = _parse_xy_string(args.xy)
  120. out_path = repo_path(args.out)
  121. d = {
  122. "screenshot_image_path": str(repo_path(args.screenshot)),
  123. "matched_keypoints_original_xy": pts.tolist(),
  124. }
  125. draw_point(d, out_path, radius=args.radius)
  126. print(f"已写入:{out_path}", flush=True)
  127. if __name__ == "__main__":
  128. if sys.platform == "win32":
  129. for stream in (sys.stdout, sys.stderr):
  130. try:
  131. stream.reconfigure(encoding="utf-8")
  132. except Exception:
  133. pass
  134. main()
  135. __all__ = ["draw_point", "DrawPoint"]