moooji's picture
Create handler.py
3cd62bd
raw
history blame
1.47 kB
from typing import Dict, List, Any
from PIL import Image
import torch
import base64
from io import BytesIO
from transformers import CLIPProcessor, CLIPModel
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class EndpointHandler():
def __init__(self, path=""):
self.model = CLIPModel.from_pretrained("laion/CLIP-ViT-L-14-laion2B-s32B-b82K").to(device)
self.processor = CLIPProcessor.from_pretrained("laion/CLIP-ViT-L-14-laion2B-s32B-b82K")
def __call__(self, data: Any) -> List[float]:
inputs = data.pop("inputs", data)
if "image" in inputs:
# decode base64 image to PIL
image = Image.open(BytesIO(base64.b64decode(inputs['image'])))
inputs = self.processor(images=image, text=None, return_tensors="pt", padding=True).to(device)
image_embeds = self.model.get_image_features(
pixel_values=inputs["pixel_values"]
)
return image_embeds[0].tolist()
if "text" in inputs:
text = inputs['text']
inputs = self.processor(images=None, text=text, return_tensors="pt", padding=True).to(device)
text_embeds = self.model.get_text_features(
input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"]
)
return text_embeds[0].tolist()
raise Exception("No 'image' or 'text' provided")