File size: 4,849 Bytes
effd0b9
 
 
 
 
 
 
 
 
941ce80
 
effd0b9
 
 
 
 
8d4eb6b
 
 
 
 
 
 
 
 
 
1060235
8d4eb6b
 
1060235
8d4eb6b
 
 
 
 
 
 
 
 
 
 
1060235
8d4eb6b
1060235
8d4eb6b
 
 
 
 
aef4077
 
 
 
 
 
 
 
b6ff56b
effd0b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8d4eb6b
effd0b9
8d4eb6b
effd0b9
aef4077
effd0b9
aef4077
 
8d4eb6b
 
 
 
effd0b9
8d4eb6b
1060235
8d4eb6b
aef4077
 
8d4eb6b
 
 
1060235
8d4eb6b
aef4077
 
 
 
 
 
 
8d4eb6b
aef4077
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from typing import Dict, List, Any
from PIL import Image
import torch
import base64
import os
from io import BytesIO
import json

import sys
CODE_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "code")
sys.path.append(CODE_PATH)
from clip.model import CLIP
from clip.clip import _transform, tokenize

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def preprocess_image(image_base64, transformer):
    """Convert base64 encoded sketch to tensor."""
    image = Image.open(BytesIO(base64.b64decode(image_base64))).convert("RGB")
    image = transformer(image).unsqueeze(0).to(device)
    return image

def preprocess_text(text):
    """Tokenize text query."""
    return tokenize([str(text)])[0].unsqueeze(0).to(device)

def get_fused_embedding(sketch_base64, text, model, transformer):
    """Fuse sketch and text features into a single embedding."""
    with torch.no_grad():
        sketch_tensor = preprocess_image(sketch_base64, transformer)
        text_tensor = preprocess_text(text)

        sketch_feature = model.encode_sketch(sketch_tensor)
        text_feature = model.encode_text(text_tensor)
        
        sketch_feature = sketch_feature / sketch_feature.norm(dim=-1, keepdim=True)
        text_feature = text_feature / text_feature.norm(dim=-1, keepdim=True)

        fused_embedding = model.feature_fuse(sketch_feature, text_feature)
    return fused_embedding.cpu().numpy().tolist()

def get_image_embedding(image_base64, model, transformer):
    """Convert base64 encoded image to tensor."""
    image_tensor = preprocess_image(image_base64, transformer)
    with torch.no_grad():
        image_feature = model.encode_image(image_tensor)
        image_feature = image_feature / image_feature.norm(dim=-1, keepdim=True)
    return image_feature.cpu().numpy().tolist()

def get_text_embedding(text, model):
    """Convert text query to tensor."""
    text_tensor = preprocess_text(text)
    with torch.no_grad():
        text_feature = model.encode_text(text_tensor)
        text_feature = text_feature / text_feature.norm(dim=-1, keepdim=True)
    return text_feature.cpu().numpy().tolist()

class EndpointHandler:
    def __init__(self, path: str = ""):
        """
        Initialize the pipeline by loading the model.
        Args:
            path (str): Path to the directory containing model weights and config.
        """
        model_config_file = os.path.join(path, "code/training/model_configs/ViT-B-16.json")
        with open(model_config_file, "r") as f:
            model_info = json.load(f)
        
        model_file = os.path.join(path, "model/tsbir_model_final.pt")
        self.model = CLIP(**model_info)
        checkpoint = torch.load(model_file, map_location=device)

        sd = checkpoint["state_dict"]
        if next(iter(sd.items()))[0].startswith("module"):
            sd = {k[len("module."):]: v for k, v in sd.items()}
        
        self.model.load_state_dict(sd, strict=False)
        self.model = self.model.to(device).eval()
        
        # Preprocessing
        self.transform = _transform(self.model.visual.input_resolution, is_train=False)
    
    def __call__(self, data: Any) -> Dict[str, List[float]]:
        """
        Process the request and return the fused embedding.
        Args:
            data (dict): Includes 'sketch' (base64) and 'text' (str) inputs, or 'image' (base64)
        Returns:
            dict: {"embedding": [float, float, ...]}
        """
        
        inputs = data.pop("inputs", data)
        # text-sketch embedding
        if len(inputs) == 2 and "sketch" in inputs and "text" in inputs:
            sketch_base64 = inputs.get("sketch", "")
            text_query = inputs.get("text", "")
            if not sketch_base64 or not text_query:
                return {"error": "Both 'sketch' (base64) and 'text' are required inputs."}

            # Generate Fused Embedding
            fused_embedding = get_fused_embedding(sketch_base64, text_query, self.model, self.transform)
            return {"embedding": fused_embedding}
        # image-only embedding
        elif len(inputs) == 1 and "image" in inputs:
            image_base64 = inputs.get("image", "")
            if not image_base64:
                return {"error": "Image 'image' (base64) is required input."}
            embedding = get_image_embedding(image_base64, self.model, self.transform)
            return {"embedding": embedding}
        # text-only embedding
        elif len(inputs) == 1 and "text" in inputs:
            text_query = inputs.get("text", "")
            if not text_query:
                return {"error": "Text 'text' is required input."}
            embedding = get_text_embedding(text_query, self.model)
            return {"embedding": embedding}
        else:
            return {"error": "Invalid request."}