example_inference.py
1.1 KB · 39 lines · python Raw
1 from skimage import io
2 import torch, os
3 from PIL import Image
4 from briarmbg import BriaRMBG
5 from utilities import preprocess_image, postprocess_image
6 from huggingface_hub import hf_hub_download
7
8 def example_inference():
9
10 im_path = f"{os.path.dirname(os.path.abspath(__file__))}/example_input.jpg"
11
12 net = BriaRMBG()
13 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14 net = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
15 net.to(device)
16 net.eval()
17
18 # prepare input
19 model_input_size = [1024,1024]
20 orig_im = io.imread(im_path)
21 orig_im_size = orig_im.shape[0:2]
22 image = preprocess_image(orig_im, model_input_size).to(device)
23
24 # inference
25 result=net(image)
26
27 # post process
28 result_image = postprocess_image(result[0][0], orig_im_size)
29
30 # save result
31 pil_mask_im = Image.fromarray(result_image)
32 orig_image = Image.open(im_path)
33 no_bg_image = orig_image.copy()
34 no_bg_image.putalpha(pil_mask_im)
35 no_bg_image.save("example_image_no_bg.png")
36
37
38 if __name__ == "__main__":
39 example_inference()