tsbir / inference.py
tcm03
Modify inference.py and Add YAML metadata for Hugging Face Hub
5c18c06
raw
history blame
2.83 kB
import torch
from PIL import Image
import base64
from io import BytesIO
import json
import sys
sys.path.append("code")
from clip.model import CLIP
from clip.clip import _transform, tokenize
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_PATH = "model/tsbir_model_final.pt"
CONFIG_PATH = "code/training/model_configs/ViT-B-16.json"
def load_model():
"""Load the model only once."""
global model
if "model" not in globals():
with open(CONFIG_PATH, 'r') as f:
model_info = json.load(f)
model = CLIP(**model_info)
checkpoint = torch.load(MODEL_PATH, map_location=device)
sd = checkpoint["state_dict"]
if next(iter(sd.items()))[0].startswith('module'):
sd = {k[len('module.'):]: v for k, v in sd.items()}
model.load_state_dict(sd, strict=False)
model = model.to(device).eval()
# Initialize transformer
global transformer
transformer = _transform(model.visual.input_resolution, is_train=False)
print("Model loaded successfully.")
# Preprocessing Functions
def preprocess_image(image_base64):
"""Convert base64 encoded image to tensor."""
image = Image.open(BytesIO(base64.b64decode(image_base64))).convert("RGB")
image = transformer(image).unsqueeze(0).to(device)
return image
def preprocess_text(text):
"""Tokenize text query."""
return tokenize([str(text)])[0].unsqueeze(0).to(device)
def get_fused_embedding(image_base64, text):
"""Fuse sketch and text features into a single embedding."""
with torch.no_grad():
# Preprocess Inputs
image_tensor = preprocess_image(image_base64)
text_tensor = preprocess_text(text)
# Extract Features
sketch_feature = model.encode_sketch(image_tensor)
text_feature = model.encode_text(text_tensor)
# Normalize Features
sketch_feature = sketch_feature / sketch_feature.norm(dim=-1, keepdim=True)
text_feature = text_feature / text_feature.norm(dim=-1, keepdim=True)
# Fuse Features
fused_embedding = model.feature_fuse(sketch_feature, text_feature)
return fused_embedding.cpu().numpy().tolist()
# Hugging Face Inference API Entry Point
def infer(inputs):
"""
Inference API entry point.
Inputs:
- 'image': Base64 encoded sketch image.
- 'text': Text query.
"""
load_model() # Ensure the model is loaded once
image_base64 = inputs.get("image", "")
text_query = inputs.get("text", "")
if not image_base64 or not text_query:
return {"error": "Both 'image' (base64) and 'text' are required inputs."}
# Generate Fused Embedding
fused_embedding = get_fused_embedding(image_base64, text_query)
return {"fused_embedding": fused_embedding}