ocr.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. """
  4. 使用项目内 python/RapidOCR 对图片做 OCR,结果输出为 JSON 到 stdout。
  5. 输出 JSON 到 stdout,供 node 层调用。
  6. 用法1: python ocr.py --image <图片路径> [--project-root <项目根目录>]
  7. 输出: {"success": true, "text": "识别结果"} 或 {"success": false, "error": "..."}
  8. 用法2: python ocr.py --image <图片路径> --find-text "要查找的文字" [--project-root <项目根目录>]
  9. 在图中查找该文字,返回中心点: {"success": true, "x": 123, "y": 456} 或 {"success": false, "error": "..."}
  10. """
  11. import sys
  12. import os
  13. import json
  14. import argparse
  15. def box_center(box):
  16. """box 为 4 个点 [[x1,y1],[x2,y2],[x3,y3],[x4,y4]] 或类似,返回中心 (cx, cy)"""
  17. if box is None or len(box) < 4:
  18. return None
  19. try:
  20. xs = [float(p[0]) for p in box]
  21. ys = [float(p[1]) for p in box]
  22. except (TypeError, IndexError):
  23. return None
  24. return (sum(xs) / len(xs), sum(ys) / len(ys))
  25. def polys_to_bbox_center(polys):
  26. """多个四边形合并为外接矩形,返回中心 (cx, cy)。polys 为 list of 4-point boxes。"""
  27. if not polys:
  28. return None
  29. all_xs, all_ys = [], []
  30. for box in polys:
  31. if box is None or len(box) < 4:
  32. continue
  33. try:
  34. for p in box:
  35. all_xs.append(float(p[0]))
  36. all_ys.append(float(p[1]))
  37. except (TypeError, IndexError):
  38. continue
  39. if not all_xs or not all_ys:
  40. return None
  41. cx = (min(all_xs) + max(all_xs)) / 2
  42. cy = (min(all_ys) + max(all_ys)) / 2
  43. return (cx, cy)
  44. def normalize_for_match(s):
  45. """规范化后用于匹配:去空格、全角括号/数字转半角"""
  46. if not s:
  47. return ''
  48. s = (s or '').strip().replace(' ', '').replace('\u3000', '')
  49. t = []
  50. for c in s:
  51. if c in ('(', '[', '{'):
  52. t.append('(')
  53. elif c in (')', ']', '}'):
  54. t.append(')')
  55. elif '\uff10' <= c <= '\uff19':
  56. t.append(chr(ord(c) - 0xFEE0))
  57. else:
  58. t.append(c)
  59. return ''.join(t)
  60. def main():
  61. ap = argparse.ArgumentParser()
  62. ap.add_argument('--image', required=True, help='图片路径(绝对或相对)')
  63. ap.add_argument('--find-text', default=None, help='要查找的文字;若指定则返回该文字在图中的中心点 x,y')
  64. ap.add_argument('--project-root', default=None, help='项目根目录,用于解析相对路径及加载 RapidOCR')
  65. args = ap.parse_args()
  66. project_root = args.project_root
  67. if not project_root:
  68. project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
  69. project_root = os.path.normpath(project_root)
  70. rapidocr_python = os.path.join(project_root, 'python', 'RapidOCR', 'python')
  71. if os.path.isdir(rapidocr_python) and rapidocr_python not in sys.path:
  72. sys.path.insert(0, rapidocr_python)
  73. image_path = args.image
  74. if not os.path.isabs(image_path):
  75. image_path = os.path.normpath(os.path.join(project_root, image_path))
  76. if not os.path.isfile(image_path):
  77. out = {'success': False, 'error': f'图片不存在: {image_path}'}
  78. print(json.dumps(out, ensure_ascii=False))
  79. return
  80. try:
  81. from rapidocr import RapidOCR
  82. except ImportError as e:
  83. out = {
  84. 'success': False,
  85. 'error': f'RapidOCR 导入失败: {str(e).strip()}。请确保 python/RapidOCR 存在且安装依赖(如 pip install onnxruntime;或 cd python/RapidOCR/python && pip install -e .)。',
  86. }
  87. print(json.dumps(out, ensure_ascii=False))
  88. sys.exit(1)
  89. try:
  90. find_text = (args.find_text or '').strip()
  91. engine = RapidOCR()
  92. result = engine(image_path)
  93. if result.boxes is None or result.txts is None or len(result.txts) == 0:
  94. if find_text:
  95. out = {'success': False, 'error': f'图中未识别到文字,或未找到: "{find_text}"'}
  96. print(json.dumps(out, ensure_ascii=False))
  97. sys.exit(1)
  98. out = {'success': True, 'text': ''}
  99. print(json.dumps(out, ensure_ascii=False))
  100. return
  101. boxes = result.boxes
  102. txts = [str(t).strip() if t is not None else '' for t in result.txts]
  103. n = min(len(boxes), len(txts))
  104. def poly_at(j):
  105. if j >= len(boxes):
  106. return None
  107. p = boxes[j]
  108. if p is not None and hasattr(p, 'tolist'):
  109. return p.tolist()
  110. return list(p) if p is not None else None
  111. if find_text:
  112. find_norm = normalize_for_match(find_text)
  113. # 匹配顺序:单条严格等于 → 单条包含 → 多段拼接严格等于
  114. for i in range(n):
  115. text = txts[i] if i < len(txts) else ''
  116. text_norm = normalize_for_match(text)
  117. if text == find_text or (find_norm and text_norm == find_norm):
  118. center = box_center(poly_at(i))
  119. if center is not None:
  120. out = {'success': True, 'x': int(round(center[0])), 'y': int(round(center[1]))}
  121. print(json.dumps(out, ensure_ascii=False))
  122. return
  123. for i in range(n):
  124. text = txts[i] if i < len(txts) else ''
  125. text_norm = normalize_for_match(text)
  126. if find_text in text or (find_norm and find_norm in text_norm):
  127. center = box_center(poly_at(i))
  128. if center is not None:
  129. out = {'success': True, 'x': int(round(center[0])), 'y': int(round(center[1]))}
  130. print(json.dumps(out, ensure_ascii=False))
  131. return
  132. for start in range(n):
  133. for end in range(start + 1, n + 1):
  134. seg_text = ''.join(txts[start:end])
  135. seg_norm = normalize_for_match(seg_text)
  136. if seg_text == find_text or (find_norm and seg_norm == find_norm):
  137. merge_polys = [poly_at(j) for j in range(start, end)]
  138. merge_polys = [p for p in merge_polys if p is not None and len(p) >= 4]
  139. center = polys_to_bbox_center(merge_polys)
  140. if center is not None:
  141. out = {'success': True, 'x': int(round(center[0])), 'y': int(round(center[1]))}
  142. print(json.dumps(out, ensure_ascii=False))
  143. return
  144. break
  145. # 多段拼接后包含查找词(如图中为「下一步(2)」时仍能匹配「下一步」)
  146. for start in range(n):
  147. for end in range(start + 1, n + 1):
  148. seg_text = ''.join(txts[start:end])
  149. seg_norm = normalize_for_match(seg_text)
  150. if (find_text in seg_text) or (find_norm and find_norm in seg_norm):
  151. merge_polys = [poly_at(j) for j in range(start, end)]
  152. merge_polys = [p for p in merge_polys if p is not None and len(p) >= 4]
  153. center = polys_to_bbox_center(merge_polys)
  154. if center is not None:
  155. out = {'success': True, 'x': int(round(center[0])), 'y': int(round(center[1]))}
  156. print(json.dumps(out, ensure_ascii=False))
  157. return
  158. break
  159. out = {'success': False, 'error': f'图中未找到文字: "{find_text}"'}
  160. print(json.dumps(out, ensure_ascii=False))
  161. sys.exit(1)
  162. else:
  163. text = '\n'.join(txts) if txts else ''
  164. out = {'success': True, 'text': text}
  165. print(json.dumps(out, ensure_ascii=False))
  166. except Exception as e:
  167. out = {'success': False, 'error': str(e).strip()}
  168. print(json.dumps(out, ensure_ascii=False))
  169. sys.exit(1)
  170. if __name__ == '__main__':
  171. main()