handler.py
4.7 KB · 140 lines · python Raw
1 # These HF deployment codes refer to https://huggingface.co/not-lain/BiRefNet/raw/main/handler.py.
2 from typing import Dict, List, Any, Tuple
3 import os
4 import requests
5 from io import BytesIO
6 import cv2
7 import numpy as np
8 from PIL import Image
9 import torch
10 from torchvision import transforms
11 from transformers import AutoModelForImageSegmentation
12
13 torch.set_float32_matmul_precision(["high", "highest"][0])
14
15 device = "cuda" if torch.cuda.is_available() else "cpu"
16
17 ### image_proc.py
18 def refine_foreground(image, mask, r=90):
19 if mask.size != image.size:
20 mask = mask.resize(image.size)
21 image = np.array(image) / 255.0
22 mask = np.array(mask) / 255.0
23 estimated_foreground = FB_blur_fusion_foreground_estimator_2(image, mask, r=r)
24 image_masked = Image.fromarray((estimated_foreground * 255.0).astype(np.uint8))
25 return image_masked
26
27
28 def FB_blur_fusion_foreground_estimator_2(image, alpha, r=90):
29 # Thanks to the source: https://github.com/Photoroom/fast-foreground-estimation
30 alpha = alpha[:, :, None]
31 F, blur_B = FB_blur_fusion_foreground_estimator(image, image, image, alpha, r)
32 return FB_blur_fusion_foreground_estimator(image, F, blur_B, alpha, r=6)[0]
33
34
35 def FB_blur_fusion_foreground_estimator(image, F, B, alpha, r=90):
36 if isinstance(image, Image.Image):
37 image = np.array(image) / 255.0
38 blurred_alpha = cv2.blur(alpha, (r, r))[:, :, None]
39
40 blurred_FA = cv2.blur(F * alpha, (r, r))
41 blurred_F = blurred_FA / (blurred_alpha + 1e-5)
42
43 blurred_B1A = cv2.blur(B * (1 - alpha), (r, r))
44 blurred_B = blurred_B1A / ((1 - blurred_alpha) + 1e-5)
45 F = blurred_F + alpha * \
46 (image - alpha * blurred_F - (1 - alpha) * blurred_B)
47 F = np.clip(F, 0, 1)
48 return F, blurred_B
49
50
51 class ImagePreprocessor():
52 def __init__(self, resolution: Tuple[int, int] = (1024, 1024)) -> None:
53 self.transform_image = transforms.Compose([
54 transforms.Resize(resolution),
55 transforms.ToTensor(),
56 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
57 ])
58
59 def proc(self, image: Image.Image) -> torch.Tensor:
60 image = self.transform_image(image)
61 return image
62
63 usage_to_weights_file = {
64 'General': 'BiRefNet',
65 'General-HR': 'BiRefNet_HR',
66 'General-Lite': 'BiRefNet_lite',
67 'General-Lite-2K': 'BiRefNet_lite-2K',
68 'General-reso_512': 'BiRefNet-reso_512',
69 'Matting': 'BiRefNet-matting',
70 'Matting-HR': 'BiRefNet_HR-Matting',
71 'Portrait': 'BiRefNet-portrait',
72 'DIS': 'BiRefNet-DIS5K',
73 'HRSOD': 'BiRefNet-HRSOD',
74 'COD': 'BiRefNet-COD',
75 'DIS-TR_TEs': 'BiRefNet-DIS5K-TR_TEs',
76 'General-legacy': 'BiRefNet-legacy'
77 }
78
79 # Choose the version of BiRefNet here.
80 usage = 'General'
81
82 # Set resolution
83 if usage in ['General-Lite-2K']:
84 resolution = (2560, 1440)
85 elif usage in ['General-reso_512']:
86 resolution = (512, 512)
87 elif usage in ['General-HR', 'Matting-HR']:
88 resolution = (2048, 2048)
89 else:
90 resolution = (1024, 1024)
91
92 half_precision = True
93
94 class EndpointHandler():
95 def __init__(self, path=''):
96 self.birefnet = AutoModelForImageSegmentation.from_pretrained(
97 '/'.join(('zhengpeng7', usage_to_weights_file[usage])), trust_remote_code=True
98 )
99 self.birefnet.to(device)
100 self.birefnet.eval()
101 if half_precision:
102 self.birefnet.half()
103
104 def __call__(self, data: Dict[str, Any]):
105 """
106 data args:
107 inputs (:obj: `str`)
108 date (:obj: `str`)
109 Return:
110 A :obj:`list` | `dict`: will be serialized and returned
111 """
112 print('data["inputs"] = ', data["inputs"])
113 image_src = data["inputs"]
114 if isinstance(image_src, str):
115 if os.path.isfile(image_src):
116 image_ori = Image.open(image_src)
117 else:
118 response = requests.get(image_src)
119 image_data = BytesIO(response.content)
120 image_ori = Image.open(image_data)
121 else:
122 image_ori = Image.fromarray(image_src)
123
124 image = image_ori.convert('RGB')
125 # Preprocess the image
126 image_preprocessor = ImagePreprocessor(resolution=tuple(resolution))
127 image_proc = image_preprocessor.proc(image)
128 image_proc = image_proc.unsqueeze(0)
129
130 # Prediction
131 with torch.no_grad():
132 preds = self.birefnet(image_proc.to(device).half() if half_precision else image_proc.to(device))[-1].sigmoid().cpu()
133 pred = preds[0].squeeze()
134
135 # Show Results
136 pred_pil = transforms.ToPILImage()(pred)
137 image_masked = refine_foreground(image, pred_pil)
138 image_masked.putalpha(pred_pil.resize(image.size))
139 return image_masked
140