handler.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. # These HF deployment codes refer to https://huggingface.co/not-lain/BiRefNet/raw/main/handler.py.
  2. from typing import Dict, List, Any, Tuple
  3. import os
  4. import requests
  5. from io import BytesIO
  6. import cv2
  7. import numpy as np
  8. from PIL import Image
  9. import torch
  10. from torchvision import transforms
  11. from transformers import AutoModelForImageSegmentation
  12. torch.set_float32_matmul_precision(["high", "highest"][0])
  13. device = "cuda" if torch.cuda.is_available() else "cpu"
  14. ### image_proc.py
  15. def refine_foreground(image, mask, r=90):
  16. if mask.size != image.size:
  17. mask = mask.resize(image.size)
  18. image = np.array(image) / 255.0
  19. mask = np.array(mask) / 255.0
  20. estimated_foreground = FB_blur_fusion_foreground_estimator_2(image, mask, r=r)
  21. image_masked = Image.fromarray((estimated_foreground * 255.0).astype(np.uint8))
  22. return image_masked
  23. def FB_blur_fusion_foreground_estimator_2(image, alpha, r=90):
  24. # Thanks to the source: https://github.com/Photoroom/fast-foreground-estimation
  25. alpha = alpha[:, :, None]
  26. F, blur_B = FB_blur_fusion_foreground_estimator(image, image, image, alpha, r)
  27. return FB_blur_fusion_foreground_estimator(image, F, blur_B, alpha, r=6)[0]
  28. def FB_blur_fusion_foreground_estimator(image, F, B, alpha, r=90):
  29. if isinstance(image, Image.Image):
  30. image = np.array(image) / 255.0
  31. blurred_alpha = cv2.blur(alpha, (r, r))[:, :, None]
  32. blurred_FA = cv2.blur(F * alpha, (r, r))
  33. blurred_F = blurred_FA / (blurred_alpha + 1e-5)
  34. blurred_B1A = cv2.blur(B * (1 - alpha), (r, r))
  35. blurred_B = blurred_B1A / ((1 - blurred_alpha) + 1e-5)
  36. F = blurred_F + alpha * \
  37. (image - alpha * blurred_F - (1 - alpha) * blurred_B)
  38. F = np.clip(F, 0, 1)
  39. return F, blurred_B
  40. class ImagePreprocessor():
  41. def __init__(self, resolution: Tuple[int, int] = (1024, 1024)) -> None:
  42. self.transform_image = transforms.Compose([
  43. transforms.Resize(resolution),
  44. transforms.ToTensor(),
  45. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
  46. ])
  47. def proc(self, image: Image.Image) -> torch.Tensor:
  48. image = self.transform_image(image)
  49. return image
  50. usage_to_weights_file = {
  51. 'General': 'BiRefNet',
  52. 'General-HR': 'BiRefNet_HR',
  53. 'General-Lite': 'BiRefNet_lite',
  54. 'General-Lite-2K': 'BiRefNet_lite-2K',
  55. 'General-reso_512': 'BiRefNet-reso_512',
  56. 'Matting': 'BiRefNet-matting',
  57. 'Matting-HR': 'BiRefNet_HR-Matting',
  58. 'Portrait': 'BiRefNet-portrait',
  59. 'DIS': 'BiRefNet-DIS5K',
  60. 'HRSOD': 'BiRefNet-HRSOD',
  61. 'COD': 'BiRefNet-COD',
  62. 'DIS-TR_TEs': 'BiRefNet-DIS5K-TR_TEs',
  63. 'General-legacy': 'BiRefNet-legacy'
  64. }
  65. # Choose the version of BiRefNet here.
  66. usage = 'General'
  67. # Set resolution
  68. if usage in ['General-Lite-2K']:
  69. resolution = (2560, 1440)
  70. elif usage in ['General-reso_512']:
  71. resolution = (512, 512)
  72. elif usage in ['General-HR', 'Matting-HR']:
  73. resolution = (2048, 2048)
  74. else:
  75. resolution = (1024, 1024)
  76. half_precision = True
  77. class EndpointHandler():
  78. def __init__(self, path=''):
  79. self.birefnet = AutoModelForImageSegmentation.from_pretrained(
  80. '/'.join(('zhengpeng7', usage_to_weights_file[usage])), trust_remote_code=True
  81. )
  82. self.birefnet.to(device)
  83. self.birefnet.eval()
  84. if half_precision:
  85. self.birefnet.half()
  86. def __call__(self, data: Dict[str, Any]):
  87. """
  88. data args:
  89. inputs (:obj: `str`)
  90. date (:obj: `str`)
  91. Return:
  92. A :obj:`list` | `dict`: will be serialized and returned
  93. """
  94. print('data["inputs"] = ', data["inputs"])
  95. image_src = data["inputs"]
  96. if isinstance(image_src, str):
  97. if os.path.isfile(image_src):
  98. image_ori = Image.open(image_src)
  99. else:
  100. response = requests.get(image_src)
  101. image_data = BytesIO(response.content)
  102. image_ori = Image.open(image_data)
  103. else:
  104. image_ori = Image.fromarray(image_src)
  105. image = image_ori.convert('RGB')
  106. # Preprocess the image
  107. image_preprocessor = ImagePreprocessor(resolution=tuple(resolution))
  108. image_proc = image_preprocessor.proc(image)
  109. image_proc = image_proc.unsqueeze(0)
  110. # Prediction
  111. with torch.no_grad():
  112. preds = self.birefnet(image_proc.to(device).half() if half_precision else image_proc.to(device))[-1].sigmoid().cpu()
  113. pred = preds[0].squeeze()
  114. # Show Results
  115. pred_pil = transforms.ToPILImage()(pred)
  116. image_masked = refine_foreground(image, pred_pil)
  117. image_masked.putalpha(pred_pil.resize(image.size))
  118. return image_masked