onnx_run.py
1.6 KB · 70 lines · python Raw
1 import onnxruntime
2 import numpy as np
3 from PIL import Image
4 import torchvision.transforms as transforms
5 import torch
6 import torch.nn.functional as F
7
8
9 session = onnxruntime.InferenceSession("./BEN2_Base.onnx", providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
10
11 def postprocess_image(result_np: np.ndarray, im_size: list) -> np.ndarray:
12
13 result = torch.from_numpy(result_np)
14
15
16 if len(result.shape) == 3:
17 result = result.unsqueeze(0)
18
19
20 result = torch.squeeze(F.interpolate(result, size=im_size, mode='bilinear'), 0)
21
22
23 ma = torch.max(result)
24 mi = torch.min(result)
25 result = (result - mi) / (ma - mi)
26
27 im_array = (result * 255).permute(1, 2, 0).cpu().data.numpy().astype(np.uint8)
28 im_array = np.squeeze(im_array)
29 return im_array
30
31 def preprocess_image(image):
32 original_size = image.size
33 transform = transforms.Compose([
34 transforms.Resize((1024, 1024)),
35 transforms.ToTensor(),
36 ])
37 img_tensor = transform(image)
38
39 img_tensor = img_tensor.unsqueeze(0)
40 return img_tensor.numpy(), image, original_size
41
42 def run_inference(image):
43
44 input_data, original_image, (w, h) = preprocess_image(image)
45
46 input_name = session.get_inputs()[0].name
47
48 outputs = session.run(None, {input_name: input_data})
49
50
51 alpha = postprocess_image(outputs[0], im_size=[w, h])
52
53
54 mask = Image.fromarray(alpha)
55 mask = mask.resize((w, h))
56
57
58 original_image.putalpha(mask)
59 return original_image
60
61 # Example usage
62 image_path = "image.png"
63 output_path = "output.png"
64
65
66 image = Image.open(image_path)
67
68 result_image = run_inference(image)
69 result_image.save(output_path)
70