|
from typing import Dict, List, Any |
|
from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer |
|
|
|
from PIL import Image |
|
import torch |
|
import base64 |
|
from base64 import b64encode |
|
import requests |
|
import json |
|
import io |
|
|
|
|
|
def stringToRGB(base64_string): |
|
imgdata = base64.b64decode(str(base64_string)) |
|
img = Image.open(io.BytesIO(imgdata)).convert('RGB') |
|
|
|
return img |
|
|
|
|
|
def predict_caption(image_str, max_token = 32): |
|
|
|
num_beams = 4 |
|
gen_kwargs = {"max_length": max_token, "num_beams": num_beams} |
|
|
|
images = [] |
|
image = stringToRGB(image_str) |
|
images.append(image) |
|
|
|
pixel_values = feature_extractor(images=images, return_tensors="pt").pixel_values |
|
pixel_values = pixel_values.to(device) |
|
|
|
output_ids = model.generate(pixel_values, **gen_kwargs) |
|
|
|
preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True) |
|
preds = [pred.strip() for pred in preds] |
|
return preds[0] |
|
|
|
|
|
class EndpointHandler(): |
|
def __init__(self, path=""): |
|
|
|
|
|
|
|
model = VisionEncoderDecoderModel.from_pretrained(path) |
|
feature_extractor = ViTFeatureExtractor.from_pretrained(path) |
|
tokenizer = AutoTokenizer.from_pretrained(path) |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model = model.to(device) |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
""" |
|
data args: |
|
inputs (:obj: `str` | `PIL.Image` | `np.array`) |
|
kwargs |
|
Return: |
|
A :obj:`list` | `dict`: will be serialized and returned |
|
""" |
|
max_token = data.pop("max_token", 32) |
|
img_str = data.pop("data", None) |
|
|
|
caption = predict_caption(img_str, max_token=max_token) |
|
return {"caption": f"{caption}"} |
|
|
|
|
|
|