handler.py
| 1 | # These HF deployment codes refer to https://huggingface.co/not-lain/BiRefNet/raw/main/handler.py. |
| 2 | from typing import Dict, List, Any, Tuple |
| 3 | import os |
| 4 | import requests |
| 5 | from io import BytesIO |
| 6 | import cv2 |
| 7 | import numpy as np |
| 8 | from PIL import Image |
| 9 | import torch |
| 10 | from torchvision import transforms |
| 11 | from transformers import AutoModelForImageSegmentation |
| 12 | |
| 13 | torch.set_float32_matmul_precision(["high", "highest"][0]) |
| 14 | |
| 15 | device = "cuda" if torch.cuda.is_available() else "cpu" |
| 16 | |
| 17 | ### image_proc.py |
| 18 | def refine_foreground(image, mask, r=90): |
| 19 | if mask.size != image.size: |
| 20 | mask = mask.resize(image.size) |
| 21 | image = np.array(image) / 255.0 |
| 22 | mask = np.array(mask) / 255.0 |
| 23 | estimated_foreground = FB_blur_fusion_foreground_estimator_2(image, mask, r=r) |
| 24 | image_masked = Image.fromarray((estimated_foreground * 255.0).astype(np.uint8)) |
| 25 | return image_masked |
| 26 | |
| 27 | |
| 28 | def FB_blur_fusion_foreground_estimator_2(image, alpha, r=90): |
| 29 | # Thanks to the source: https://github.com/Photoroom/fast-foreground-estimation |
| 30 | alpha = alpha[:, :, None] |
| 31 | F, blur_B = FB_blur_fusion_foreground_estimator(image, image, image, alpha, r) |
| 32 | return FB_blur_fusion_foreground_estimator(image, F, blur_B, alpha, r=6)[0] |
| 33 | |
| 34 | |
| 35 | def FB_blur_fusion_foreground_estimator(image, F, B, alpha, r=90): |
| 36 | if isinstance(image, Image.Image): |
| 37 | image = np.array(image) / 255.0 |
| 38 | blurred_alpha = cv2.blur(alpha, (r, r))[:, :, None] |
| 39 | |
| 40 | blurred_FA = cv2.blur(F * alpha, (r, r)) |
| 41 | blurred_F = blurred_FA / (blurred_alpha + 1e-5) |
| 42 | |
| 43 | blurred_B1A = cv2.blur(B * (1 - alpha), (r, r)) |
| 44 | blurred_B = blurred_B1A / ((1 - blurred_alpha) + 1e-5) |
| 45 | F = blurred_F + alpha * \ |
| 46 | (image - alpha * blurred_F - (1 - alpha) * blurred_B) |
| 47 | F = np.clip(F, 0, 1) |
| 48 | return F, blurred_B |
| 49 | |
| 50 | |
| 51 | class ImagePreprocessor(): |
| 52 | def __init__(self, resolution: Tuple[int, int] = (1024, 1024)) -> None: |
| 53 | self.transform_image = transforms.Compose([ |
| 54 | transforms.Resize(resolution), |
| 55 | transforms.ToTensor(), |
| 56 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), |
| 57 | ]) |
| 58 | |
| 59 | def proc(self, image: Image.Image) -> torch.Tensor: |
| 60 | image = self.transform_image(image) |
| 61 | return image |
| 62 | |
| 63 | usage_to_weights_file = { |
| 64 | 'General': 'BiRefNet', |
| 65 | 'General-HR': 'BiRefNet_HR', |
| 66 | 'General-Lite': 'BiRefNet_lite', |
| 67 | 'General-Lite-2K': 'BiRefNet_lite-2K', |
| 68 | 'General-reso_512': 'BiRefNet-reso_512', |
| 69 | 'Matting': 'BiRefNet-matting', |
| 70 | 'Matting-HR': 'BiRefNet_HR-Matting', |
| 71 | 'Portrait': 'BiRefNet-portrait', |
| 72 | 'DIS': 'BiRefNet-DIS5K', |
| 73 | 'HRSOD': 'BiRefNet-HRSOD', |
| 74 | 'COD': 'BiRefNet-COD', |
| 75 | 'DIS-TR_TEs': 'BiRefNet-DIS5K-TR_TEs', |
| 76 | 'General-legacy': 'BiRefNet-legacy' |
| 77 | } |
| 78 | |
| 79 | # Choose the version of BiRefNet here. |
| 80 | usage = 'General' |
| 81 | |
| 82 | # Set resolution |
| 83 | if usage in ['General-Lite-2K']: |
| 84 | resolution = (2560, 1440) |
| 85 | elif usage in ['General-reso_512']: |
| 86 | resolution = (512, 512) |
| 87 | elif usage in ['General-HR', 'Matting-HR']: |
| 88 | resolution = (2048, 2048) |
| 89 | else: |
| 90 | resolution = (1024, 1024) |
| 91 | |
| 92 | half_precision = True |
| 93 | |
| 94 | class EndpointHandler(): |
| 95 | def __init__(self, path=''): |
| 96 | self.birefnet = AutoModelForImageSegmentation.from_pretrained( |
| 97 | '/'.join(('zhengpeng7', usage_to_weights_file[usage])), trust_remote_code=True |
| 98 | ) |
| 99 | self.birefnet.to(device) |
| 100 | self.birefnet.eval() |
| 101 | if half_precision: |
| 102 | self.birefnet.half() |
| 103 | |
| 104 | def __call__(self, data: Dict[str, Any]): |
| 105 | """ |
| 106 | data args: |
| 107 | inputs (:obj: `str`) |
| 108 | date (:obj: `str`) |
| 109 | Return: |
| 110 | A :obj:`list` | `dict`: will be serialized and returned |
| 111 | """ |
| 112 | print('data["inputs"] = ', data["inputs"]) |
| 113 | image_src = data["inputs"] |
| 114 | if isinstance(image_src, str): |
| 115 | if os.path.isfile(image_src): |
| 116 | image_ori = Image.open(image_src) |
| 117 | else: |
| 118 | response = requests.get(image_src) |
| 119 | image_data = BytesIO(response.content) |
| 120 | image_ori = Image.open(image_data) |
| 121 | else: |
| 122 | image_ori = Image.fromarray(image_src) |
| 123 | |
| 124 | image = image_ori.convert('RGB') |
| 125 | # Preprocess the image |
| 126 | image_preprocessor = ImagePreprocessor(resolution=tuple(resolution)) |
| 127 | image_proc = image_preprocessor.proc(image) |
| 128 | image_proc = image_proc.unsqueeze(0) |
| 129 | |
| 130 | # Prediction |
| 131 | with torch.no_grad(): |
| 132 | preds = self.birefnet(image_proc.to(device).half() if half_precision else image_proc.to(device))[-1].sigmoid().cpu() |
| 133 | pred = preds[0].squeeze() |
| 134 | |
| 135 | # Show Results |
| 136 | pred_pil = transforms.ToPILImage()(pred) |
| 137 | image_masked = refine_foreground(image, pred_pil) |
| 138 | image_masked.putalpha(pred_pil.resize(image.size)) |
| 139 | return image_masked |
| 140 | |