| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150 |
- """
- BiRefNet 图像背景移除工具
- 用于高质量抠图处理
- """
- import sys
- import os
- # 添加 BiRefNet 目录到路径
- birefnet_dir = os.path.join(os.path.dirname(__file__), 'BiRefNet')
- sys.path.insert(0, birefnet_dir)
- try:
- from transformers import AutoModelForImageSegmentation
- from PIL import Image
- import torch
- import numpy as np
- from torchvision import transforms
- except ImportError as e:
- print(f"错误: 缺少必要的库 - {e}", file=sys.stderr)
- print("请安装: pip install transformers pillow torch numpy torchvision", file=sys.stderr)
- sys.exit(1)
- def process_image(input_path, output_path):
- """
- 处理单个图片文件
- """
- try:
- # 检查输入文件是否存在
- if not os.path.exists(input_path):
- print(f"错误: 输入文件不存在: {input_path}", file=sys.stderr)
- sys.exit(1)
-
- # 创建输出目录(如果不存在)
- output_dir = os.path.dirname(output_path)
- if output_dir and not os.path.exists(output_dir):
- os.makedirs(output_dir, exist_ok=True)
-
- print(f"正在加载 BiRefNet 模型...", flush=True)
-
- # 检查是否有本地模型路径
- local_model_path = os.path.join(birefnet_dir, 'model', 'BiRefNet')
- model_file = os.path.join(local_model_path, 'model.safetensors')
- config_file = os.path.join(local_model_path, 'birefnet.py')
- use_local = os.path.exists(model_file) and os.path.exists(config_file)
-
- # 加载 BiRefNet 模型
- try:
- if use_local:
- print(f"使用本地模型: {local_model_path}", flush=True)
- model = AutoModelForImageSegmentation.from_pretrained(
- local_model_path,
- trust_remote_code=True,
- local_files_only=True
- )
- else:
- print(f"从 HuggingFace 下载模型...", flush=True)
- print("提示: 如果下载失败,请运行 download-birefnet-model.py 下载模型到本地", flush=True)
- model = AutoModelForImageSegmentation.from_pretrained(
- 'ZhengPeng7/BiRefNet',
- trust_remote_code=True
- )
- except Exception as e:
- print(f"错误: 无法加载 BiRefNet 模型 - {e}", file=sys.stderr)
- if not use_local:
- print("提示: 请确保已安装 transformers 库并可以访问 HuggingFace", file=sys.stderr)
- print("或者运行以下命令下载模型到本地:", file=sys.stderr)
- print(" python download-birefnet-model.py", file=sys.stderr)
- sys.exit(1)
-
- # 设置设备
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
- model = model.to(device)
- model.eval()
-
- half_precision = device == 'cuda'
- if half_precision:
- model = model.half() # 使用 FP16 加速
-
- print(f"正在处理图片: {os.path.basename(input_path)}", flush=True)
-
- # 读取输入图像
- image_ori = Image.open(input_path).convert('RGB')
- original_size = image_ori.size
-
- # 预处理图像(使用 BiRefNet 的预处理方式)
- try:
- from image_proc import ImagePreprocessor
- resolution = (1024, 1024) # 默认分辨率
- image_preprocessor = ImagePreprocessor(resolution=resolution)
- image_proc = image_preprocessor.proc(image_ori)
- image_proc = image_proc.unsqueeze(0)
- except ImportError:
- # 如果无法导入,使用简单的预处理
- transform = transforms.Compose([
- transforms.Resize((1024, 1024)),
- transforms.ToTensor(),
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
- ])
- image_proc = transform(image_ori).unsqueeze(0)
-
- image_proc = image_proc.to(device)
- if half_precision:
- image_proc = image_proc.half()
-
- # 推理
- print("正在进行VIP抠图处理...", flush=True)
- with torch.no_grad():
- preds = model(image_proc)[-1].sigmoid().cpu()
-
- # 后处理:将预测结果转换为图像
- pred = preds[0].squeeze()
- pred_pil = transforms.ToPILImage()(pred)
-
- # 调整到原始尺寸
- pred_pil = pred_pil.resize(original_size, Image.Resampling.LANCZOS)
-
- # 应用遮罩到原图
- image_rgba = image_ori.convert('RGBA')
- mask_array = np.array(pred_pil)
-
- # 创建 RGBA 图像
- rgba_array = np.array(image_rgba)
- rgba_array[:, :, 3] = mask_array
-
- # 保存结果
- output_image = Image.fromarray(rgba_array, 'RGBA')
- output_image.save(output_path, 'PNG')
-
- print("VIP抠图完成", flush=True)
-
- except Exception as error:
- print(f"错误: {error}", file=sys.stderr)
- import traceback
- traceback.print_exc()
- sys.exit(1)
- def main():
- if len(sys.argv) < 3:
- print("错误: 需要提供输入和输出路径", file=sys.stderr)
- print("用法: python birefnet-matting.py <输入文件> <输出文件>", file=sys.stderr)
- sys.exit(1)
-
- input_path = sys.argv[1]
- output_path = sys.argv[2]
-
- process_image(input_path, output_path)
- if __name__ == '__main__':
- main()
|