moooji's picture
Update handler.py
ed9455f
raw
history blame
971 Bytes
from typing import Dict, List, Any
from PIL import Image
import torch
import base64
from io import BytesIO
from transformers import AutoImageProcessor, Swinv2Model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class EndpointHandler():
def __init__(self, path=""):
self.model = Swinv2Model.from_pretrained("microsoft/swinv2-large-patch4-window12to24-192to384-22kto1k-ft").to(device)
self.processor = AutoImageProcessor.from_pretrained("microsoft/swinv2-large-patch4-window12to24-192to384-22kto1k-ft")
def __call__(self, data: Any) -> List[float]:
inputs = data.pop("inputs", data)
image = Image.open(BytesIO(base64.b64decode(inputs['image'])))
inputs = self.processor(image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
last_hidden_states = outputs.last_hidden_state
return last_hidden_states[2].tolist()