moooji commited on
Commit
3cd62bd
·
1 Parent(s): dd3fd4d

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +38 -0
handler.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from PIL import Image
3
+ import torch
4
+ import base64
5
+ from io import BytesIO
6
+ from transformers import CLIPProcessor, CLIPModel
7
+
8
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
9
+
10
+ class EndpointHandler():
11
+ def __init__(self, path=""):
12
+ self.model = CLIPModel.from_pretrained("laion/CLIP-ViT-L-14-laion2B-s32B-b82K").to(device)
13
+ self.processor = CLIPProcessor.from_pretrained("laion/CLIP-ViT-L-14-laion2B-s32B-b82K")
14
+
15
+ def __call__(self, data: Any) -> List[float]:
16
+ inputs = data.pop("inputs", data)
17
+
18
+ if "image" in inputs:
19
+ # decode base64 image to PIL
20
+ image = Image.open(BytesIO(base64.b64decode(inputs['image'])))
21
+ inputs = self.processor(images=image, text=None, return_tensors="pt", padding=True).to(device)
22
+
23
+ image_embeds = self.model.get_image_features(
24
+ pixel_values=inputs["pixel_values"]
25
+ )
26
+
27
+ return image_embeds[0].tolist()
28
+ if "text" in inputs:
29
+ text = inputs['text']
30
+ inputs = self.processor(images=None, text=text, return_tensors="pt", padding=True).to(device)
31
+
32
+ text_embeds = self.model.get_text_features(
33
+ input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"]
34
+ )
35
+
36
+ return text_embeds[0].tolist()
37
+
38
+ raise Exception("No 'image' or 'text' provided")