birefnet-matting.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555
  1. """
  2. 使用 BiRefNet 进行图像背景移除
  3. 使用项目中已有的 BiRefNet 环境,效果最好
  4. 使用方法:
  5. python ImageMatting_birefnet.py # 批量处理 InputImage 文件夹
  6. python ImageMatting_birefnet.py input.png output.png # 处理单个文件
  7. """
  8. import os
  9. import sys
  10. import torch
  11. from PIL import Image
  12. from torchvision import transforms
  13. from glob import glob
  14. # 支持的图片格式
  15. SUPPORTED_FORMATS = ['.jpg', '.jpeg', '.png', '.bmp', '.webp', '.tiff', '.tif']
  16. # 添加 BiRefNet 目录到路径
  17. script_dir = os.path.dirname(os.path.abspath(__file__))
  18. birefnet_dir = os.path.join(script_dir, 'BiRefNet')
  19. sys.path.insert(0, birefnet_dir)
  20. # 尝试导入本地 BiRefNet 模块,如果失败则使用 HuggingFace transformers
  21. BiRefNet = None
  22. refine_foreground = None
  23. check_state_dict = None
  24. USE_LOCAL_BIREFNET = False
  25. try:
  26. from models.birefnet import BiRefNet
  27. from image_proc import refine_foreground
  28. from utils import check_state_dict
  29. USE_LOCAL_BIREFNET = True
  30. print('已加载本地 BiRefNet 模块')
  31. except ImportError as e:
  32. print(f'警告: 无法导入本地 BiRefNet 模块: {e}')
  33. print('将使用 HuggingFace transformers 方式加载模型')
  34. try:
  35. from transformers import AutoModelForImageSegmentation
  36. print('已导入 transformers.AutoModelForImageSegmentation')
  37. except ImportError:
  38. print('错误: 请安装 transformers 库: pip install transformers')
  39. sys.exit(1)
  40. def extract_character_birefnet(input_path, output_path, model_path=None, device=None):
  41. """
  42. 使用 BiRefNet 移除图像背景
  43. Args:
  44. input_path: 输入图像路径
  45. output_path: 输出图像路径(PNG格式,背景透明)
  46. model_path: 模型权重路径(如果为None,会尝试从 HuggingFace 或本地加载)
  47. device: 设备 ('cuda' 或 'cpu'),如果为None则自动检测
  48. """
  49. try:
  50. script_dir = os.path.dirname(os.path.abspath(__file__))
  51. # 检查输入文件是否存在
  52. if not os.path.exists(input_path):
  53. print(f'错误: {input_path} 文件不存在')
  54. return False
  55. # 创建输出目录
  56. output_dir = os.path.dirname(output_path)
  57. if output_dir and not os.path.exists(output_dir):
  58. os.makedirs(output_dir, exist_ok=True)
  59. # 检测设备
  60. if device is None:
  61. device = 'cuda' if torch.cuda.is_available() else 'cpu'
  62. if device == 'cuda' and not torch.cuda.is_available():
  63. device = 'cpu'
  64. # 设置 PyTorch 精度
  65. torch.set_float32_matmul_precision(['high', 'highest'][0])
  66. # 加载模型 - 优先使用本地模型,然后尝试 HuggingFace
  67. birefnet = None
  68. # 方法1: 尝试从本地 safetensors 文件加载
  69. local_model_path = os.path.join(script_dir, 'BiRefNet', 'model', 'BiRefNet', 'model.safetensors')
  70. if model_path is None and os.path.exists(local_model_path):
  71. try:
  72. print(f'尝试从本地加载模型: {local_model_path}')
  73. # 注意: safetensors 需要特殊处理,这里尝试使用 HuggingFace 方式
  74. from transformers import AutoModelForImageSegmentation
  75. birefnet = AutoModelForImageSegmentation.from_pretrained(
  76. os.path.join(script_dir, 'BiRefNet', 'model', 'BiRefNet'),
  77. trust_remote_code=True
  78. )
  79. print('成功从本地加载模型')
  80. except Exception as e:
  81. print(f'从本地加载失败: {e},尝试其他方法...')
  82. # 方法2: 如果指定了 model_path,尝试加载 .pth 文件
  83. if birefnet is None and model_path and os.path.exists(model_path):
  84. try:
  85. print(f'从 {model_path} 加载模型权重...')
  86. birefnet = BiRefNet(bb_pretrained=False)
  87. state_dict = torch.load(model_path, map_location='cpu', weights_only=True)
  88. state_dict = check_state_dict(state_dict)
  89. birefnet.load_state_dict(state_dict)
  90. print('成功加载 .pth 模型权重')
  91. except Exception as e:
  92. print(f'加载 .pth 文件失败: {e}')
  93. # 方法3: 尝试使用 BiRefNet.from_pretrained() 从 HuggingFace 加载
  94. if birefnet is None:
  95. try:
  96. print('尝试从 HuggingFace 加载模型...')
  97. birefnet = BiRefNet.from_pretrained('zhengpeng7/BiRefNet')
  98. print('成功从 HuggingFace 加载模型')
  99. except Exception as e:
  100. print(f'从 HuggingFace 加载失败: {e}')
  101. # 方法4: 使用 transformers 库
  102. try:
  103. from transformers import AutoModelForImageSegmentation
  104. birefnet = AutoModelForImageSegmentation.from_pretrained(
  105. 'ZhengPeng7/BiRefNet',
  106. trust_remote_code=True
  107. )
  108. print('成功通过 transformers 从 HuggingFace 加载模型')
  109. except Exception as e2:
  110. print(f'所有加载方法都失败了: {e2}')
  111. return False
  112. # 设置模型为评估模式
  113. birefnet.to(device)
  114. birefnet.eval()
  115. # 使用混合精度(如果支持)
  116. use_half = device == 'cuda'
  117. if use_half:
  118. try:
  119. birefnet.half()
  120. print('已启用 FP16 精度')
  121. except:
  122. print('FP16 不可用,使用 FP32')
  123. use_half = False
  124. print('正在移除背景...')
  125. # 图像预处理
  126. transform_image = transforms.Compose([
  127. transforms.Resize((1024, 1024)),
  128. transforms.ToTensor(),
  129. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  130. ])
  131. # 读取图像
  132. image = Image.open(input_path)
  133. original_size = image.size
  134. image = image.convert("RGB") if image.mode != "RGB" else image
  135. # 预处理
  136. input_tensor = transform_image(image).unsqueeze(0)
  137. if use_half:
  138. input_tensor = input_tensor.half()
  139. input_tensor = input_tensor.to(device)
  140. # 推理
  141. autocast_ctx = torch.amp.autocast(device_type=device, dtype=torch.float16 if use_half else None) if device == 'cuda' else torch.no_grad()
  142. with autocast_ctx, torch.no_grad():
  143. preds = birefnet(input_tensor)[-1].sigmoid().to(torch.float32).cpu()
  144. pred = preds[0].squeeze()
  145. # 将预测结果转换为 PIL Image 并调整回原始尺寸
  146. pred_pil = transforms.ToPILImage()(pred)
  147. pred_pil = pred_pil.resize(original_size, Image.Resampling.LANCZOS)
  148. # 使用 refine_foreground 优化前景
  149. try:
  150. image_masked = refine_foreground(image, pred_pil, device=device)
  151. image_masked.putalpha(pred_pil)
  152. except Exception as e:
  153. print(f'前景优化失败,使用简单方法: {e}')
  154. # 如果 refine_foreground 失败,使用简单方法
  155. image_masked = image.copy()
  156. image_masked.putalpha(pred_pil)
  157. # 保存结果
  158. image_masked.save(output_path, 'PNG')
  159. return True
  160. except Exception as error:
  161. print(f'处理图像时出错: {error}')
  162. import traceback
  163. traceback.print_exc()
  164. return False
  165. def load_birefnet_model(model_path=None, device=None):
  166. """
  167. 加载 BiRefNet 模型(只加载一次,用于批量处理)
  168. Returns:
  169. (model, device, use_half): 模型、设备、是否使用半精度
  170. """
  171. # 检测设备
  172. if device is None:
  173. device = 'cuda' if torch.cuda.is_available() else 'cpu'
  174. if device == 'cuda' and not torch.cuda.is_available():
  175. device = 'cpu'
  176. use_half = device == 'cuda'
  177. script_dir = os.path.dirname(os.path.abspath(__file__))
  178. # 加载模型
  179. birefnet = None
  180. # 方法1: 尝试从本地 safetensors 文件加载
  181. local_model_path = os.path.join(script_dir, 'BiRefNet', 'model', 'BiRefNet', 'model.safetensors')
  182. if model_path is None and os.path.exists(local_model_path):
  183. try:
  184. from transformers import AutoModelForImageSegmentation
  185. birefnet = AutoModelForImageSegmentation.from_pretrained(
  186. os.path.join(script_dir, 'BiRefNet', 'model', 'BiRefNet'),
  187. trust_remote_code=True
  188. )
  189. except Exception as e:
  190. pass
  191. # 方法2: 如果指定了 model_path,尝试加载 .pth 文件
  192. if birefnet is None and model_path and os.path.exists(model_path):
  193. try:
  194. birefnet = BiRefNet(bb_pretrained=False)
  195. state_dict = torch.load(model_path, map_location='cpu', weights_only=True)
  196. state_dict = check_state_dict(state_dict)
  197. birefnet.load_state_dict(state_dict)
  198. except Exception as e:
  199. pass
  200. # 方法3: 尝试使用 BiRefNet.from_pretrained() 从 HuggingFace 加载
  201. if birefnet is None:
  202. try:
  203. birefnet = BiRefNet.from_pretrained('zhengpeng7/BiRefNet')
  204. except Exception as e:
  205. # 方法4: 使用 transformers 库
  206. try:
  207. from transformers import AutoModelForImageSegmentation
  208. birefnet = AutoModelForImageSegmentation.from_pretrained(
  209. 'ZhengPeng7/BiRefNet',
  210. trust_remote_code=True
  211. )
  212. except Exception as e2:
  213. raise Exception(f'所有加载方法都失败了: {e2}')
  214. # 设置模型为评估模式
  215. birefnet.to(device)
  216. birefnet.eval()
  217. # 使用混合精度(如果支持)
  218. if use_half:
  219. try:
  220. birefnet.half()
  221. except:
  222. use_half = False
  223. return birefnet, device, use_half
  224. def extract_character_birefnet_with_model(input_path, output_path, birefnet, device, use_half):
  225. """
  226. 使用已加载的模型处理单张图片(用于批量处理)
  227. """
  228. try:
  229. # 检查输入文件是否存在
  230. if not os.path.exists(input_path):
  231. return False
  232. # 创建输出目录
  233. output_dir = os.path.dirname(output_path)
  234. if output_dir and not os.path.exists(output_dir):
  235. os.makedirs(output_dir, exist_ok=True)
  236. # 图像预处理
  237. transform_image = transforms.Compose([
  238. transforms.Resize((1024, 1024)),
  239. transforms.ToTensor(),
  240. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  241. ])
  242. # 读取图像
  243. image = Image.open(input_path)
  244. original_size = image.size
  245. image = image.convert("RGB") if image.mode != "RGB" else image
  246. # 预处理
  247. input_tensor = transform_image(image).unsqueeze(0)
  248. if use_half:
  249. input_tensor = input_tensor.half()
  250. input_tensor = input_tensor.to(device)
  251. # 推理
  252. autocast_ctx = torch.amp.autocast(device_type=device, dtype=torch.float16 if use_half else None) if device == 'cuda' else torch.no_grad()
  253. with autocast_ctx, torch.no_grad():
  254. preds = birefnet(input_tensor)[-1].sigmoid().to(torch.float32).cpu()
  255. pred = preds[0].squeeze()
  256. # 将预测结果转换为 PIL Image 并调整回原始尺寸
  257. pred_pil = transforms.ToPILImage()(pred)
  258. pred_pil = pred_pil.resize(original_size, Image.Resampling.LANCZOS)
  259. # 使用 refine_foreground 优化前景
  260. try:
  261. image_masked = refine_foreground(image, pred_pil, device=device)
  262. image_masked.putalpha(pred_pil)
  263. except Exception as e:
  264. # 如果 refine_foreground 失败,使用简单方法
  265. image_masked = image.copy()
  266. image_masked.putalpha(pred_pil)
  267. # 保存结果
  268. image_masked.save(output_path, 'PNG')
  269. return True
  270. except Exception as error:
  271. return False
  272. def process_folder_birefnet(input_folder, output_folder, model_path=None, device=None):
  273. """
  274. 批量处理文件夹中的所有图片
  275. Args:
  276. input_folder: 输入文件夹路径
  277. output_folder: 输出文件夹路径
  278. model_path: 模型权重路径
  279. device: 设备
  280. """
  281. if not os.path.exists(input_folder):
  282. print(f'错误: 输入文件夹不存在: {input_folder}')
  283. return
  284. # 创建输出文件夹
  285. if not os.path.exists(output_folder):
  286. os.makedirs(output_folder, exist_ok=True)
  287. print(f'已创建输出文件夹: {output_folder}')
  288. # 获取所有支持的图片文件(递归查找)
  289. from pathlib import Path
  290. input_path = Path(input_folder)
  291. image_files = []
  292. for ext in SUPPORTED_FORMATS:
  293. # 使用递归查找,避免遗漏子文件夹中的图片
  294. image_files.extend(input_path.rglob(f'*{ext}'))
  295. image_files.extend(input_path.rglob(f'*{ext.upper()}'))
  296. # 去重(使用绝对路径去重,避免重复处理)
  297. # 先转换为绝对路径的字符串集合去重,再转回Path对象
  298. unique_paths = set()
  299. unique_files = []
  300. for f in image_files:
  301. abs_path = str(f.resolve())
  302. if abs_path not in unique_paths:
  303. unique_paths.add(abs_path)
  304. unique_files.append(f)
  305. image_files = unique_files
  306. if not image_files:
  307. print(f'在 {input_folder} 中未找到支持的图片文件')
  308. return
  309. print(f'找到 {len(image_files)} 张图片')
  310. print('正在加载模型...')
  311. # 加载模型(只加载一次,提高效率)
  312. try:
  313. birefnet, device_actual, use_half = load_birefnet_model(model_path, device)
  314. print(f'模型加载成功,使用设备: {device_actual}')
  315. if use_half:
  316. print('已启用 FP16 精度')
  317. except Exception as e:
  318. print(f'模型加载失败: {e}')
  319. return
  320. print('开始处理...')
  321. print('=' * 50)
  322. success_count = 0
  323. for idx, input_path in enumerate(image_files, 1):
  324. filename = os.path.basename(input_path)
  325. name_without_ext = os.path.splitext(filename)[0]
  326. output_path = os.path.join(output_folder, f'{name_without_ext}.png')
  327. print(f'[{idx}/{len(image_files)}] 正在处理: {filename}')
  328. if extract_character_birefnet_with_model(input_path, output_path, birefnet, device_actual, use_half):
  329. print(f' ✓ 成功保存到: {os.path.basename(output_path)}')
  330. success_count += 1
  331. else:
  332. print(f' ✗ 处理失败')
  333. print()
  334. print('=' * 50)
  335. print(f'处理完成!成功: {success_count}/{len(image_files)}')
  336. def process_folder_birefnet_by_id(unique_id, rec_base_dir, send_base_dir):
  337. """
  338. 根据唯一ID处理指定文件夹中的图片
  339. Args:
  340. unique_id: 唯一任务ID
  341. rec_base_dir: rec文件夹的基础路径
  342. send_base_dir: send文件夹的基础路径
  343. """
  344. input_folder = os.path.join(rec_base_dir, unique_id)
  345. output_folder = os.path.join(send_base_dir, unique_id)
  346. if not os.path.exists(input_folder):
  347. print(f'错误: 输入文件夹不存在: {input_folder}')
  348. return
  349. # 创建输出文件夹
  350. if not os.path.exists(output_folder):
  351. os.makedirs(output_folder, exist_ok=True)
  352. print(f'已创建输出文件夹: {output_folder}')
  353. # 获取所有支持的图片文件
  354. from pathlib import Path
  355. input_path = Path(input_folder)
  356. image_files = []
  357. for ext in SUPPORTED_FORMATS:
  358. image_files.extend(input_path.rglob(f'*{ext}'))
  359. image_files.extend(input_path.rglob(f'*{ext.upper()}'))
  360. # 去重(使用绝对路径去重,避免重复处理)
  361. unique_paths = set()
  362. unique_files = []
  363. for f in image_files:
  364. abs_path = str(f.resolve())
  365. if abs_path not in unique_paths:
  366. unique_paths.add(abs_path)
  367. unique_files.append(f)
  368. image_files = unique_files
  369. if not image_files:
  370. print(f'在 {input_folder} 中未找到支持的图片文件')
  371. return
  372. print(f'找到 {len(image_files)} 张图片')
  373. print('正在加载模型...')
  374. # 加载模型(只加载一次,提高效率)
  375. try:
  376. birefnet, device_actual, use_half = load_birefnet_model(None, None)
  377. print(f'模型加载成功,使用设备: {device_actual}')
  378. if use_half:
  379. print('已启用 FP16 精度')
  380. except Exception as e:
  381. print(f'模型加载失败: {e}')
  382. return
  383. print('开始处理...')
  384. print('=' * 50)
  385. success_count = 0
  386. for idx, input_path in enumerate(image_files, 1):
  387. filename = os.path.basename(input_path)
  388. name_without_ext = os.path.splitext(filename)[0]
  389. output_path = os.path.join(output_folder, f'{name_without_ext}.png')
  390. print(f'[{idx}/{len(image_files)}] 正在处理: {filename}')
  391. if extract_character_birefnet_with_model(input_path, output_path, birefnet, device_actual, use_half):
  392. print(f' ✓ 成功保存到: {os.path.basename(output_path)}')
  393. success_count += 1
  394. else:
  395. print(f' ✗ 处理失败')
  396. print()
  397. print('=' * 50)
  398. print(f'处理完成!成功: {success_count}/{len(image_files)}')
  399. # 打包处理后的图片为zip文件
  400. output_zip_path = os.path.join(output_folder, f'{unique_id}.zip')
  401. print(f'正在打包到: {output_zip_path}')
  402. processed_images = []
  403. for ext in SUPPORTED_FORMATS:
  404. processed_images.extend(Path(output_folder).glob(f'*{ext}'))
  405. processed_images.extend(Path(output_folder).glob(f'*{ext.upper()}'))
  406. # 排除zip文件本身
  407. processed_images = [f for f in processed_images if not f.name.lower().endswith('.zip')]
  408. if processed_images:
  409. import zipfile
  410. with zipfile.ZipFile(output_zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
  411. for img_path in processed_images:
  412. arcname = img_path.name
  413. zipf.write(str(img_path), arcname)
  414. print(f'打包完成: {output_zip_path}')
  415. else:
  416. print('警告: 没有找到处理后的图片文件')
  417. if __name__ == '__main__':
  418. import argparse
  419. parser = argparse.ArgumentParser(description='使用 BiRefNet 移除图像背景')
  420. parser.add_argument('input', nargs='?', help='输入图像路径或输入文件夹路径')
  421. parser.add_argument('output', nargs='?', help='输出图像路径或输出文件夹路径')
  422. parser.add_argument('--model', type=str, default=None, help='模型权重路径(.pth文件)')
  423. parser.add_argument('--device', type=str, default=None, choices=['cuda', 'cpu'], help='设备 (cuda/cpu)')
  424. args = parser.parse_args()
  425. # 获取脚本所在目录
  426. script_dir = os.path.dirname(os.path.abspath(__file__))
  427. # 如果提供了命令行参数
  428. if args.input and args.output:
  429. input_path = args.input
  430. output_path = args.output
  431. # 判断是文件夹还是文件
  432. if os.path.isdir(input_path):
  433. # 批量处理文件夹
  434. process_folder_birefnet(input_path, output_path, model_path=args.model, device=args.device)
  435. else:
  436. # 处理单个文件
  437. extract_character_birefnet(
  438. input_path,
  439. output_path,
  440. model_path=args.model,
  441. device=args.device
  442. )
  443. else:
  444. # 使用默认路径:InputImage 文件夹 -> OutPutImage 文件夹
  445. input_folder = os.path.join(script_dir, 'InputImage')
  446. output_folder = os.path.join(script_dir, 'OutPutImage')
  447. if os.path.exists(input_folder):
  448. # 批量处理文件夹
  449. process_folder_birefnet(input_folder, output_folder, model_path=args.model, device=args.device)
  450. else:
  451. # 如果 InputImage 不存在,尝试处理单个文件(向后兼容)
  452. input_image_path = os.path.join(script_dir, 'image3.png')
  453. result_dir = os.path.join(script_dir, 'result')
  454. output_image_path = os.path.join(result_dir, 'image3_extracted_birefnet.png')
  455. extract_character_birefnet(
  456. input_image_path,
  457. output_image_path,
  458. model_path=args.model,
  459. device=args.device
  460. )