File size: 4,849 Bytes
effd0b9 941ce80 effd0b9 8d4eb6b 1060235 8d4eb6b 1060235 8d4eb6b 1060235 8d4eb6b 1060235 8d4eb6b aef4077 b6ff56b effd0b9 8d4eb6b effd0b9 8d4eb6b effd0b9 aef4077 effd0b9 aef4077 8d4eb6b effd0b9 8d4eb6b 1060235 8d4eb6b aef4077 8d4eb6b 1060235 8d4eb6b aef4077 8d4eb6b aef4077 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
from typing import Dict, List, Any
from PIL import Image
import torch
import base64
import os
from io import BytesIO
import json
import sys
CODE_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "code")
sys.path.append(CODE_PATH)
from clip.model import CLIP
from clip.clip import _transform, tokenize
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def preprocess_image(image_base64, transformer):
"""Convert base64 encoded sketch to tensor."""
image = Image.open(BytesIO(base64.b64decode(image_base64))).convert("RGB")
image = transformer(image).unsqueeze(0).to(device)
return image
def preprocess_text(text):
"""Tokenize text query."""
return tokenize([str(text)])[0].unsqueeze(0).to(device)
def get_fused_embedding(sketch_base64, text, model, transformer):
"""Fuse sketch and text features into a single embedding."""
with torch.no_grad():
sketch_tensor = preprocess_image(sketch_base64, transformer)
text_tensor = preprocess_text(text)
sketch_feature = model.encode_sketch(sketch_tensor)
text_feature = model.encode_text(text_tensor)
sketch_feature = sketch_feature / sketch_feature.norm(dim=-1, keepdim=True)
text_feature = text_feature / text_feature.norm(dim=-1, keepdim=True)
fused_embedding = model.feature_fuse(sketch_feature, text_feature)
return fused_embedding.cpu().numpy().tolist()
def get_image_embedding(image_base64, model, transformer):
"""Convert base64 encoded image to tensor."""
image_tensor = preprocess_image(image_base64, transformer)
with torch.no_grad():
image_feature = model.encode_image(image_tensor)
image_feature = image_feature / image_feature.norm(dim=-1, keepdim=True)
return image_feature.cpu().numpy().tolist()
def get_text_embedding(text, model):
"""Convert text query to tensor."""
text_tensor = preprocess_text(text)
with torch.no_grad():
text_feature = model.encode_text(text_tensor)
text_feature = text_feature / text_feature.norm(dim=-1, keepdim=True)
return text_feature.cpu().numpy().tolist()
class EndpointHandler:
def __init__(self, path: str = ""):
"""
Initialize the pipeline by loading the model.
Args:
path (str): Path to the directory containing model weights and config.
"""
model_config_file = os.path.join(path, "code/training/model_configs/ViT-B-16.json")
with open(model_config_file, "r") as f:
model_info = json.load(f)
model_file = os.path.join(path, "model/tsbir_model_final.pt")
self.model = CLIP(**model_info)
checkpoint = torch.load(model_file, map_location=device)
sd = checkpoint["state_dict"]
if next(iter(sd.items()))[0].startswith("module"):
sd = {k[len("module."):]: v for k, v in sd.items()}
self.model.load_state_dict(sd, strict=False)
self.model = self.model.to(device).eval()
# Preprocessing
self.transform = _transform(self.model.visual.input_resolution, is_train=False)
def __call__(self, data: Any) -> Dict[str, List[float]]:
"""
Process the request and return the fused embedding.
Args:
data (dict): Includes 'sketch' (base64) and 'text' (str) inputs, or 'image' (base64)
Returns:
dict: {"embedding": [float, float, ...]}
"""
inputs = data.pop("inputs", data)
# text-sketch embedding
if len(inputs) == 2 and "sketch" in inputs and "text" in inputs:
sketch_base64 = inputs.get("sketch", "")
text_query = inputs.get("text", "")
if not sketch_base64 or not text_query:
return {"error": "Both 'sketch' (base64) and 'text' are required inputs."}
# Generate Fused Embedding
fused_embedding = get_fused_embedding(sketch_base64, text_query, self.model, self.transform)
return {"embedding": fused_embedding}
# image-only embedding
elif len(inputs) == 1 and "image" in inputs:
image_base64 = inputs.get("image", "")
if not image_base64:
return {"error": "Image 'image' (base64) is required input."}
embedding = get_image_embedding(image_base64, self.model, self.transform)
return {"embedding": embedding}
# text-only embedding
elif len(inputs) == 1 and "text" in inputs:
text_query = inputs.get("text", "")
if not text_query:
return {"error": "Text 'text' is required input."}
embedding = get_text_embedding(text_query, self.model)
return {"embedding": embedding}
else:
return {"error": "Invalid request."}
|