MyPipe.py
2.8 KB · 77 lines · python Raw
1 import torch, os
2 import torch.nn.functional as F
3 from torchvision.transforms.functional import normalize
4 import numpy as np
5 from transformers import Pipeline
6 from transformers.image_utils import load_image
7 from skimage import io
8 from PIL import Image
9
10 class RMBGPipe(Pipeline):
11 def __init__(self,**kwargs):
12 Pipeline.__init__(self,**kwargs)
13 self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
14 self.model.to(self.device)
15 self.model.eval()
16
17 def _sanitize_parameters(self, **kwargs):
18 # parse parameters
19 preprocess_kwargs = {}
20 postprocess_kwargs = {}
21 if "model_input_size" in kwargs :
22 preprocess_kwargs["model_input_size"] = kwargs["model_input_size"]
23 if "return_mask" in kwargs:
24 postprocess_kwargs["return_mask"] = kwargs["return_mask"]
25 return preprocess_kwargs, {}, postprocess_kwargs
26
27 def preprocess(self,input_image,model_input_size: list=[1024,1024]):
28 # preprocess the input
29 orig_im = load_image(input_image)
30 orig_im = np.array(orig_im)
31 orig_im_size = orig_im.shape[0:2]
32 preprocessed_image = self.preprocess_image(orig_im, model_input_size).to(self.device)
33 inputs = {
34 "preprocessed_image":preprocessed_image,
35 "orig_im_size":orig_im_size,
36 "input_image" : input_image
37 }
38 return inputs
39
40 def _forward(self,inputs):
41 result = self.model(inputs.pop("preprocessed_image"))
42 inputs["result"] = result
43 return inputs
44
45 def postprocess(self,inputs,return_mask:bool=False ):
46 result = inputs.pop("result")
47 orig_im_size = inputs.pop("orig_im_size")
48 input_image = inputs.pop("input_image")
49 result_image = self.postprocess_image(result[0][0], orig_im_size)
50 pil_im = Image.fromarray(result_image)
51 if return_mask ==True :
52 return pil_im
53 input_image = load_image(input_image)
54 no_bg_image = input_image.copy()
55 no_bg_image.putalpha(pil_im)
56 return no_bg_image
57
58 # utilities functions
59 def preprocess_image(self,im: np.ndarray, model_input_size: list=[1024,1024]) -> torch.Tensor:
60 # same as utilities.py with minor modification
61 if len(im.shape) < 3:
62 im = im[:, :, np.newaxis]
63 im_tensor = torch.tensor(im, dtype=torch.float32).permute(2,0,1)
64 im_tensor = F.interpolate(torch.unsqueeze(im_tensor,0), size=model_input_size, mode='bilinear')
65 image = torch.divide(im_tensor,255.0)
66 image = normalize(image,[0.5,0.5,0.5],[1.0,1.0,1.0])
67 return image
68
69 def postprocess_image(self,result: torch.Tensor, im_size: list)-> np.ndarray:
70 result = torch.squeeze(F.interpolate(result, size=im_size, mode='bilinear') ,0)
71 ma = torch.max(result)
72 mi = torch.min(result)
73 result = (result-mi)/(ma-mi)
74 im_array = (result*255).permute(1,2,0).cpu().data.numpy().astype(np.uint8)
75 im_array = np.squeeze(im_array)
76 return im_array
77