birefnet-matting.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. """
  2. BiRefNet 图像背景移除工具
  3. 用于高质量抠图处理
  4. """
  5. import sys
  6. import os
  7. # 添加 BiRefNet 目录到路径
  8. birefnet_dir = os.path.join(os.path.dirname(__file__), 'BiRefNet')
  9. sys.path.insert(0, birefnet_dir)
  10. try:
  11. from transformers import AutoModelForImageSegmentation
  12. from PIL import Image
  13. import torch
  14. import numpy as np
  15. from torchvision import transforms
  16. except ImportError as e:
  17. print(f"错误: 缺少必要的库 - {e}", file=sys.stderr)
  18. print("请安装: pip install transformers pillow torch numpy torchvision", file=sys.stderr)
  19. sys.exit(1)
  20. def process_image(input_path, output_path):
  21. """
  22. 处理单个图片文件
  23. """
  24. try:
  25. # 检查输入文件是否存在
  26. if not os.path.exists(input_path):
  27. print(f"错误: 输入文件不存在: {input_path}", file=sys.stderr)
  28. sys.exit(1)
  29. # 创建输出目录(如果不存在)
  30. output_dir = os.path.dirname(output_path)
  31. if output_dir and not os.path.exists(output_dir):
  32. os.makedirs(output_dir, exist_ok=True)
  33. print(f"正在加载 BiRefNet 模型...", flush=True)
  34. # 检查是否有本地模型路径
  35. local_model_path = os.path.join(birefnet_dir, 'model', 'BiRefNet')
  36. model_file = os.path.join(local_model_path, 'model.safetensors')
  37. config_file = os.path.join(local_model_path, 'birefnet.py')
  38. use_local = os.path.exists(model_file) and os.path.exists(config_file)
  39. # 加载 BiRefNet 模型
  40. try:
  41. if use_local:
  42. print(f"使用本地模型: {local_model_path}", flush=True)
  43. model = AutoModelForImageSegmentation.from_pretrained(
  44. local_model_path,
  45. trust_remote_code=True,
  46. local_files_only=True
  47. )
  48. else:
  49. print(f"从 HuggingFace 下载模型...", flush=True)
  50. print("提示: 如果下载失败,请运行 download-birefnet-model.py 下载模型到本地", flush=True)
  51. model = AutoModelForImageSegmentation.from_pretrained(
  52. 'ZhengPeng7/BiRefNet',
  53. trust_remote_code=True
  54. )
  55. except Exception as e:
  56. print(f"错误: 无法加载 BiRefNet 模型 - {e}", file=sys.stderr)
  57. if not use_local:
  58. print("提示: 请确保已安装 transformers 库并可以访问 HuggingFace", file=sys.stderr)
  59. print("或者运行以下命令下载模型到本地:", file=sys.stderr)
  60. print(" python download-birefnet-model.py", file=sys.stderr)
  61. sys.exit(1)
  62. # 设置设备
  63. device = 'cuda' if torch.cuda.is_available() else 'cpu'
  64. model = model.to(device)
  65. model.eval()
  66. half_precision = device == 'cuda'
  67. if half_precision:
  68. model = model.half() # 使用 FP16 加速
  69. print(f"正在处理图片: {os.path.basename(input_path)}", flush=True)
  70. # 读取输入图像
  71. image_ori = Image.open(input_path).convert('RGB')
  72. original_size = image_ori.size
  73. # 预处理图像(使用 BiRefNet 的预处理方式)
  74. try:
  75. from image_proc import ImagePreprocessor
  76. resolution = (1024, 1024) # 默认分辨率
  77. image_preprocessor = ImagePreprocessor(resolution=resolution)
  78. image_proc = image_preprocessor.proc(image_ori)
  79. image_proc = image_proc.unsqueeze(0)
  80. except ImportError:
  81. # 如果无法导入,使用简单的预处理
  82. transform = transforms.Compose([
  83. transforms.Resize((1024, 1024)),
  84. transforms.ToTensor(),
  85. transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  86. ])
  87. image_proc = transform(image_ori).unsqueeze(0)
  88. image_proc = image_proc.to(device)
  89. if half_precision:
  90. image_proc = image_proc.half()
  91. # 推理
  92. print("正在进行VIP抠图处理...", flush=True)
  93. with torch.no_grad():
  94. preds = model(image_proc)[-1].sigmoid().cpu()
  95. # 后处理:将预测结果转换为图像
  96. pred = preds[0].squeeze()
  97. pred_pil = transforms.ToPILImage()(pred)
  98. # 调整到原始尺寸
  99. pred_pil = pred_pil.resize(original_size, Image.Resampling.LANCZOS)
  100. # 应用遮罩到原图
  101. image_rgba = image_ori.convert('RGBA')
  102. mask_array = np.array(pred_pil)
  103. # 创建 RGBA 图像
  104. rgba_array = np.array(image_rgba)
  105. rgba_array[:, :, 3] = mask_array
  106. # 保存结果
  107. output_image = Image.fromarray(rgba_array, 'RGBA')
  108. output_image.save(output_path, 'PNG')
  109. print("VIP抠图完成", flush=True)
  110. except Exception as error:
  111. print(f"错误: {error}", file=sys.stderr)
  112. import traceback
  113. traceback.print_exc()
  114. sys.exit(1)
  115. def main():
  116. if len(sys.argv) < 3:
  117. print("错误: 需要提供输入和输出路径", file=sys.stderr)
  118. print("用法: python birefnet-matting.py <输入文件> <输出文件>", file=sys.stderr)
  119. sys.exit(1)
  120. input_path = sys.argv[1]
  121. output_path = sys.argv[2]
  122. process_image(input_path, output_path)
  123. if __name__ == '__main__':
  124. main()