image-matting.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. """
  2. 使用 rembg 库进行图像背景移除
  3. """
  4. import sys
  5. import os
  6. from rembg import remove
  7. from PIL import Image
  8. from glob import glob
  9. SUPPORTED_FORMATS = ['.jpg', '.jpeg', '.png', '.bmp', '.webp', '.tiff', '.tif']
  10. def extract_character_rembg(input_path, output_path):
  11. """
  12. 使用 rembg 库移除图像背景
  13. """
  14. try:
  15. print(f' -> 读取图片: {os.path.basename(input_path)}')
  16. if not os.path.exists(input_path):
  17. print(f' [X] 错误: 文件不存在')
  18. return False
  19. # 创建输出目录
  20. output_dir = os.path.dirname(output_path)
  21. if output_dir and not os.path.exists(output_dir):
  22. print(f' -> 创建输出目录: {output_dir}')
  23. os.makedirs(output_dir, exist_ok=True)
  24. # 读取输入图像
  25. print(f' -> 加载图像...')
  26. input_image = Image.open(input_path)
  27. original_size = input_image.size
  28. print(f' 图像尺寸: {original_size[0]}x{original_size[1]}')
  29. # 确保输入图像是RGB模式
  30. if input_image.mode != 'RGB':
  31. print(f' -> 转换颜色模式: {input_image.mode} -> RGB')
  32. input_image = input_image.convert('RGB')
  33. # 使用 rembg 移除背景
  34. print(f' -> 执行AI抠图...')
  35. output_image = remove(input_image, post_process_mask=False)
  36. print(f' [OK] 抠图完成')
  37. # 确保输出图像尺寸与输入图像一致
  38. if output_image.size != original_size:
  39. print(f' -> 调整图像尺寸')
  40. final_image = Image.new('RGBA', original_size, (0, 0, 0, 0))
  41. final_image.paste(output_image, (0, 0), output_image if output_image.mode == 'RGBA' else None)
  42. output_image = final_image
  43. else:
  44. if output_image.mode != 'RGBA':
  45. output_image = output_image.convert('RGBA')
  46. # 保存结果
  47. print(f' -> 保存结果: {os.path.basename(output_path)}')
  48. output_image.save(output_path, 'PNG')
  49. print(f' [OK] 保存成功')
  50. return True
  51. except Exception as error:
  52. print(f' [X] 处理失败: {error}')
  53. return False
  54. def process_folder(input_folder, output_folder):
  55. """
  56. 批量处理文件夹中的所有图片
  57. """
  58. print('=' * 60)
  59. print('[Step 1/2] AI Image Matting')
  60. print('=' * 60)
  61. if not os.path.exists(input_folder):
  62. print(f'[X] Error: Input folder not found: {input_folder}')
  63. return 0
  64. print(f'-> Input folder: {input_folder}')
  65. # 创建输出文件夹
  66. if not os.path.exists(output_folder):
  67. print(f'-> Creating output folder: {output_folder}')
  68. os.makedirs(output_folder, exist_ok=True)
  69. else:
  70. print(f'-> Output folder: {output_folder}')
  71. # 获取所有支持的图片文件
  72. print(f'-> Scanning image files...')
  73. image_files = []
  74. for ext in SUPPORTED_FORMATS:
  75. image_files.extend(glob(os.path.join(input_folder, f'*{ext}')))
  76. image_files.extend(glob(os.path.join(input_folder, f'*{ext.upper()}')))
  77. # 去重(Windows文件系统不区分大小写,可能重复)
  78. image_files = list(set(image_files))
  79. image_files.sort()
  80. if not image_files:
  81. print(f'[X] No supported image files found')
  82. return 0
  83. print(f'[OK] Found {len(image_files)} images')
  84. print('-' * 60)
  85. success_count = 0
  86. for idx, input_path in enumerate(image_files, 1):
  87. filename = os.path.basename(input_path)
  88. name_without_ext = os.path.splitext(filename)[0]
  89. output_path = os.path.join(output_folder, f'{name_without_ext}.png')
  90. print(f'\n[{idx}/{len(image_files)}] {filename}')
  91. print(f'PROGRESS: {idx}/{len(image_files)}')
  92. if extract_character_rembg(input_path, output_path):
  93. success_count += 1
  94. print(f' [OK] Done')
  95. print('\n' + '=' * 60)
  96. print(f'[Step 1 Complete] Success: {success_count}/{len(image_files)}')
  97. print('=' * 60)
  98. return success_count
  99. if __name__ == '__main__':
  100. if len(sys.argv) != 3:
  101. print('用法: python image-matting.py <输入文件夹> <输出文件夹>')
  102. sys.exit(1)
  103. input_folder = sys.argv[1]
  104. output_folder = sys.argv[2]
  105. processed = process_folder(input_folder, output_folder)
  106. sys.exit(0 if processed > 0 else 1)