inference.py
| 1 | import BEN2 |
| 2 | from PIL import Image |
| 3 | import torch |
| 4 | |
| 5 | |
| 6 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| 7 | |
| 8 | file = "./image.png" # input image |
| 9 | |
| 10 | model = BEN2.BEN_Base().to(device).eval() #init pipeline |
| 11 | |
| 12 | model.loadcheckpoints("./BEN2_Base.pth") |
| 13 | image = Image.open(file) |
| 14 | foreground = model.inference(image) |
| 15 | |
| 16 | foreground.save("./foreground.png") |
| 17 | |