""" BiRefNet 图像背景移除工具 用于高质量抠图处理 """ import sys import os # 添加 BiRefNet 目录到路径 birefnet_dir = os.path.join(os.path.dirname(__file__), 'BiRefNet') sys.path.insert(0, birefnet_dir) try: from transformers import AutoModelForImageSegmentation from PIL import Image import torch import numpy as np from torchvision import transforms except ImportError as e: print(f"错误: 缺少必要的库 - {e}", file=sys.stderr) print("请安装: pip install transformers pillow torch numpy torchvision", file=sys.stderr) sys.exit(1) def process_image(input_path, output_path): """ 处理单个图片文件 """ try: # 检查输入文件是否存在 if not os.path.exists(input_path): print(f"错误: 输入文件不存在: {input_path}", file=sys.stderr) sys.exit(1) # 创建输出目录(如果不存在) output_dir = os.path.dirname(output_path) if output_dir and not os.path.exists(output_dir): os.makedirs(output_dir, exist_ok=True) print(f"正在加载 BiRefNet 模型...", flush=True) # 检查是否有本地模型路径 local_model_path = os.path.join(birefnet_dir, 'model', 'BiRefNet') model_file = os.path.join(local_model_path, 'model.safetensors') config_file = os.path.join(local_model_path, 'birefnet.py') use_local = os.path.exists(model_file) and os.path.exists(config_file) # 加载 BiRefNet 模型 try: if use_local: print(f"使用本地模型: {local_model_path}", flush=True) model = AutoModelForImageSegmentation.from_pretrained( local_model_path, trust_remote_code=True, local_files_only=True ) else: print(f"从 HuggingFace 下载模型...", flush=True) print("提示: 如果下载失败,请运行 download-birefnet-model.py 下载模型到本地", flush=True) model = AutoModelForImageSegmentation.from_pretrained( 'ZhengPeng7/BiRefNet', trust_remote_code=True ) except Exception as e: print(f"错误: 无法加载 BiRefNet 模型 - {e}", file=sys.stderr) if not use_local: print("提示: 请确保已安装 transformers 库并可以访问 HuggingFace", file=sys.stderr) print("或者运行以下命令下载模型到本地:", file=sys.stderr) print(" python download-birefnet-model.py", file=sys.stderr) sys.exit(1) # 设置设备 device = 'cuda' if torch.cuda.is_available() else 'cpu' model = model.to(device) model.eval() half_precision = device == 'cuda' if half_precision: model = model.half() # 使用 FP16 加速 print(f"正在处理图片: {os.path.basename(input_path)}", flush=True) # 读取输入图像 image_ori = Image.open(input_path).convert('RGB') original_size = image_ori.size # 预处理图像(使用 BiRefNet 的预处理方式) try: from image_proc import ImagePreprocessor resolution = (1024, 1024) # 默认分辨率 image_preprocessor = ImagePreprocessor(resolution=resolution) image_proc = image_preprocessor.proc(image_ori) image_proc = image_proc.unsqueeze(0) except ImportError: # 如果无法导入,使用简单的预处理 transform = transforms.Compose([ transforms.Resize((1024, 1024)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) image_proc = transform(image_ori).unsqueeze(0) image_proc = image_proc.to(device) if half_precision: image_proc = image_proc.half() # 推理 print("正在进行VIP抠图处理...", flush=True) with torch.no_grad(): preds = model(image_proc)[-1].sigmoid().cpu() # 后处理:将预测结果转换为图像 pred = preds[0].squeeze() pred_pil = transforms.ToPILImage()(pred) # 调整到原始尺寸 pred_pil = pred_pil.resize(original_size, Image.Resampling.LANCZOS) # 应用遮罩到原图 image_rgba = image_ori.convert('RGBA') mask_array = np.array(pred_pil) # 创建 RGBA 图像 rgba_array = np.array(image_rgba) rgba_array[:, :, 3] = mask_array # 保存结果 output_image = Image.fromarray(rgba_array, 'RGBA') output_image.save(output_path, 'PNG') print("VIP抠图完成", flush=True) except Exception as error: print(f"错误: {error}", file=sys.stderr) import traceback traceback.print_exc() sys.exit(1) def main(): if len(sys.argv) < 3: print("错误: 需要提供输入和输出路径", file=sys.stderr) print("用法: python birefnet-matting.py <输入文件> <输出文件>", file=sys.stderr) sys.exit(1) input_path = sys.argv[1] output_path = sys.argv[2] process_image(input_path, output_path) if __name__ == '__main__': main()