File size: 2,081 Bytes
0c4adde
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from typing import Dict, List, Any
from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer

from PIL import Image
import torch
import base64
from base64 import b64encode
import requests
import json
import io

# Take in base64 string and return cv image
def stringToRGB(base64_string):
    imgdata = base64.b64decode(str(base64_string))
    img = Image.open(io.BytesIO(imgdata)).convert('RGB')
    # opencv_img= cv2.cvtColor(np.array(img), cv2.COLOR_BGR2RGB)
    return img


def predict_caption(image_str, max_token = 32):

    num_beams = 4
    gen_kwargs = {"max_length": max_token, "num_beams": num_beams}
    
    images = []
    image = stringToRGB(image_str)
    images.append(image)
    
    pixel_values = feature_extractor(images=images, return_tensors="pt").pixel_values
    pixel_values = pixel_values.to(device)

    output_ids = model.generate(pixel_values, **gen_kwargs)

    preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
    preds = [pred.strip() for pred in preds]
    return preds[0]


class EndpointHandler():
    def __init__(self, path=""):
        # Preload all the elements you are going to need at inference.
        # pseudo:
        # self.model= load_model(path
        model = VisionEncoderDecoderModel.from_pretrained(path)
        feature_extractor = ViTFeatureExtractor.from_pretrained(path)
        tokenizer = AutoTokenizer.from_pretrained(path)

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model = model.to(device)
        

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
       data args:
            inputs (:obj: `str` | `PIL.Image` | `np.array`)
            kwargs
      Return:
            A :obj:`list` | `dict`: will be serialized and returned
        """
        max_token = data.pop("max_token", 32)
        img_str = data.pop("data", None)
        
        caption = predict_caption(img_str, max_token=max_token)
        return {"caption": f"{caption}"}
        
        # pseudo
        # self.model(input)