tsbir / handler.py
tcm03
Enable text-only embedding request
aef4077
raw
history blame
4.85 kB
from typing import Dict, List, Any
from PIL import Image
import torch
import base64
import os
from io import BytesIO
import json
import sys
CODE_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "code")
sys.path.append(CODE_PATH)
from clip.model import CLIP
from clip.clip import _transform, tokenize
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def preprocess_image(image_base64, transformer):
"""Convert base64 encoded sketch to tensor."""
image = Image.open(BytesIO(base64.b64decode(image_base64))).convert("RGB")
image = transformer(image).unsqueeze(0).to(device)
return image
def preprocess_text(text):
"""Tokenize text query."""
return tokenize([str(text)])[0].unsqueeze(0).to(device)
def get_fused_embedding(sketch_base64, text, model, transformer):
"""Fuse sketch and text features into a single embedding."""
with torch.no_grad():
sketch_tensor = preprocess_image(sketch_base64, transformer)
text_tensor = preprocess_text(text)
sketch_feature = model.encode_sketch(sketch_tensor)
text_feature = model.encode_text(text_tensor)
sketch_feature = sketch_feature / sketch_feature.norm(dim=-1, keepdim=True)
text_feature = text_feature / text_feature.norm(dim=-1, keepdim=True)
fused_embedding = model.feature_fuse(sketch_feature, text_feature)
return fused_embedding.cpu().numpy().tolist()
def get_image_embedding(image_base64, model, transformer):
"""Convert base64 encoded image to tensor."""
image_tensor = preprocess_image(image_base64, transformer)
with torch.no_grad():
image_feature = model.encode_image(image_tensor)
image_feature = image_feature / image_feature.norm(dim=-1, keepdim=True)
return image_feature.cpu().numpy().tolist()
def get_text_embedding(text, model):
"""Convert text query to tensor."""
text_tensor = preprocess_text(text)
with torch.no_grad():
text_feature = model.encode_text(text_tensor)
text_feature = text_feature / text_feature.norm(dim=-1, keepdim=True)
return text_feature.cpu().numpy().tolist()
class EndpointHandler:
def __init__(self, path: str = ""):
"""
Initialize the pipeline by loading the model.
Args:
path (str): Path to the directory containing model weights and config.
"""
model_config_file = os.path.join(path, "code/training/model_configs/ViT-B-16.json")
with open(model_config_file, "r") as f:
model_info = json.load(f)
model_file = os.path.join(path, "model/tsbir_model_final.pt")
self.model = CLIP(**model_info)
checkpoint = torch.load(model_file, map_location=device)
sd = checkpoint["state_dict"]
if next(iter(sd.items()))[0].startswith("module"):
sd = {k[len("module."):]: v for k, v in sd.items()}
self.model.load_state_dict(sd, strict=False)
self.model = self.model.to(device).eval()
# Preprocessing
self.transform = _transform(self.model.visual.input_resolution, is_train=False)
def __call__(self, data: Any) -> Dict[str, List[float]]:
"""
Process the request and return the fused embedding.
Args:
data (dict): Includes 'sketch' (base64) and 'text' (str) inputs, or 'image' (base64)
Returns:
dict: {"embedding": [float, float, ...]}
"""
inputs = data.pop("inputs", data)
# text-sketch embedding
if len(inputs) == 2 and "sketch" in inputs and "text" in inputs:
sketch_base64 = inputs.get("sketch", "")
text_query = inputs.get("text", "")
if not sketch_base64 or not text_query:
return {"error": "Both 'sketch' (base64) and 'text' are required inputs."}
# Generate Fused Embedding
fused_embedding = get_fused_embedding(sketch_base64, text_query, self.model, self.transform)
return {"embedding": fused_embedding}
# image-only embedding
elif len(inputs) == 1 and "image" in inputs:
image_base64 = inputs.get("image", "")
if not image_base64:
return {"error": "Image 'image' (base64) is required input."}
embedding = get_image_embedding(image_base64, self.model, self.transform)
return {"embedding": embedding}
# text-only embedding
elif len(inputs) == 1 and "text" in inputs:
text_query = inputs.get("text", "")
if not text_query:
return {"error": "Text 'text' is required input."}
embedding = get_text_embedding(text_query, self.model)
return {"embedding": embedding}
else:
return {"error": "Invalid request."}