README.md
9.7 KB · 226 lines · markdown Raw
1 ---
2 library_name: birefnet
3 tags:
4 - background-removal
5 - mask-generation
6 - Dichotomous Image Segmentation
7 - Camouflaged Object Detection
8 - Salient Object Detection
9 - pytorch_model_hub_mixin
10 - model_hub_mixin
11 - transformers
12 repo_url: https://github.com/ZhengPeng7/BiRefNet
13 pipeline_tag: image-segmentation
14 license: mit
15 ---
16 <h1 align="center">Bilateral Reference for High-Resolution Dichotomous Image Segmentation</h1>
17
18 <div align='center'>
19 <a href='https://scholar.google.com/citations?user=TZRzWOsAAAAJ' target='_blank'><strong>Peng Zheng</strong></a><sup> 1,4,5,6</sup>,&thinsp;
20 <a href='https://scholar.google.com/citations?user=0uPb8MMAAAAJ' target='_blank'><strong>Dehong Gao</strong></a><sup> 2</sup>,&thinsp;
21 <a href='https://scholar.google.com/citations?user=kakwJ5QAAAAJ' target='_blank'><strong>Deng-Ping Fan</strong></a><sup> 1*</sup>,&thinsp;
22 <a href='https://scholar.google.com/citations?user=9cMQrVsAAAAJ' target='_blank'><strong>Li Liu</strong></a><sup> 3</sup>,&thinsp;
23 <a href='https://scholar.google.com/citations?user=qQP6WXIAAAAJ' target='_blank'><strong>Jorma Laaksonen</strong></a><sup> 4</sup>,&thinsp;
24 <a href='https://scholar.google.com/citations?user=pw_0Z_UAAAAJ' target='_blank'><strong>Wanli Ouyang</strong></a><sup> 5</sup>,&thinsp;
25 <a href='https://scholar.google.com/citations?user=stFCYOAAAAAJ' target='_blank'><strong>Nicu Sebe</strong></a><sup> 6</sup>
26 </div>
27
28 <div align='center'>
29 <sup>1 </sup>Nankai University&ensp; <sup>2 </sup>Northwestern Polytechnical University&ensp; <sup>3 </sup>National University of Defense Technology&ensp; <sup>4 </sup>Aalto University&ensp; <sup>5 </sup>Shanghai AI Laboratory&ensp; <sup>6 </sup>University of Trento&ensp;
30 </div>
31
32 <div align="center" style="display: flex; justify-content: center; flex-wrap: wrap;">
33 <a href='https://www.sciopen.com/article/pdf/10.26599/AIR.2024.9150038.pdf'><img src='https://img.shields.io/badge/Journal-Paper-red'></a>&ensp;
34 <a href='https://arxiv.org/pdf/2401.03407'><img src='https://img.shields.io/badge/arXiv-BiRefNet-red'></a>&ensp;
35 <a href='https://drive.google.com/file/d/1aBnJ_R9lbnC2dm8dqD0-pzP2Cu-U1Xpt/view?usp=drive_link'><img src='https://img.shields.io/badge/中文版-BiRefNet-red'></a>&ensp;
36 <a href='https://www.birefnet.top'><img src='https://img.shields.io/badge/Page-BiRefNet-red'></a>&ensp;
37 <a href='https://drive.google.com/drive/folders/1s2Xe0cjq-2ctnJBR24563yMSCOu4CcxM'><img src='https://img.shields.io/badge/Drive-Stuff-green'></a>&ensp;
38 <a href='LICENSE'><img src='https://img.shields.io/badge/License-MIT-yellow'></a>&ensp;
39 <a href='https://huggingface.co/spaces/ZhengPeng7/BiRefNet_demo'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20HF%20Spaces-BiRefNet-blue'></a>&ensp;
40 <a href='https://huggingface.co/ZhengPeng7/BiRefNet'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20HF%20Models-BiRefNet-blue'></a>&ensp;
41 <a href='https://colab.research.google.com/drive/14Dqg7oeBkFEtchaHLNpig2BcdkZEogba?usp=drive_link'><img src='https://img.shields.io/badge/Single_Image_Inference-F9AB00?style=for-the-badge&logo=googlecolab&color=525252'></a>&ensp;
42 <a href='https://colab.research.google.com/drive/1MaEiBfJ4xIaZZn0DqKrhydHB8X97hNXl#scrollTo=DJ4meUYjia6S'><img src='https://img.shields.io/badge/Inference_&_Evaluation-F9AB00?style=for-the-badge&logo=googlecolab&color=525252'></a>&ensp;
43 </div>
44
45
46 | *DIS-Sample_1* | *DIS-Sample_2* |
47 | :------------------------------: | :-------------------------------: |
48 | <img src="https://drive.google.com/thumbnail?id=1ItXaA26iYnE8XQ_GgNLy71MOWePoS2-g&sz=w400" /> | <img src="https://drive.google.com/thumbnail?id=1Z-esCujQF_uEa_YJjkibc3NUrW4aR_d4&sz=w400" /> |
49
50 This repo is the official implementation of "[**Bilateral Reference for High-Resolution Dichotomous Image Segmentation**](https://arxiv.org/pdf/2401.03407.pdf)" (___CAAI AIR 2024___).
51
52 Visit our GitHub repo: [https://github.com/ZhengPeng7/BiRefNet](https://github.com/ZhengPeng7/BiRefNet) for more details -- **codes**, **docs**, and **model zoo**!
53
54 ## How to use
55
56 ### 0. Install Packages:
57 ```
58 pip install -qr https://raw.githubusercontent.com/ZhengPeng7/BiRefNet/main/requirements.txt
59 ```
60
61 ### 1. Load BiRefNet:
62
63 #### Use codes + weights from HuggingFace
64 > Only use the weights on HuggingFace -- Pro: No need to download BiRefNet codes manually; Con: Codes on HuggingFace might not be latest version (I'll try to keep them always latest).
65
66 ```python
67 # Load BiRefNet with weights
68 from transformers import AutoModelForImageSegmentation
69 birefnet = AutoModelForImageSegmentation.from_pretrained('ZhengPeng7/BiRefNet', trust_remote_code=True)
70 ```
71
72 #### Use codes from GitHub + weights from HuggingFace
73 > Only use the weights on HuggingFace -- Pro: codes are always latest; Con: Need to clone the BiRefNet repo from my GitHub.
74
75 ```shell
76 # Download codes
77 git clone https://github.com/ZhengPeng7/BiRefNet.git
78 cd BiRefNet
79 ```
80
81 ```python
82 # Use codes locally
83 from models.birefnet import BiRefNet
84
85 # Load weights from Hugging Face Models
86 birefnet = BiRefNet.from_pretrained('ZhengPeng7/BiRefNet')
87 ```
88
89 #### Use codes from GitHub + weights from local space
90 > Only use the weights and codes both locally.
91
92 ```python
93 # Use codes and weights locally
94 import torch
95 from utils import check_state_dict
96
97 birefnet = BiRefNet(bb_pretrained=False)
98 state_dict = torch.load(PATH_TO_WEIGHT, map_location='cpu')
99 state_dict = check_state_dict(state_dict)
100 birefnet.load_state_dict(state_dict)
101 ```
102
103 #### Use the loaded BiRefNet for inference
104 ```python
105 # Imports
106 from PIL import Image
107 import matplotlib.pyplot as plt
108 import torch
109 from torchvision import transforms
110 from models.birefnet import BiRefNet
111
112 birefnet = ... # -- BiRefNet should be loaded with codes above, either way.
113 torch.set_float32_matmul_precision(['high', 'highest'][0])
114 birefnet.to('cuda')
115 birefnet.eval()
116 birefnet.half()
117
118 def extract_object(birefnet, imagepath):
119 # Data settings
120 image_size = (1024, 1024)
121 transform_image = transforms.Compose([
122 transforms.Resize(image_size),
123 transforms.ToTensor(),
124 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
125 ])
126
127 image = Image.open(imagepath)
128 input_images = transform_image(image).unsqueeze(0).to('cuda').half()
129
130 # Prediction
131 with torch.no_grad():
132 preds = birefnet(input_images)[-1].sigmoid().cpu()
133 pred = preds[0].squeeze()
134 pred_pil = transforms.ToPILImage()(pred)
135 mask = pred_pil.resize(image.size)
136 image.putalpha(mask)
137 return image, mask
138
139 # Visualization
140 plt.axis("off")
141 plt.imshow(extract_object(birefnet, imagepath='PATH-TO-YOUR_IMAGE.jpg')[0])
142 plt.show()
143
144 ```
145
146 ### 2. Use inference endpoint locally:
147 > You may need to click the *deploy* and set up the endpoint by yourself, which would make some costs.
148 ```
149 import requests
150 import base64
151 from io import BytesIO
152 from PIL import Image
153
154
155 YOUR_HF_TOKEN = 'xxx'
156 API_URL = "xxx"
157 headers = {
158 "Authorization": "Bearer {}".format(YOUR_HF_TOKEN)
159 }
160
161 def base64_to_bytes(base64_string):
162 # Remove the data URI prefix if present
163 if "data:image" in base64_string:
164 base64_string = base64_string.split(",")[1]
165
166 # Decode the Base64 string into bytes
167 image_bytes = base64.b64decode(base64_string)
168 return image_bytes
169
170 def bytes_to_base64(image_bytes):
171 # Create a BytesIO object to handle the image data
172 image_stream = BytesIO(image_bytes)
173
174 # Open the image using Pillow (PIL)
175 image = Image.open(image_stream)
176 return image
177
178 def query(payload):
179 response = requests.post(API_URL, headers=headers, json=payload)
180 return response.json()
181
182 output = query({
183 "inputs": "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg",
184 "parameters": {}
185 })
186
187 output_image = bytes_to_base64(base64_to_bytes(output))
188 output_image
189 ```
190
191
192 > This BiRefNet for standard dichotomous image segmentation (DIS) is trained on **DIS-TR** and validated on **DIS-TEs and DIS-VD**.
193
194 ## This repo holds the official model weights of "[<ins>Bilateral Reference for High-Resolution Dichotomous Image Segmentation</ins>](https://arxiv.org/pdf/2401.03407)" (_CAAI AIR 2024_).
195
196 This repo contains the weights of BiRefNet proposed in our paper, which has achieved the SOTA performance on three tasks (DIS, HRSOD, and COD).
197
198 Go to my GitHub page for BiRefNet codes and the latest updates: https://github.com/ZhengPeng7/BiRefNet :)
199
200
201 #### Try our online demos for inference:
202
203 + Online **Image Inference** on Colab: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/14Dqg7oeBkFEtchaHLNpig2BcdkZEogba?usp=drive_link)
204 + **Online Inference with GUI on Hugging Face** with adjustable resolutions: [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/ZhengPeng7/BiRefNet_demo)
205 + **Inference and evaluation** of your given weights: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1MaEiBfJ4xIaZZn0DqKrhydHB8X97hNXl#scrollTo=DJ4meUYjia6S)
206 <img src="https://drive.google.com/thumbnail?id=12XmDhKtO1o2fEvBu4OE4ULVB2BK0ecWi&sz=w1080" />
207
208 ## Acknowledgement:
209
210 + Many thanks to @Freepik for their generous support on GPU resources for training higher resolution BiRefNet models and more of my explorations.
211 + Many thanks to @fal for their generous support on GPU resources for training better general BiRefNet models.
212 + Many thanks to @not-lain for his help on the better deployment of our BiRefNet model on HuggingFace.
213
214
215 ## Citation
216
217 ```
218 @article{zheng2024birefnet,
219 title={Bilateral Reference for High-Resolution Dichotomous Image Segmentation},
220 author={Zheng, Peng and Gao, Dehong and Fan, Deng-Ping and Liu, Li and Laaksonen, Jorma and Ouyang, Wanli and Sebe, Nicu},
221 journal={CAAI Artificial Intelligence Research},
222 volume = {3},
223 pages = {9150038},
224 year={2024}
225 }
226 ```