| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555 |
- """
- 使用 BiRefNet 进行图像背景移除
- 使用项目中已有的 BiRefNet 环境,效果最好
- 使用方法:
- python ImageMatting_birefnet.py # 批量处理 InputImage 文件夹
- python ImageMatting_birefnet.py input.png output.png # 处理单个文件
- """
- import os
- import sys
- import torch
- from PIL import Image
- from torchvision import transforms
- from glob import glob
- # 支持的图片格式
- SUPPORTED_FORMATS = ['.jpg', '.jpeg', '.png', '.bmp', '.webp', '.tiff', '.tif']
- # 添加 BiRefNet 目录到路径
- script_dir = os.path.dirname(os.path.abspath(__file__))
- birefnet_dir = os.path.join(script_dir, 'BiRefNet')
- sys.path.insert(0, birefnet_dir)
- # 尝试导入本地 BiRefNet 模块,如果失败则使用 HuggingFace transformers
- BiRefNet = None
- refine_foreground = None
- check_state_dict = None
- USE_LOCAL_BIREFNET = False
- try:
- from models.birefnet import BiRefNet
- from image_proc import refine_foreground
- from utils import check_state_dict
- USE_LOCAL_BIREFNET = True
- print('已加载本地 BiRefNet 模块')
- except ImportError as e:
- print(f'警告: 无法导入本地 BiRefNet 模块: {e}')
- print('将使用 HuggingFace transformers 方式加载模型')
- try:
- from transformers import AutoModelForImageSegmentation
- print('已导入 transformers.AutoModelForImageSegmentation')
- except ImportError:
- print('错误: 请安装 transformers 库: pip install transformers')
- sys.exit(1)
- def extract_character_birefnet(input_path, output_path, model_path=None, device=None):
- """
- 使用 BiRefNet 移除图像背景
-
- Args:
- input_path: 输入图像路径
- output_path: 输出图像路径(PNG格式,背景透明)
- model_path: 模型权重路径(如果为None,会尝试从 HuggingFace 或本地加载)
- device: 设备 ('cuda' 或 'cpu'),如果为None则自动检测
- """
- try:
- script_dir = os.path.dirname(os.path.abspath(__file__))
-
- # 检查输入文件是否存在
- if not os.path.exists(input_path):
- print(f'错误: {input_path} 文件不存在')
- return False
-
- # 创建输出目录
- output_dir = os.path.dirname(output_path)
- if output_dir and not os.path.exists(output_dir):
- os.makedirs(output_dir, exist_ok=True)
-
- # 检测设备
- if device is None:
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
-
- if device == 'cuda' and not torch.cuda.is_available():
- device = 'cpu'
-
- # 设置 PyTorch 精度
- torch.set_float32_matmul_precision(['high', 'highest'][0])
-
- # 加载模型 - 优先使用本地模型,然后尝试 HuggingFace
- birefnet = None
-
- # 方法1: 尝试从本地 safetensors 文件加载
- local_model_path = os.path.join(script_dir, 'BiRefNet', 'model', 'BiRefNet', 'model.safetensors')
- if model_path is None and os.path.exists(local_model_path):
- try:
- print(f'尝试从本地加载模型: {local_model_path}')
- # 注意: safetensors 需要特殊处理,这里尝试使用 HuggingFace 方式
- from transformers import AutoModelForImageSegmentation
- birefnet = AutoModelForImageSegmentation.from_pretrained(
- os.path.join(script_dir, 'BiRefNet', 'model', 'BiRefNet'),
- trust_remote_code=True
- )
- print('成功从本地加载模型')
- except Exception as e:
- print(f'从本地加载失败: {e},尝试其他方法...')
-
- # 方法2: 如果指定了 model_path,尝试加载 .pth 文件
- if birefnet is None and model_path and os.path.exists(model_path):
- try:
- print(f'从 {model_path} 加载模型权重...')
- birefnet = BiRefNet(bb_pretrained=False)
- state_dict = torch.load(model_path, map_location='cpu', weights_only=True)
- state_dict = check_state_dict(state_dict)
- birefnet.load_state_dict(state_dict)
- print('成功加载 .pth 模型权重')
- except Exception as e:
- print(f'加载 .pth 文件失败: {e}')
-
- # 方法3: 尝试使用 BiRefNet.from_pretrained() 从 HuggingFace 加载
- if birefnet is None:
- try:
- print('尝试从 HuggingFace 加载模型...')
- birefnet = BiRefNet.from_pretrained('zhengpeng7/BiRefNet')
- print('成功从 HuggingFace 加载模型')
- except Exception as e:
- print(f'从 HuggingFace 加载失败: {e}')
- # 方法4: 使用 transformers 库
- try:
- from transformers import AutoModelForImageSegmentation
- birefnet = AutoModelForImageSegmentation.from_pretrained(
- 'ZhengPeng7/BiRefNet',
- trust_remote_code=True
- )
- print('成功通过 transformers 从 HuggingFace 加载模型')
- except Exception as e2:
- print(f'所有加载方法都失败了: {e2}')
- return False
-
- # 设置模型为评估模式
- birefnet.to(device)
- birefnet.eval()
-
- # 使用混合精度(如果支持)
- use_half = device == 'cuda'
- if use_half:
- try:
- birefnet.half()
- print('已启用 FP16 精度')
- except:
- print('FP16 不可用,使用 FP32')
- use_half = False
-
- print('正在移除背景...')
-
- # 图像预处理
- transform_image = transforms.Compose([
- transforms.Resize((1024, 1024)),
- transforms.ToTensor(),
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
- ])
-
- # 读取图像
- image = Image.open(input_path)
- original_size = image.size
- image = image.convert("RGB") if image.mode != "RGB" else image
-
- # 预处理
- input_tensor = transform_image(image).unsqueeze(0)
- if use_half:
- input_tensor = input_tensor.half()
- input_tensor = input_tensor.to(device)
-
- # 推理
- autocast_ctx = torch.amp.autocast(device_type=device, dtype=torch.float16 if use_half else None) if device == 'cuda' else torch.no_grad()
-
- with autocast_ctx, torch.no_grad():
- preds = birefnet(input_tensor)[-1].sigmoid().to(torch.float32).cpu()
-
- pred = preds[0].squeeze()
-
- # 将预测结果转换为 PIL Image 并调整回原始尺寸
- pred_pil = transforms.ToPILImage()(pred)
- pred_pil = pred_pil.resize(original_size, Image.Resampling.LANCZOS)
-
- # 使用 refine_foreground 优化前景
- try:
- image_masked = refine_foreground(image, pred_pil, device=device)
- image_masked.putalpha(pred_pil)
- except Exception as e:
- print(f'前景优化失败,使用简单方法: {e}')
- # 如果 refine_foreground 失败,使用简单方法
- image_masked = image.copy()
- image_masked.putalpha(pred_pil)
-
- # 保存结果
- image_masked.save(output_path, 'PNG')
-
- return True
-
- except Exception as error:
- print(f'处理图像时出错: {error}')
- import traceback
- traceback.print_exc()
- return False
- def load_birefnet_model(model_path=None, device=None):
- """
- 加载 BiRefNet 模型(只加载一次,用于批量处理)
-
- Returns:
- (model, device, use_half): 模型、设备、是否使用半精度
- """
- # 检测设备
- if device is None:
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
-
- if device == 'cuda' and not torch.cuda.is_available():
- device = 'cpu'
-
- use_half = device == 'cuda'
-
- script_dir = os.path.dirname(os.path.abspath(__file__))
-
- # 加载模型
- birefnet = None
-
- # 方法1: 尝试从本地 safetensors 文件加载
- local_model_path = os.path.join(script_dir, 'BiRefNet', 'model', 'BiRefNet', 'model.safetensors')
- if model_path is None and os.path.exists(local_model_path):
- try:
- from transformers import AutoModelForImageSegmentation
- birefnet = AutoModelForImageSegmentation.from_pretrained(
- os.path.join(script_dir, 'BiRefNet', 'model', 'BiRefNet'),
- trust_remote_code=True
- )
- except Exception as e:
- pass
-
- # 方法2: 如果指定了 model_path,尝试加载 .pth 文件
- if birefnet is None and model_path and os.path.exists(model_path):
- try:
- birefnet = BiRefNet(bb_pretrained=False)
- state_dict = torch.load(model_path, map_location='cpu', weights_only=True)
- state_dict = check_state_dict(state_dict)
- birefnet.load_state_dict(state_dict)
- except Exception as e:
- pass
-
- # 方法3: 尝试使用 BiRefNet.from_pretrained() 从 HuggingFace 加载
- if birefnet is None:
- try:
- birefnet = BiRefNet.from_pretrained('zhengpeng7/BiRefNet')
- except Exception as e:
- # 方法4: 使用 transformers 库
- try:
- from transformers import AutoModelForImageSegmentation
- birefnet = AutoModelForImageSegmentation.from_pretrained(
- 'ZhengPeng7/BiRefNet',
- trust_remote_code=True
- )
- except Exception as e2:
- raise Exception(f'所有加载方法都失败了: {e2}')
-
- # 设置模型为评估模式
- birefnet.to(device)
- birefnet.eval()
-
- # 使用混合精度(如果支持)
- if use_half:
- try:
- birefnet.half()
- except:
- use_half = False
-
- return birefnet, device, use_half
- def extract_character_birefnet_with_model(input_path, output_path, birefnet, device, use_half):
- """
- 使用已加载的模型处理单张图片(用于批量处理)
- """
- try:
- # 检查输入文件是否存在
- if not os.path.exists(input_path):
- return False
-
- # 创建输出目录
- output_dir = os.path.dirname(output_path)
- if output_dir and not os.path.exists(output_dir):
- os.makedirs(output_dir, exist_ok=True)
-
- # 图像预处理
- transform_image = transforms.Compose([
- transforms.Resize((1024, 1024)),
- transforms.ToTensor(),
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
- ])
-
- # 读取图像
- image = Image.open(input_path)
- original_size = image.size
- image = image.convert("RGB") if image.mode != "RGB" else image
-
- # 预处理
- input_tensor = transform_image(image).unsqueeze(0)
- if use_half:
- input_tensor = input_tensor.half()
- input_tensor = input_tensor.to(device)
-
- # 推理
- autocast_ctx = torch.amp.autocast(device_type=device, dtype=torch.float16 if use_half else None) if device == 'cuda' else torch.no_grad()
-
- with autocast_ctx, torch.no_grad():
- preds = birefnet(input_tensor)[-1].sigmoid().to(torch.float32).cpu()
-
- pred = preds[0].squeeze()
-
- # 将预测结果转换为 PIL Image 并调整回原始尺寸
- pred_pil = transforms.ToPILImage()(pred)
- pred_pil = pred_pil.resize(original_size, Image.Resampling.LANCZOS)
-
- # 使用 refine_foreground 优化前景
- try:
- image_masked = refine_foreground(image, pred_pil, device=device)
- image_masked.putalpha(pred_pil)
- except Exception as e:
- # 如果 refine_foreground 失败,使用简单方法
- image_masked = image.copy()
- image_masked.putalpha(pred_pil)
-
- # 保存结果
- image_masked.save(output_path, 'PNG')
-
- return True
-
- except Exception as error:
- return False
- def process_folder_birefnet(input_folder, output_folder, model_path=None, device=None):
- """
- 批量处理文件夹中的所有图片
-
- Args:
- input_folder: 输入文件夹路径
- output_folder: 输出文件夹路径
- model_path: 模型权重路径
- device: 设备
- """
- if not os.path.exists(input_folder):
- print(f'错误: 输入文件夹不存在: {input_folder}')
- return
-
- # 创建输出文件夹
- if not os.path.exists(output_folder):
- os.makedirs(output_folder, exist_ok=True)
- print(f'已创建输出文件夹: {output_folder}')
-
- # 获取所有支持的图片文件(递归查找)
- from pathlib import Path
- input_path = Path(input_folder)
- image_files = []
- for ext in SUPPORTED_FORMATS:
- # 使用递归查找,避免遗漏子文件夹中的图片
- image_files.extend(input_path.rglob(f'*{ext}'))
- image_files.extend(input_path.rglob(f'*{ext.upper()}'))
-
- # 去重(使用绝对路径去重,避免重复处理)
- # 先转换为绝对路径的字符串集合去重,再转回Path对象
- unique_paths = set()
- unique_files = []
- for f in image_files:
- abs_path = str(f.resolve())
- if abs_path not in unique_paths:
- unique_paths.add(abs_path)
- unique_files.append(f)
- image_files = unique_files
-
- if not image_files:
- print(f'在 {input_folder} 中未找到支持的图片文件')
- return
-
- print(f'找到 {len(image_files)} 张图片')
- print('正在加载模型...')
-
- # 加载模型(只加载一次,提高效率)
- try:
- birefnet, device_actual, use_half = load_birefnet_model(model_path, device)
- print(f'模型加载成功,使用设备: {device_actual}')
- if use_half:
- print('已启用 FP16 精度')
- except Exception as e:
- print(f'模型加载失败: {e}')
- return
-
- print('开始处理...')
- print('=' * 50)
-
- success_count = 0
- for idx, input_path in enumerate(image_files, 1):
- filename = os.path.basename(input_path)
- name_without_ext = os.path.splitext(filename)[0]
- output_path = os.path.join(output_folder, f'{name_without_ext}.png')
-
- print(f'[{idx}/{len(image_files)}] 正在处理: {filename}')
-
- if extract_character_birefnet_with_model(input_path, output_path, birefnet, device_actual, use_half):
- print(f' ✓ 成功保存到: {os.path.basename(output_path)}')
- success_count += 1
- else:
- print(f' ✗ 处理失败')
- print()
-
- print('=' * 50)
- print(f'处理完成!成功: {success_count}/{len(image_files)}')
- def process_folder_birefnet_by_id(unique_id, rec_base_dir, send_base_dir):
- """
- 根据唯一ID处理指定文件夹中的图片
-
- Args:
- unique_id: 唯一任务ID
- rec_base_dir: rec文件夹的基础路径
- send_base_dir: send文件夹的基础路径
- """
- input_folder = os.path.join(rec_base_dir, unique_id)
- output_folder = os.path.join(send_base_dir, unique_id)
-
- if not os.path.exists(input_folder):
- print(f'错误: 输入文件夹不存在: {input_folder}')
- return
-
- # 创建输出文件夹
- if not os.path.exists(output_folder):
- os.makedirs(output_folder, exist_ok=True)
- print(f'已创建输出文件夹: {output_folder}')
-
- # 获取所有支持的图片文件
- from pathlib import Path
- input_path = Path(input_folder)
- image_files = []
- for ext in SUPPORTED_FORMATS:
- image_files.extend(input_path.rglob(f'*{ext}'))
- image_files.extend(input_path.rglob(f'*{ext.upper()}'))
-
- # 去重(使用绝对路径去重,避免重复处理)
- unique_paths = set()
- unique_files = []
- for f in image_files:
- abs_path = str(f.resolve())
- if abs_path not in unique_paths:
- unique_paths.add(abs_path)
- unique_files.append(f)
- image_files = unique_files
-
- if not image_files:
- print(f'在 {input_folder} 中未找到支持的图片文件')
- return
-
- print(f'找到 {len(image_files)} 张图片')
- print('正在加载模型...')
-
- # 加载模型(只加载一次,提高效率)
- try:
- birefnet, device_actual, use_half = load_birefnet_model(None, None)
- print(f'模型加载成功,使用设备: {device_actual}')
- if use_half:
- print('已启用 FP16 精度')
- except Exception as e:
- print(f'模型加载失败: {e}')
- return
-
- print('开始处理...')
- print('=' * 50)
-
- success_count = 0
- for idx, input_path in enumerate(image_files, 1):
- filename = os.path.basename(input_path)
- name_without_ext = os.path.splitext(filename)[0]
- output_path = os.path.join(output_folder, f'{name_without_ext}.png')
-
- print(f'[{idx}/{len(image_files)}] 正在处理: {filename}')
-
- if extract_character_birefnet_with_model(input_path, output_path, birefnet, device_actual, use_half):
- print(f' ✓ 成功保存到: {os.path.basename(output_path)}')
- success_count += 1
- else:
- print(f' ✗ 处理失败')
- print()
-
- print('=' * 50)
- print(f'处理完成!成功: {success_count}/{len(image_files)}')
-
- # 打包处理后的图片为zip文件
- output_zip_path = os.path.join(output_folder, f'{unique_id}.zip')
- print(f'正在打包到: {output_zip_path}')
-
- processed_images = []
- for ext in SUPPORTED_FORMATS:
- processed_images.extend(Path(output_folder).glob(f'*{ext}'))
- processed_images.extend(Path(output_folder).glob(f'*{ext.upper()}'))
-
- # 排除zip文件本身
- processed_images = [f for f in processed_images if not f.name.lower().endswith('.zip')]
-
- if processed_images:
- import zipfile
- with zipfile.ZipFile(output_zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
- for img_path in processed_images:
- arcname = img_path.name
- zipf.write(str(img_path), arcname)
- print(f'打包完成: {output_zip_path}')
- else:
- print('警告: 没有找到处理后的图片文件')
- if __name__ == '__main__':
- import argparse
-
- parser = argparse.ArgumentParser(description='使用 BiRefNet 移除图像背景')
- parser.add_argument('input', nargs='?', help='输入图像路径或输入文件夹路径')
- parser.add_argument('output', nargs='?', help='输出图像路径或输出文件夹路径')
- parser.add_argument('--model', type=str, default=None, help='模型权重路径(.pth文件)')
- parser.add_argument('--device', type=str, default=None, choices=['cuda', 'cpu'], help='设备 (cuda/cpu)')
-
- args = parser.parse_args()
-
- # 获取脚本所在目录
- script_dir = os.path.dirname(os.path.abspath(__file__))
-
- # 如果提供了命令行参数
- if args.input and args.output:
- input_path = args.input
- output_path = args.output
-
- # 判断是文件夹还是文件
- if os.path.isdir(input_path):
- # 批量处理文件夹
- process_folder_birefnet(input_path, output_path, model_path=args.model, device=args.device)
- else:
- # 处理单个文件
- extract_character_birefnet(
- input_path,
- output_path,
- model_path=args.model,
- device=args.device
- )
- else:
- # 使用默认路径:InputImage 文件夹 -> OutPutImage 文件夹
- input_folder = os.path.join(script_dir, 'InputImage')
- output_folder = os.path.join(script_dir, 'OutPutImage')
-
- if os.path.exists(input_folder):
- # 批量处理文件夹
- process_folder_birefnet(input_folder, output_folder, model_path=args.model, device=args.device)
- else:
- # 如果 InputImage 不存在,尝试处理单个文件(向后兼容)
- input_image_path = os.path.join(script_dir, 'image3.png')
- result_dir = os.path.join(script_dir, 'result')
- output_image_path = os.path.join(result_dir, 'image3_extracted_birefnet.png')
- extract_character_birefnet(
- input_image_path,
- output_image_path,
- model_path=args.model,
- device=args.device
- )
|