server.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450
  1. from fastapi import FastAPI, UploadFile, File, HTTPException
  2. from fastapi.responses import JSONResponse, StreamingResponse
  3. import uvicorn
  4. import os
  5. import zipfile
  6. import shutil
  7. from pathlib import Path
  8. import sys
  9. from typing import Generator
  10. import asyncio
  11. import websockets
  12. from contextlib import asynccontextmanager
  13. # 添加当前目录到路径,以便导入 birefnet-matting
  14. script_dir = os.path.dirname(os.path.abspath(__file__))
  15. sys.path.insert(0, script_dir)
  16. # 导入 birefnet-matting 模块
  17. try:
  18. import importlib.util
  19. spec = importlib.util.spec_from_file_location("birefnet_matting", os.path.join(script_dir, "birefnet-matting.py"))
  20. birefnet_matting = importlib.util.module_from_spec(spec)
  21. spec.loader.exec_module(birefnet_matting)
  22. except Exception as e:
  23. print(f"警告: 无法导入 birefnet-matting 模块: {e}")
  24. birefnet_matting = None
  25. # 创建必要的文件夹
  26. REC_FOLDER = os.path.join(script_dir, "rec")
  27. SEND_FOLDER = os.path.join(script_dir, "send")
  28. os.makedirs(REC_FOLDER, exist_ok=True)
  29. os.makedirs(SEND_FOLDER, exist_ok=True)
  30. # WebSocket 客户端配置
  31. WS_SERVER_URL = "ws://localhost:9527"
  32. ws_connection = None
  33. ws_running = True
  34. async def process_received_zip(filename: str, zip_data: bytes, websocket):
  35. """处理接收到的 zip 包"""
  36. import io
  37. print(f"[WebSocket] 开始处理任务: {filename}")
  38. task_rec_dir = os.path.join(REC_FOLDER, filename)
  39. task_send_dir = os.path.join(SEND_FOLDER, filename)
  40. try:
  41. # 创建任务文件夹
  42. os.makedirs(task_rec_dir, exist_ok=True)
  43. os.makedirs(task_send_dir, exist_ok=True)
  44. # 保存并解压 zip 文件
  45. zip_path = os.path.join(REC_FOLDER, f"{filename}.zip")
  46. with open(zip_path, 'wb') as f:
  47. f.write(zip_data)
  48. print(f"[WebSocket] 已保存 zip 文件: {zip_path}")
  49. # 解压到 rec/{filename}/ 文件夹
  50. with zipfile.ZipFile(zip_path, 'r') as zf:
  51. zf.extractall(task_rec_dir)
  52. print(f"[WebSocket] 已解压到: {task_rec_dir}")
  53. # 删除原始 zip 文件
  54. try:
  55. os.remove(zip_path)
  56. except:
  57. pass
  58. # 检查图片文件
  59. supported_formats = ['.jpg', '.jpeg', '.png', '.bmp', '.webp', '.tiff', '.tif']
  60. image_files = []
  61. for f in os.listdir(task_rec_dir):
  62. ext = os.path.splitext(f)[1].lower()
  63. if ext in supported_formats:
  64. image_files.append(f)
  65. if not image_files:
  66. print(f"[WebSocket] 未找到图片文件")
  67. return
  68. print(f"[WebSocket] 找到 {len(image_files)} 张图片,开始处理...")
  69. # 调用 birefnet-matting 处理
  70. if birefnet_matting:
  71. birefnet_matting.process_folder_birefnet_by_id(filename, REC_FOLDER, SEND_FOLDER)
  72. else:
  73. print("[WebSocket] birefnet_matting 模块未加载,跳过处理")
  74. # 如果模块未加载,直接复制文件作为测试
  75. for img in image_files:
  76. src = os.path.join(task_rec_dir, img)
  77. dst = os.path.join(task_send_dir, img)
  78. shutil.copy(src, dst)
  79. # 等待处理完成,检查输出文件
  80. output_zip_path = os.path.join(task_send_dir, f"{filename}.zip")
  81. # 如果 birefnet 没有生成 zip,我们自己打包
  82. if not os.path.exists(output_zip_path):
  83. print(f"[WebSocket] 正在打包处理结果...")
  84. # 检查 send 文件夹中的图片
  85. output_images = []
  86. for f in os.listdir(task_send_dir):
  87. ext = os.path.splitext(f)[1].lower()
  88. if ext in supported_formats:
  89. output_images.append(f)
  90. if output_images:
  91. with zipfile.ZipFile(output_zip_path, 'w', zipfile.ZIP_DEFLATED) as zf:
  92. for img in output_images:
  93. img_path = os.path.join(task_send_dir, img)
  94. zf.write(img_path, img)
  95. print(f"[WebSocket] 已创建输出 zip: {output_zip_path}")
  96. # 读取处理后的 zip 并发回
  97. if os.path.exists(output_zip_path):
  98. with open(output_zip_path, 'rb') as f:
  99. result_zip_data = f.read()
  100. # 发送格式:4字节文件名长度 + 文件名 + zip数据
  101. filename_bytes = filename.encode('utf-8')
  102. message = len(filename_bytes).to_bytes(4, 'big') + filename_bytes + result_zip_data
  103. await websocket.send(message)
  104. print(f"[WebSocket] 已发送处理结果,大小: {len(result_zip_data)} 字节")
  105. else:
  106. print(f"[WebSocket] 未找到输出文件: {output_zip_path}")
  107. # 清理临时文件
  108. print(f"[WebSocket] 清理临时文件...")
  109. if os.path.exists(task_rec_dir):
  110. shutil.rmtree(task_rec_dir)
  111. if os.path.exists(task_send_dir):
  112. shutil.rmtree(task_send_dir)
  113. print(f"[WebSocket] 任务 {filename} 处理完成")
  114. except Exception as e:
  115. import traceback
  116. print(f"[WebSocket] 处理任务失败: {e}")
  117. traceback.print_exc()
  118. # 清理
  119. if os.path.exists(task_rec_dir):
  120. try:
  121. shutil.rmtree(task_rec_dir)
  122. except:
  123. pass
  124. if os.path.exists(task_send_dir):
  125. try:
  126. shutil.rmtree(task_send_dir)
  127. except:
  128. pass
  129. async def websocket_client():
  130. """WebSocket 客户端,自动连接并在断线后每秒重连"""
  131. global ws_connection, ws_running
  132. while ws_running:
  133. try:
  134. print(f"正在连接到 WebSocket 服务器: {WS_SERVER_URL}")
  135. async with websockets.connect(WS_SERVER_URL, max_size=100 * 1024 * 1024) as websocket:
  136. ws_connection = websocket
  137. print(f"已成功连接到 WebSocket 服务器: {WS_SERVER_URL}")
  138. # 保持连接并处理消息
  139. while ws_running:
  140. try:
  141. # 接收消息(设置超时以便定期检查 ws_running 状态)
  142. message = await asyncio.wait_for(websocket.recv(), timeout=5.0)
  143. # 检查是否是二进制数据(zip 包)
  144. if isinstance(message, bytes):
  145. print(f"[WebSocket] 收到二进制数据,大小: {len(message)} 字节")
  146. # 解析消息:前4字节是文件名长度,接着是文件名,然后是zip数据
  147. try:
  148. filename_len = int.from_bytes(message[:4], 'big')
  149. filename = message[4:4+filename_len].decode('utf-8')
  150. zip_data = message[4+filename_len:]
  151. print(f"[WebSocket] 收到任务: {filename}, zip大小: {len(zip_data)} 字节")
  152. # 异步处理任务
  153. asyncio.create_task(process_received_zip(filename, zip_data, websocket))
  154. except Exception as e:
  155. print(f"[WebSocket] 解析数据失败: {e}")
  156. else:
  157. print(f"收到 WebSocket 消息: {message}")
  158. except asyncio.TimeoutError:
  159. # 超时只是为了定期检查 ws_running 状态,继续循环
  160. continue
  161. except websockets.exceptions.ConnectionClosed:
  162. print("WebSocket 连接已关闭")
  163. break
  164. except (websockets.exceptions.ConnectionClosed,
  165. websockets.exceptions.InvalidStatusCode,
  166. ConnectionRefusedError,
  167. OSError) as e:
  168. ws_connection = None
  169. if ws_running:
  170. print(f"WebSocket 连接失败或断开: {e},1秒后重试...")
  171. await asyncio.sleep(1)
  172. except Exception as e:
  173. ws_connection = None
  174. if ws_running:
  175. print(f"WebSocket 发生错误: {e},1秒后重试...")
  176. await asyncio.sleep(1)
  177. print("WebSocket 客户端已停止")
  178. @asynccontextmanager
  179. async def lifespan(app: FastAPI):
  180. """FastAPI 生命周期管理"""
  181. global ws_running
  182. # 启动时创建 WebSocket 客户端任务
  183. ws_task = asyncio.create_task(websocket_client())
  184. print("WebSocket 客户端任务已启动")
  185. yield
  186. # 关闭时停止 WebSocket 客户端
  187. ws_running = False
  188. if ws_connection:
  189. await ws_connection.close()
  190. ws_task.cancel()
  191. try:
  192. await ws_task
  193. except asyncio.CancelledError:
  194. pass
  195. print("WebSocket 客户端任务已停止")
  196. app = FastAPI(title="BiRefNet Matting Server", version="1.0.0", lifespan=lifespan)
  197. @app.get("/")
  198. async def root():
  199. """根路径,返回欢迎信息"""
  200. return {"message": "欢迎使用 BiRefNet Matting Server", "status": "running"}
  201. @app.get("/health")
  202. async def health():
  203. """健康检查端点"""
  204. return {"status": "healthy"}
  205. @app.get("/api/info")
  206. async def info():
  207. """获取服务器信息"""
  208. return {
  209. "name": "BiRefNet Matting Server",
  210. "version": "1.0.0",
  211. "description": "图像抠图服务"
  212. }
  213. @app.post("/api/process")
  214. async def process_images(file: UploadFile = File(...)):
  215. """
  216. 接收zip图包,解压后使用BiRefNet进行抠图,然后打包返回
  217. Args:
  218. file: 上传的zip文件
  219. Returns:
  220. 处理后的zip文件
  221. """
  222. if birefnet_matting is None:
  223. raise HTTPException(status_code=500, detail="BiRefNet模块未正确加载")
  224. if not file.filename.endswith('.zip'):
  225. raise HTTPException(status_code=400, detail="只支持zip格式的文件")
  226. # 使用固定的rec和send文件夹
  227. script_dir = os.path.dirname(os.path.abspath(__file__))
  228. rec_dir = os.path.join(script_dir, "rec")
  229. send_dir = os.path.join(script_dir, "send")
  230. os.makedirs(rec_dir, exist_ok=True)
  231. os.makedirs(send_dir, exist_ok=True)
  232. zip_path = None
  233. output_zip_path = None
  234. response_sent = False
  235. unique_id = None
  236. task_rec_dir = None
  237. task_send_dir = None
  238. try:
  239. # 从zip文件名提取唯一ID(去掉.zip后缀)
  240. zip_filename = file.filename
  241. if not zip_filename.endswith('.zip'):
  242. raise HTTPException(status_code=400, detail="文件必须是zip格式")
  243. unique_id = zip_filename[:-4] # 去掉 .zip 后缀
  244. print(f"接收到任务ID: {unique_id}")
  245. # 创建对应的文件夹
  246. task_rec_dir = os.path.join(rec_dir, unique_id)
  247. task_send_dir = os.path.join(send_dir, unique_id)
  248. os.makedirs(task_rec_dir, exist_ok=True)
  249. os.makedirs(task_send_dir, exist_ok=True)
  250. # 保存上传的zip文件到rec文件夹
  251. zip_path = os.path.join(rec_dir, zip_filename)
  252. with open(zip_path, "wb") as f:
  253. shutil.copyfileobj(file.file, f)
  254. # 解压zip文件到rec/{ID}/文件夹
  255. print(f"正在解压文件到: {task_rec_dir}")
  256. with zipfile.ZipFile(zip_path, 'r') as zip_ref:
  257. zip_ref.extractall(task_rec_dir)
  258. # 删除上传的zip文件(已解压,不再需要)
  259. try:
  260. os.remove(zip_path)
  261. except:
  262. pass
  263. # 检查rec/{ID}/文件夹中是否有图片文件
  264. supported_formats = ['.jpg', '.jpeg', '.png', '.bmp', '.webp', '.tiff', '.tif']
  265. rec_path = Path(task_rec_dir)
  266. image_files = []
  267. for ext in supported_formats:
  268. # 只使用rglob递归查找,避免重复计算
  269. image_files.extend(rec_path.rglob(f'*{ext}'))
  270. image_files.extend(rec_path.rglob(f'*{ext.upper()}'))
  271. # 去重(使用绝对路径去重,避免重复处理)
  272. unique_paths = set()
  273. unique_files = []
  274. for f in image_files:
  275. abs_path = str(f.resolve())
  276. if abs_path not in unique_paths:
  277. unique_paths.add(abs_path)
  278. unique_files.append(f)
  279. image_files = unique_files
  280. if not image_files:
  281. raise HTTPException(status_code=400, detail="zip文件中未找到支持的图片文件")
  282. # 记录实际接收到的图片数量
  283. received_image_count = len(image_files)
  284. print(f"接收到 {received_image_count} 张图片")
  285. print("图片列表:")
  286. for img in image_files:
  287. print(f" - {img.name}")
  288. print()
  289. print("开始处理图片...")
  290. # 调用birefnet-matting处理图片,传递ID
  291. birefnet_matting.process_folder_birefnet_by_id(unique_id, rec_dir, send_dir)
  292. # 等待处理完成,检查send/{ID}/文件夹中的zip文件
  293. output_zip_path = os.path.join(task_send_dir, f"{unique_id}.zip")
  294. # 等待zip文件生成(最多等待5分钟)
  295. import time
  296. max_wait_time = 300 # 5分钟
  297. wait_interval = 1 # 每秒检查一次
  298. waited_time = 0
  299. print("等待处理完成...")
  300. while not os.path.exists(output_zip_path) and waited_time < max_wait_time:
  301. time.sleep(wait_interval)
  302. waited_time += wait_interval
  303. if waited_time % 10 == 0:
  304. print(f" 已等待 {waited_time} 秒...")
  305. if not os.path.exists(output_zip_path):
  306. raise HTTPException(status_code=500, detail="处理超时,未生成输出文件")
  307. print(f"处理完成,找到结果文件: {output_zip_path}")
  308. # 读取zip文件内容并返回
  309. def generate() -> Generator[bytes, None, None]:
  310. try:
  311. with open(output_zip_path, 'rb') as f:
  312. while True:
  313. chunk = f.read(8192) # 8KB chunks
  314. if not chunk:
  315. break
  316. yield chunk
  317. finally:
  318. # 文件发送完成后清理临时文件
  319. print("正在清理临时文件...")
  320. # 清理 rec/{ID}/ 文件夹
  321. if os.path.exists(task_rec_dir):
  322. try:
  323. shutil.rmtree(task_rec_dir)
  324. print(f"已清理: {task_rec_dir}")
  325. except:
  326. pass
  327. # 清理 send/{ID}/ 文件夹
  328. if os.path.exists(task_send_dir):
  329. try:
  330. shutil.rmtree(task_send_dir)
  331. print(f"已清理: {task_send_dir}")
  332. except:
  333. pass
  334. print("临时文件清理完成")
  335. response_sent = True
  336. return StreamingResponse(
  337. generate(),
  338. media_type="application/zip",
  339. headers={"Content-Disposition": f"attachment; filename={unique_id}.zip"}
  340. )
  341. except zipfile.BadZipFile:
  342. # 清理临时文件
  343. if task_rec_dir and os.path.exists(task_rec_dir):
  344. try:
  345. shutil.rmtree(task_rec_dir)
  346. except:
  347. pass
  348. if task_send_dir and os.path.exists(task_send_dir):
  349. try:
  350. shutil.rmtree(task_send_dir)
  351. except:
  352. pass
  353. raise HTTPException(status_code=400, detail="无效的zip文件")
  354. except Exception as e:
  355. import traceback
  356. error_msg = f"处理过程中出错: {str(e)}\n{traceback.format_exc()}"
  357. print(error_msg)
  358. # 清理临时文件
  359. if task_rec_dir and os.path.exists(task_rec_dir):
  360. try:
  361. shutil.rmtree(task_rec_dir)
  362. except:
  363. pass
  364. if task_send_dir and os.path.exists(task_send_dir):
  365. try:
  366. shutil.rmtree(task_send_dir)
  367. except:
  368. pass
  369. raise HTTPException(status_code=500, detail=f"处理失败: {str(e)}")
  370. if __name__ == "__main__":
  371. uvicorn.run(app, host="0.0.0.0", port=8000)