tcm03
commited on
Commit
·
aef4077
1
Parent(s):
1060235
Enable text-only embedding request
Browse files- handler.py +21 -4
handler.py
CHANGED
@@ -47,6 +47,14 @@ def get_image_embedding(image_base64, model, transformer):
|
|
47 |
image_feature = image_feature / image_feature.norm(dim=-1, keepdim=True)
|
48 |
return image_feature.cpu().numpy().tolist()
|
49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
class EndpointHandler:
|
51 |
def __init__(self, path: str = ""):
|
52 |
"""
|
@@ -80,9 +88,10 @@ class EndpointHandler:
|
|
80 |
Returns:
|
81 |
dict: {"embedding": [float, float, ...]}
|
82 |
"""
|
83 |
-
|
84 |
inputs = data.pop("inputs", data)
|
85 |
-
|
|
|
86 |
sketch_base64 = inputs.get("sketch", "")
|
87 |
text_query = inputs.get("text", "")
|
88 |
if not sketch_base64 or not text_query:
|
@@ -91,11 +100,19 @@ class EndpointHandler:
|
|
91 |
# Generate Fused Embedding
|
92 |
fused_embedding = get_fused_embedding(sketch_base64, text_query, self.model, self.transform)
|
93 |
return {"embedding": fused_embedding}
|
94 |
-
|
|
|
95 |
image_base64 = inputs.get("image", "")
|
96 |
if not image_base64:
|
97 |
return {"error": "Image 'image' (base64) is required input."}
|
98 |
embedding = get_image_embedding(image_base64, self.model, self.transform)
|
99 |
return {"embedding": embedding}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
else:
|
101 |
-
return {"error": "
|
|
|
47 |
image_feature = image_feature / image_feature.norm(dim=-1, keepdim=True)
|
48 |
return image_feature.cpu().numpy().tolist()
|
49 |
|
50 |
+
def get_text_embedding(text, model):
|
51 |
+
"""Convert text query to tensor."""
|
52 |
+
text_tensor = preprocess_text(text)
|
53 |
+
with torch.no_grad():
|
54 |
+
text_feature = model.encode_text(text_tensor)
|
55 |
+
text_feature = text_feature / text_feature.norm(dim=-1, keepdim=True)
|
56 |
+
return text_feature.cpu().numpy().tolist()
|
57 |
+
|
58 |
class EndpointHandler:
|
59 |
def __init__(self, path: str = ""):
|
60 |
"""
|
|
|
88 |
Returns:
|
89 |
dict: {"embedding": [float, float, ...]}
|
90 |
"""
|
91 |
+
|
92 |
inputs = data.pop("inputs", data)
|
93 |
+
# text-sketch embedding
|
94 |
+
if len(inputs) == 2 and "sketch" in inputs and "text" in inputs:
|
95 |
sketch_base64 = inputs.get("sketch", "")
|
96 |
text_query = inputs.get("text", "")
|
97 |
if not sketch_base64 or not text_query:
|
|
|
100 |
# Generate Fused Embedding
|
101 |
fused_embedding = get_fused_embedding(sketch_base64, text_query, self.model, self.transform)
|
102 |
return {"embedding": fused_embedding}
|
103 |
+
# image-only embedding
|
104 |
+
elif len(inputs) == 1 and "image" in inputs:
|
105 |
image_base64 = inputs.get("image", "")
|
106 |
if not image_base64:
|
107 |
return {"error": "Image 'image' (base64) is required input."}
|
108 |
embedding = get_image_embedding(image_base64, self.model, self.transform)
|
109 |
return {"embedding": embedding}
|
110 |
+
# text-only embedding
|
111 |
+
elif len(inputs) == 1 and "text" in inputs:
|
112 |
+
text_query = inputs.get("text", "")
|
113 |
+
if not text_query:
|
114 |
+
return {"error": "Text 'text' is required input."}
|
115 |
+
embedding = get_text_embedding(text_query, self.model)
|
116 |
+
return {"embedding": embedding}
|
117 |
else:
|
118 |
+
return {"error": "Invalid request."}
|