webui.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. import os
  2. import time
  3. import zipfile
  4. from flask import Flask, render_template, request, jsonify, send_file, redirect, url_for
  5. from werkzeug.utils import secure_filename
  6. from onnxocr.ocr_images_pdfs import OCRLogic
  7. import cv2
  8. import base64
  9. import numpy as np
  10. from onnxocr.onnx_paddleocr import ONNXPaddleOcr
  11. BASE_DIR = os.path.dirname(os.path.abspath(__file__))
  12. UPLOAD_ROOT = os.path.join(BASE_DIR, "uploads")
  13. RESULT_ROOT = os.path.join(BASE_DIR, "results")
  14. os.makedirs(UPLOAD_ROOT, exist_ok=True)
  15. os.makedirs(RESULT_ROOT, exist_ok=True)
  16. MODEL_OPTIONS = ["PP-OCRv5", "PP-OCRv4", "ch_ppocr_server_v2.0"]
  17. app = Flask(__name__, static_folder="static", template_folder="templates")
  18. app.config['MAX_CONTENT_LENGTH'] = 200 * 1024 * 1024 # 200MB
  19. ocr_logic = OCRLogic(lambda msg: print(msg))
  20. # 独立 OCR 模型实例,避免影响 ocr_logic
  21. ocr_model_api = ONNXPaddleOcr(use_angle_cls=True, use_gpu=False)
  22. @app.route("/")
  23. def index():
  24. return render_template("webui.html", model_options=MODEL_OPTIONS)
  25. @app.errorhandler(404)
  26. def not_found(e):
  27. path = request.path
  28. if not path.startswith("/static") and not path.startswith("/download"):
  29. return redirect(url_for("index"))
  30. return jsonify({"detail": "NotFound"}), 404
  31. @app.route("/set_model", methods=["POST"])
  32. def set_model():
  33. model_name = request.form.get("model_name")
  34. try:
  35. ocr_logic.set_model(model_name)
  36. return {"success": True, "msg": f"模型已切换为 {model_name}"}
  37. except Exception as e:
  38. return {"success": False, "msg": str(e)}
  39. @app.route("/ocr", methods=["POST"])
  40. def ocr_files():
  41. files = request.files.getlist("files")
  42. model_name = request.form.get("model_name")
  43. if not files or not model_name:
  44. return jsonify({"success": False, "msg": "缺少文件或模型参数"}), 400
  45. try:
  46. ocr_logic.set_model(model_name)
  47. except Exception as e:
  48. return jsonify({"success": False, "msg": f"模型切换失败: {e}"}), 500
  49. timestamp = time.strftime("%Y%m%d_%H%M%S")
  50. session_dir = os.path.join(RESULT_ROOT, timestamp)
  51. os.makedirs(session_dir, exist_ok=True)
  52. file_paths = []
  53. for file in files:
  54. filename = secure_filename(file.filename)
  55. file_path = os.path.join(session_dir, filename)
  56. file.save(file_path)
  57. file_paths.append(file_path)
  58. results = []
  59. def status_callback(msg): pass
  60. logic = OCRLogic(status_callback)
  61. logic.set_model(model_name)
  62. logic.run(file_paths, save_txt=True, merge_txt=False, output_img=False)
  63. txt_files = []
  64. for file_path in file_paths:
  65. out_dir = os.path.join(os.path.dirname(file_path), "Output_OCR")
  66. if not os.path.exists(out_dir):
  67. continue
  68. for fname in os.listdir(out_dir):
  69. if fname.endswith(".txt") and fname.startswith(os.path.splitext(os.path.basename(file_path))[0]):
  70. txt_files.append(os.path.join(out_dir, fname))
  71. with open(os.path.join(out_dir, fname), "r", encoding="utf-8") as f:
  72. content = f.read()
  73. results.append({"filename": fname, "content": content})
  74. zip_path = os.path.join(session_dir, f"ocr_txt_{timestamp}.zip")
  75. with zipfile.ZipFile(zip_path, "w") as zipf:
  76. for txt_file in txt_files:
  77. zipf.write(txt_file, os.path.basename(txt_file))
  78. return jsonify({
  79. "success": True,
  80. "results": results,
  81. "zip_url": f"/download/{timestamp}"
  82. })
  83. @app.route("/download/<timestamp>")
  84. def download_zip(timestamp):
  85. session_dir = os.path.join(RESULT_ROOT, timestamp)
  86. zip_path = os.path.join(session_dir, f"ocr_txt_{timestamp}.zip")
  87. if os.path.exists(zip_path):
  88. return send_file(zip_path, as_attachment=True, download_name=f"ocr_txt_{timestamp}.zip")
  89. return jsonify({"success": False, "msg": "文件不存在"}), 404
  90. @app.route("/ocr_api", methods=["POST"])
  91. def ocr_api():
  92. data = request.get_json()
  93. if not data or "image" not in data:
  94. return jsonify({"error": "Invalid request, 'image' field is required."}), 400
  95. image_base64 = data["image"]
  96. try:
  97. image_bytes = base64.b64decode(image_base64)
  98. image_np = np.frombuffer(image_bytes, dtype=np.uint8)
  99. img = cv2.imdecode(image_np, cv2.IMREAD_COLOR)
  100. if img is None:
  101. return jsonify({"error": "Failed to decode image from base64."}), 400
  102. except Exception as e:
  103. return jsonify({"error": f"Image decoding failed: {str(e)}"}), 400
  104. start_time = time.time()
  105. result = ocr_model_api.ocr(img)
  106. end_time = time.time()
  107. processing_time = end_time - start_time
  108. ocr_results = []
  109. for line in result[0]:
  110. if isinstance(line[0], (list, np.ndarray)):
  111. bounding_box = np.array(line[0]).reshape(4, 2).tolist()
  112. else:
  113. bounding_box = []
  114. ocr_results.append({
  115. "text": line[1][0],
  116. "confidence": float(line[1][1]),
  117. "bounding_box": bounding_box
  118. })
  119. return jsonify({
  120. "processing_time": processing_time,
  121. "results": ocr_results
  122. })
  123. if __name__ == "__main__":
  124. app.run(host="0.0.0.0", port=5005, debug=True)