| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450 |
- from fastapi import FastAPI, UploadFile, File, HTTPException
- from fastapi.responses import JSONResponse, StreamingResponse
- import uvicorn
- import os
- import zipfile
- import shutil
- from pathlib import Path
- import sys
- from typing import Generator
- import asyncio
- import websockets
- from contextlib import asynccontextmanager
- # 添加当前目录到路径,以便导入 birefnet-matting
- script_dir = os.path.dirname(os.path.abspath(__file__))
- sys.path.insert(0, script_dir)
- # 导入 birefnet-matting 模块
- try:
- import importlib.util
- spec = importlib.util.spec_from_file_location("birefnet_matting", os.path.join(script_dir, "birefnet-matting.py"))
- birefnet_matting = importlib.util.module_from_spec(spec)
- spec.loader.exec_module(birefnet_matting)
- except Exception as e:
- print(f"警告: 无法导入 birefnet-matting 模块: {e}")
- birefnet_matting = None
- # 创建必要的文件夹
- REC_FOLDER = os.path.join(script_dir, "rec")
- SEND_FOLDER = os.path.join(script_dir, "send")
- os.makedirs(REC_FOLDER, exist_ok=True)
- os.makedirs(SEND_FOLDER, exist_ok=True)
- # WebSocket 客户端配置
- WS_SERVER_URL = "ws://localhost:9527"
- ws_connection = None
- ws_running = True
- async def process_received_zip(filename: str, zip_data: bytes, websocket):
- """处理接收到的 zip 包"""
- import io
-
- print(f"[WebSocket] 开始处理任务: {filename}")
-
- task_rec_dir = os.path.join(REC_FOLDER, filename)
- task_send_dir = os.path.join(SEND_FOLDER, filename)
-
- try:
- # 创建任务文件夹
- os.makedirs(task_rec_dir, exist_ok=True)
- os.makedirs(task_send_dir, exist_ok=True)
-
- # 保存并解压 zip 文件
- zip_path = os.path.join(REC_FOLDER, f"{filename}.zip")
- with open(zip_path, 'wb') as f:
- f.write(zip_data)
- print(f"[WebSocket] 已保存 zip 文件: {zip_path}")
-
- # 解压到 rec/{filename}/ 文件夹
- with zipfile.ZipFile(zip_path, 'r') as zf:
- zf.extractall(task_rec_dir)
- print(f"[WebSocket] 已解压到: {task_rec_dir}")
-
- # 删除原始 zip 文件
- try:
- os.remove(zip_path)
- except:
- pass
-
- # 检查图片文件
- supported_formats = ['.jpg', '.jpeg', '.png', '.bmp', '.webp', '.tiff', '.tif']
- image_files = []
- for f in os.listdir(task_rec_dir):
- ext = os.path.splitext(f)[1].lower()
- if ext in supported_formats:
- image_files.append(f)
-
- if not image_files:
- print(f"[WebSocket] 未找到图片文件")
- return
-
- print(f"[WebSocket] 找到 {len(image_files)} 张图片,开始处理...")
-
- # 调用 birefnet-matting 处理
- if birefnet_matting:
- birefnet_matting.process_folder_birefnet_by_id(filename, REC_FOLDER, SEND_FOLDER)
- else:
- print("[WebSocket] birefnet_matting 模块未加载,跳过处理")
- # 如果模块未加载,直接复制文件作为测试
- for img in image_files:
- src = os.path.join(task_rec_dir, img)
- dst = os.path.join(task_send_dir, img)
- shutil.copy(src, dst)
-
- # 等待处理完成,检查输出文件
- output_zip_path = os.path.join(task_send_dir, f"{filename}.zip")
-
- # 如果 birefnet 没有生成 zip,我们自己打包
- if not os.path.exists(output_zip_path):
- print(f"[WebSocket] 正在打包处理结果...")
- # 检查 send 文件夹中的图片
- output_images = []
- for f in os.listdir(task_send_dir):
- ext = os.path.splitext(f)[1].lower()
- if ext in supported_formats:
- output_images.append(f)
-
- if output_images:
- with zipfile.ZipFile(output_zip_path, 'w', zipfile.ZIP_DEFLATED) as zf:
- for img in output_images:
- img_path = os.path.join(task_send_dir, img)
- zf.write(img_path, img)
- print(f"[WebSocket] 已创建输出 zip: {output_zip_path}")
-
- # 读取处理后的 zip 并发回
- if os.path.exists(output_zip_path):
- with open(output_zip_path, 'rb') as f:
- result_zip_data = f.read()
-
- # 发送格式:4字节文件名长度 + 文件名 + zip数据
- filename_bytes = filename.encode('utf-8')
- message = len(filename_bytes).to_bytes(4, 'big') + filename_bytes + result_zip_data
-
- await websocket.send(message)
- print(f"[WebSocket] 已发送处理结果,大小: {len(result_zip_data)} 字节")
- else:
- print(f"[WebSocket] 未找到输出文件: {output_zip_path}")
-
- # 清理临时文件
- print(f"[WebSocket] 清理临时文件...")
- if os.path.exists(task_rec_dir):
- shutil.rmtree(task_rec_dir)
- if os.path.exists(task_send_dir):
- shutil.rmtree(task_send_dir)
- print(f"[WebSocket] 任务 {filename} 处理完成")
-
- except Exception as e:
- import traceback
- print(f"[WebSocket] 处理任务失败: {e}")
- traceback.print_exc()
- # 清理
- if os.path.exists(task_rec_dir):
- try:
- shutil.rmtree(task_rec_dir)
- except:
- pass
- if os.path.exists(task_send_dir):
- try:
- shutil.rmtree(task_send_dir)
- except:
- pass
- async def websocket_client():
- """WebSocket 客户端,自动连接并在断线后每秒重连"""
- global ws_connection, ws_running
-
- while ws_running:
- try:
- print(f"正在连接到 WebSocket 服务器: {WS_SERVER_URL}")
- async with websockets.connect(WS_SERVER_URL, max_size=100 * 1024 * 1024) as websocket:
- ws_connection = websocket
- print(f"已成功连接到 WebSocket 服务器: {WS_SERVER_URL}")
-
- # 保持连接并处理消息
- while ws_running:
- try:
- # 接收消息(设置超时以便定期检查 ws_running 状态)
- message = await asyncio.wait_for(websocket.recv(), timeout=5.0)
-
- # 检查是否是二进制数据(zip 包)
- if isinstance(message, bytes):
- print(f"[WebSocket] 收到二进制数据,大小: {len(message)} 字节")
- # 解析消息:前4字节是文件名长度,接着是文件名,然后是zip数据
- try:
- filename_len = int.from_bytes(message[:4], 'big')
- filename = message[4:4+filename_len].decode('utf-8')
- zip_data = message[4+filename_len:]
-
- print(f"[WebSocket] 收到任务: {filename}, zip大小: {len(zip_data)} 字节")
-
- # 异步处理任务
- asyncio.create_task(process_received_zip(filename, zip_data, websocket))
-
- except Exception as e:
- print(f"[WebSocket] 解析数据失败: {e}")
- else:
- print(f"收到 WebSocket 消息: {message}")
-
- except asyncio.TimeoutError:
- # 超时只是为了定期检查 ws_running 状态,继续循环
- continue
- except websockets.exceptions.ConnectionClosed:
- print("WebSocket 连接已关闭")
- break
-
- except (websockets.exceptions.ConnectionClosed,
- websockets.exceptions.InvalidStatusCode,
- ConnectionRefusedError,
- OSError) as e:
- ws_connection = None
- if ws_running:
- print(f"WebSocket 连接失败或断开: {e},1秒后重试...")
- await asyncio.sleep(1)
- except Exception as e:
- ws_connection = None
- if ws_running:
- print(f"WebSocket 发生错误: {e},1秒后重试...")
- await asyncio.sleep(1)
-
- print("WebSocket 客户端已停止")
- @asynccontextmanager
- async def lifespan(app: FastAPI):
- """FastAPI 生命周期管理"""
- global ws_running
-
- # 启动时创建 WebSocket 客户端任务
- ws_task = asyncio.create_task(websocket_client())
- print("WebSocket 客户端任务已启动")
-
- yield
-
- # 关闭时停止 WebSocket 客户端
- ws_running = False
- if ws_connection:
- await ws_connection.close()
- ws_task.cancel()
- try:
- await ws_task
- except asyncio.CancelledError:
- pass
- print("WebSocket 客户端任务已停止")
- app = FastAPI(title="BiRefNet Matting Server", version="1.0.0", lifespan=lifespan)
- @app.get("/")
- async def root():
- """根路径,返回欢迎信息"""
- return {"message": "欢迎使用 BiRefNet Matting Server", "status": "running"}
- @app.get("/health")
- async def health():
- """健康检查端点"""
- return {"status": "healthy"}
- @app.get("/api/info")
- async def info():
- """获取服务器信息"""
- return {
- "name": "BiRefNet Matting Server",
- "version": "1.0.0",
- "description": "图像抠图服务"
- }
- @app.post("/api/process")
- async def process_images(file: UploadFile = File(...)):
- """
- 接收zip图包,解压后使用BiRefNet进行抠图,然后打包返回
-
- Args:
- file: 上传的zip文件
-
- Returns:
- 处理后的zip文件
- """
- if birefnet_matting is None:
- raise HTTPException(status_code=500, detail="BiRefNet模块未正确加载")
-
- if not file.filename.endswith('.zip'):
- raise HTTPException(status_code=400, detail="只支持zip格式的文件")
-
- # 使用固定的rec和send文件夹
- script_dir = os.path.dirname(os.path.abspath(__file__))
- rec_dir = os.path.join(script_dir, "rec")
- send_dir = os.path.join(script_dir, "send")
- os.makedirs(rec_dir, exist_ok=True)
- os.makedirs(send_dir, exist_ok=True)
-
- zip_path = None
- output_zip_path = None
- response_sent = False
- unique_id = None
- task_rec_dir = None
- task_send_dir = None
-
- try:
- # 从zip文件名提取唯一ID(去掉.zip后缀)
- zip_filename = file.filename
- if not zip_filename.endswith('.zip'):
- raise HTTPException(status_code=400, detail="文件必须是zip格式")
-
- unique_id = zip_filename[:-4] # 去掉 .zip 后缀
- print(f"接收到任务ID: {unique_id}")
-
- # 创建对应的文件夹
- task_rec_dir = os.path.join(rec_dir, unique_id)
- task_send_dir = os.path.join(send_dir, unique_id)
- os.makedirs(task_rec_dir, exist_ok=True)
- os.makedirs(task_send_dir, exist_ok=True)
-
- # 保存上传的zip文件到rec文件夹
- zip_path = os.path.join(rec_dir, zip_filename)
- with open(zip_path, "wb") as f:
- shutil.copyfileobj(file.file, f)
-
- # 解压zip文件到rec/{ID}/文件夹
- print(f"正在解压文件到: {task_rec_dir}")
- with zipfile.ZipFile(zip_path, 'r') as zip_ref:
- zip_ref.extractall(task_rec_dir)
-
- # 删除上传的zip文件(已解压,不再需要)
- try:
- os.remove(zip_path)
- except:
- pass
-
- # 检查rec/{ID}/文件夹中是否有图片文件
- supported_formats = ['.jpg', '.jpeg', '.png', '.bmp', '.webp', '.tiff', '.tif']
- rec_path = Path(task_rec_dir)
- image_files = []
- for ext in supported_formats:
- # 只使用rglob递归查找,避免重复计算
- image_files.extend(rec_path.rglob(f'*{ext}'))
- image_files.extend(rec_path.rglob(f'*{ext.upper()}'))
-
- # 去重(使用绝对路径去重,避免重复处理)
- unique_paths = set()
- unique_files = []
- for f in image_files:
- abs_path = str(f.resolve())
- if abs_path not in unique_paths:
- unique_paths.add(abs_path)
- unique_files.append(f)
- image_files = unique_files
-
- if not image_files:
- raise HTTPException(status_code=400, detail="zip文件中未找到支持的图片文件")
-
- # 记录实际接收到的图片数量
- received_image_count = len(image_files)
- print(f"接收到 {received_image_count} 张图片")
- print("图片列表:")
- for img in image_files:
- print(f" - {img.name}")
- print()
-
- print("开始处理图片...")
- # 调用birefnet-matting处理图片,传递ID
- birefnet_matting.process_folder_birefnet_by_id(unique_id, rec_dir, send_dir)
-
- # 等待处理完成,检查send/{ID}/文件夹中的zip文件
- output_zip_path = os.path.join(task_send_dir, f"{unique_id}.zip")
-
- # 等待zip文件生成(最多等待5分钟)
- import time
- max_wait_time = 300 # 5分钟
- wait_interval = 1 # 每秒检查一次
- waited_time = 0
-
- print("等待处理完成...")
- while not os.path.exists(output_zip_path) and waited_time < max_wait_time:
- time.sleep(wait_interval)
- waited_time += wait_interval
- if waited_time % 10 == 0:
- print(f" 已等待 {waited_time} 秒...")
-
- if not os.path.exists(output_zip_path):
- raise HTTPException(status_code=500, detail="处理超时,未生成输出文件")
-
- print(f"处理完成,找到结果文件: {output_zip_path}")
-
- # 读取zip文件内容并返回
- def generate() -> Generator[bytes, None, None]:
- try:
- with open(output_zip_path, 'rb') as f:
- while True:
- chunk = f.read(8192) # 8KB chunks
- if not chunk:
- break
- yield chunk
- finally:
- # 文件发送完成后清理临时文件
- print("正在清理临时文件...")
- # 清理 rec/{ID}/ 文件夹
- if os.path.exists(task_rec_dir):
- try:
- shutil.rmtree(task_rec_dir)
- print(f"已清理: {task_rec_dir}")
- except:
- pass
- # 清理 send/{ID}/ 文件夹
- if os.path.exists(task_send_dir):
- try:
- shutil.rmtree(task_send_dir)
- print(f"已清理: {task_send_dir}")
- except:
- pass
- print("临时文件清理完成")
-
- response_sent = True
- return StreamingResponse(
- generate(),
- media_type="application/zip",
- headers={"Content-Disposition": f"attachment; filename={unique_id}.zip"}
- )
-
- except zipfile.BadZipFile:
- # 清理临时文件
- if task_rec_dir and os.path.exists(task_rec_dir):
- try:
- shutil.rmtree(task_rec_dir)
- except:
- pass
- if task_send_dir and os.path.exists(task_send_dir):
- try:
- shutil.rmtree(task_send_dir)
- except:
- pass
- raise HTTPException(status_code=400, detail="无效的zip文件")
- except Exception as e:
- import traceback
- error_msg = f"处理过程中出错: {str(e)}\n{traceback.format_exc()}"
- print(error_msg)
- # 清理临时文件
- if task_rec_dir and os.path.exists(task_rec_dir):
- try:
- shutil.rmtree(task_rec_dir)
- except:
- pass
- if task_send_dir and os.path.exists(task_send_dir):
- try:
- shutil.rmtree(task_send_dir)
- except:
- pass
- raise HTTPException(status_code=500, detail=f"处理失败: {str(e)}")
- if __name__ == "__main__":
- uvicorn.run(app, host="0.0.0.0", port=8000)
|