Commit
·
90f3fab
1
Parent(s):
cdd2f2d
Refactor app.py: Update imports, add get_image_url function, and optimize search functionality
Browse files
app.py
CHANGED
@@ -6,11 +6,11 @@ from __future__ import annotations
|
|
6 |
|
7 |
import os
|
8 |
from time import time
|
|
|
9 |
|
10 |
-
import faiss
|
11 |
-
import pandas as pd
|
12 |
import streamlit as st
|
13 |
-
|
|
|
14 |
from openai import OpenAI
|
15 |
from qdrant_client import QdrantClient
|
16 |
from qdrant_client.http import models
|
@@ -29,16 +29,27 @@ OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
|
|
29 |
QDRANT_API_ENDPOINT = os.environ.get("QDRANT_API_ENDPOINT")
|
30 |
QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY")
|
31 |
|
|
|
|
|
|
|
32 |
if not QDRANT_API_ENDPOINT or not QDRANT_API_KEY:
|
33 |
raise ValueError("env: QDRANT_API_ENDPOINT or QDRANT_API_KEY is not set.")
|
34 |
|
35 |
|
|
|
|
|
|
|
|
|
36 |
@st.cache_resource
|
37 |
-
def
|
|
|
|
|
|
|
38 |
model, _, preprocess = create_model_and_transforms(
|
39 |
-
|
40 |
)
|
41 |
-
|
|
|
42 |
|
43 |
|
44 |
@st.cache_resource
|
@@ -50,9 +61,48 @@ def get_qdrant_client():
|
|
50 |
return qdrant_client
|
51 |
|
52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
def app():
|
|
|
54 |
st.title("secon.dev site search")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
|
57 |
if __name__ == "__main__":
|
|
|
|
|
|
|
58 |
app()
|
|
|
6 |
|
7 |
import os
|
8 |
from time import time
|
9 |
+
from typing import Literal
|
10 |
|
|
|
|
|
11 |
import streamlit as st
|
12 |
+
import torch
|
13 |
+
from open_clip import create_model_and_transforms, get_tokenizer
|
14 |
from openai import OpenAI
|
15 |
from qdrant_client import QdrantClient
|
16 |
from qdrant_client.http import models
|
|
|
29 |
QDRANT_API_ENDPOINT = os.environ.get("QDRANT_API_ENDPOINT")
|
30 |
QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY")
|
31 |
|
32 |
+
BASE_IMAGE_URL = "https://storage.googleapis.com/secons-site-images/photo/"
|
33 |
+
TargetImageType = Literal["xsmall", "small", "medium", "large"]
|
34 |
+
|
35 |
if not QDRANT_API_ENDPOINT or not QDRANT_API_KEY:
|
36 |
raise ValueError("env: QDRANT_API_ENDPOINT or QDRANT_API_KEY is not set.")
|
37 |
|
38 |
|
39 |
+
def get_image_url(image_name: str, image_type: TargetImageType = "xsmall") -> str:
|
40 |
+
return f"{BASE_IMAGE_URL}{image_type}/{image_name}.webp"
|
41 |
+
|
42 |
+
|
43 |
@st.cache_resource
|
44 |
+
def get_model_preprocess_tokenizer(
|
45 |
+
target_model: str = "xlm-roberta-base-ViT-B-32",
|
46 |
+
pretrained: str = "laion5B-s13B-b90k",
|
47 |
+
):
|
48 |
model, _, preprocess = create_model_and_transforms(
|
49 |
+
target_model, pretrained=pretrained
|
50 |
)
|
51 |
+
tokenizer = get_tokenizer(target_model)
|
52 |
+
return model, preprocess, tokenizer
|
53 |
|
54 |
|
55 |
@st.cache_resource
|
|
|
61 |
return qdrant_client
|
62 |
|
63 |
|
64 |
+
@st.cache_data
|
65 |
+
def get_text_features(text: str):
|
66 |
+
model, preprocess, tokenizer = get_model_preprocess_tokenizer()
|
67 |
+
text_tokenized = tokenizer([text])
|
68 |
+
with torch.no_grad():
|
69 |
+
text_features = model.encode_text(text_tokenized) # type: ignore
|
70 |
+
text_features /= text_features.norm(dim=-1, keepdim=True)
|
71 |
+
# tensor to list
|
72 |
+
return text_features[0].tolist()
|
73 |
+
|
74 |
+
|
75 |
def app():
|
76 |
+
_, _, _ = get_model_preprocess_tokenizer() # for cache
|
77 |
st.title("secon.dev site search")
|
78 |
+
search_text = st.text_input("Search", key="search_text")
|
79 |
+
if search_text:
|
80 |
+
st.write("searching...")
|
81 |
+
start = time()
|
82 |
+
qdrant_client = get_qdrant_client()
|
83 |
+
text_features = get_text_features(search_text)
|
84 |
+
search_results = qdrant_client.search(
|
85 |
+
collection_name="images-clip",
|
86 |
+
query_vector=text_features,
|
87 |
+
limit=20,
|
88 |
+
)
|
89 |
+
elapsed = time() - start
|
90 |
+
st.write(f"elapsed: {elapsed:.2f} sec")
|
91 |
+
st.write(f"total: {len(search_results)}")
|
92 |
+
for r in search_results:
|
93 |
+
score = r.score
|
94 |
+
if payload := r.payload:
|
95 |
+
name = payload["name"]
|
96 |
+
else:
|
97 |
+
name = "unknown"
|
98 |
+
image_url = get_image_url(name, image_type="xsmall")
|
99 |
+
st.write(f"score: {score:.2f}")
|
100 |
+
st.image(image_url, width=200)
|
101 |
+
st.write("---")
|
102 |
|
103 |
|
104 |
if __name__ == "__main__":
|
105 |
+
st.set_page_config(
|
106 |
+
layout="wide", page_icon="https://secon.dev/images/profile_usa.png"
|
107 |
+
)
|
108 |
app()
|