inference.py
353 B · 17 lines · python Raw
1 import BEN2
2 from PIL import Image
3 import torch
4
5
6 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
7
8 file = "./image.png" # input image
9
10 model = BEN2.BEN_Base().to(device).eval() #init pipeline
11
12 model.loadcheckpoints("./BEN2_Base.pth")
13 image = Image.open(file)
14 foreground = model.inference(image)
15
16 foreground.save("./foreground.png")
17