handler.py
| 1 | from typing import Dict, List, Any |
| 2 | from PIL import Image |
| 3 | from io import BytesIO |
| 4 | from transformers import AutoModelForSemanticSegmentation, AutoFeatureExtractor |
| 5 | import base64 |
| 6 | import torch |
| 7 | from torch import nn |
| 8 | |
| 9 | class EndpointHandler(): |
| 10 | def __init__(self, path="."): |
| 11 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| 12 | self.model = AutoModelForSemanticSegmentation.from_pretrained(path).to(self.device).eval() |
| 13 | self.feature_extractor = AutoFeatureExtractor.from_pretrained(path) |
| 14 | |
| 15 | def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
| 16 | """ |
| 17 | data args: |
| 18 | images (:obj:`PIL.Image`) |
| 19 | candiates (:obj:`list`) |
| 20 | Return: |
| 21 | A :obj:`list`:. The list contains items that are dicts should be liked {"label": "XXX", "score": 0.82} |
| 22 | """ |
| 23 | inputs = data.pop("inputs", data) |
| 24 | |
| 25 | # decode base64 image to PIL |
| 26 | image = Image.open(BytesIO(base64.b64decode(inputs['image']))) |
| 27 | |
| 28 | # preprocess image |
| 29 | encoding = self.feature_extractor(images=image, return_tensors="pt") |
| 30 | pixel_values = encoding["pixel_values"].to(self.device) |
| 31 | with torch.no_grad(): |
| 32 | outputs = self.model(pixel_values=pixel_values) |
| 33 | logits = outputs.logits |
| 34 | upsampled_logits = nn.functional.interpolate(logits, |
| 35 | size=image.size[::-1], |
| 36 | mode="bilinear", |
| 37 | align_corners=False,) |
| 38 | pred_seg = upsampled_logits.argmax(dim=1)[0] |
| 39 | return pred_seg.tolist() |
| 40 | |