CarlLee's picture
Create handler.py
0c4adde
raw
history blame
2.08 kB
from typing import Dict, List, Any
from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer
from PIL import Image
import torch
import base64
from base64 import b64encode
import requests
import json
import io
# Take in base64 string and return cv image
def stringToRGB(base64_string):
imgdata = base64.b64decode(str(base64_string))
img = Image.open(io.BytesIO(imgdata)).convert('RGB')
# opencv_img= cv2.cvtColor(np.array(img), cv2.COLOR_BGR2RGB)
return img
def predict_caption(image_str, max_token = 32):
num_beams = 4
gen_kwargs = {"max_length": max_token, "num_beams": num_beams}
images = []
image = stringToRGB(image_str)
images.append(image)
pixel_values = feature_extractor(images=images, return_tensors="pt").pixel_values
pixel_values = pixel_values.to(device)
output_ids = model.generate(pixel_values, **gen_kwargs)
preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
preds = [pred.strip() for pred in preds]
return preds[0]
class EndpointHandler():
def __init__(self, path=""):
# Preload all the elements you are going to need at inference.
# pseudo:
# self.model= load_model(path
model = VisionEncoderDecoderModel.from_pretrained(path)
feature_extractor = ViTFeatureExtractor.from_pretrained(path)
tokenizer = AutoTokenizer.from_pretrained(path)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str` | `PIL.Image` | `np.array`)
kwargs
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
max_token = data.pop("max_token", 32)
img_str = data.pop("data", None)
caption = predict_caption(img_str, max_token=max_token)
return {"caption": f"{caption}"}
# pseudo
# self.model(input)