onnx_run.py
| 1 | import onnxruntime |
| 2 | import numpy as np |
| 3 | from PIL import Image |
| 4 | import torchvision.transforms as transforms |
| 5 | import torch |
| 6 | import torch.nn.functional as F |
| 7 | |
| 8 | |
| 9 | session = onnxruntime.InferenceSession("./BEN2_Base.onnx", providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) |
| 10 | |
| 11 | def postprocess_image(result_np: np.ndarray, im_size: list) -> np.ndarray: |
| 12 | |
| 13 | result = torch.from_numpy(result_np) |
| 14 | |
| 15 | |
| 16 | if len(result.shape) == 3: |
| 17 | result = result.unsqueeze(0) |
| 18 | |
| 19 | |
| 20 | result = torch.squeeze(F.interpolate(result, size=im_size, mode='bilinear'), 0) |
| 21 | |
| 22 | |
| 23 | ma = torch.max(result) |
| 24 | mi = torch.min(result) |
| 25 | result = (result - mi) / (ma - mi) |
| 26 | |
| 27 | im_array = (result * 255).permute(1, 2, 0).cpu().data.numpy().astype(np.uint8) |
| 28 | im_array = np.squeeze(im_array) |
| 29 | return im_array |
| 30 | |
| 31 | def preprocess_image(image): |
| 32 | original_size = image.size |
| 33 | transform = transforms.Compose([ |
| 34 | transforms.Resize((1024, 1024)), |
| 35 | transforms.ToTensor(), |
| 36 | ]) |
| 37 | img_tensor = transform(image) |
| 38 | |
| 39 | img_tensor = img_tensor.unsqueeze(0) |
| 40 | return img_tensor.numpy(), image, original_size |
| 41 | |
| 42 | def run_inference(image): |
| 43 | |
| 44 | input_data, original_image, (w, h) = preprocess_image(image) |
| 45 | |
| 46 | input_name = session.get_inputs()[0].name |
| 47 | |
| 48 | outputs = session.run(None, {input_name: input_data}) |
| 49 | |
| 50 | |
| 51 | alpha = postprocess_image(outputs[0], im_size=[w, h]) |
| 52 | |
| 53 | |
| 54 | mask = Image.fromarray(alpha) |
| 55 | mask = mask.resize((w, h)) |
| 56 | |
| 57 | |
| 58 | original_image.putalpha(mask) |
| 59 | return original_image |
| 60 | |
| 61 | # Example usage |
| 62 | image_path = "image.png" |
| 63 | output_path = "output.png" |
| 64 | |
| 65 | |
| 66 | image = Image.open(image_path) |
| 67 | |
| 68 | result_image = run_inference(image) |
| 69 | result_image.save(output_path) |
| 70 | |