|
""" |
|
streamlit run app.py --server.address 0.0.0.0 |
|
""" |
|
|
|
from __future__ import annotations |
|
|
|
import os |
|
from time import time |
|
from typing import Literal |
|
|
|
import streamlit as st |
|
import torch |
|
from open_clip import create_model_and_transforms, get_tokenizer |
|
from openai import OpenAI |
|
from qdrant_client import QdrantClient |
|
from qdrant_client.http import models |
|
|
|
if os.getenv("SPACE_ID"): |
|
USE_HF_SPACE = True |
|
os.environ["HF_HOME"] = "/data/.huggingface" |
|
os.environ["HF_DATASETS_CACHE"] = "/data/.huggingface" |
|
else: |
|
USE_HF_SPACE = False |
|
|
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY") |
|
QDRANT_API_ENDPOINT = os.environ.get("QDRANT_API_ENDPOINT") |
|
QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY") |
|
|
|
BASE_IMAGE_URL = "https://storage.googleapis.com/secons-site-images/photo/" |
|
TargetImageType = Literal["xsmall", "small", "medium", "large"] |
|
|
|
if not QDRANT_API_ENDPOINT or not QDRANT_API_KEY: |
|
raise ValueError("env: QDRANT_API_ENDPOINT or QDRANT_API_KEY is not set.") |
|
|
|
|
|
def get_image_url(image_name: str, image_type: TargetImageType = "xsmall") -> str: |
|
return f"{BASE_IMAGE_URL}{image_type}/{image_name}.webp" |
|
|
|
|
|
@st.cache_resource |
|
def get_model_preprocess_tokenizer( |
|
target_model: str = "xlm-roberta-base-ViT-B-32", |
|
pretrained: str = "laion5B-s13B-b90k", |
|
): |
|
model, _, preprocess = create_model_and_transforms( |
|
target_model, pretrained=pretrained |
|
) |
|
tokenizer = get_tokenizer(target_model) |
|
return model, preprocess, tokenizer |
|
|
|
|
|
@st.cache_resource |
|
def get_qdrant_client(): |
|
qdrant_client = QdrantClient( |
|
url=QDRANT_API_ENDPOINT, |
|
api_key=QDRANT_API_KEY, |
|
) |
|
return qdrant_client |
|
|
|
|
|
@st.cache_data |
|
def get_text_features(text: str): |
|
model, preprocess, tokenizer = get_model_preprocess_tokenizer() |
|
text_tokenized = tokenizer([text]) |
|
with torch.no_grad(): |
|
text_features = model.encode_text(text_tokenized) |
|
text_features /= text_features.norm(dim=-1, keepdim=True) |
|
|
|
return text_features[0].tolist() |
|
|
|
|
|
def app(): |
|
_, _, _ = get_model_preprocess_tokenizer() |
|
st.title("secon.dev site search") |
|
search_text = st.text_input("Search", key="search_text") |
|
if search_text: |
|
st.write("searching...") |
|
start = time() |
|
qdrant_client = get_qdrant_client() |
|
text_features = get_text_features(search_text) |
|
search_results = qdrant_client.search( |
|
collection_name="images-clip", |
|
query_vector=text_features, |
|
limit=50, |
|
) |
|
elapsed = time() - start |
|
st.write(f"elapsed: {elapsed:.2f} sec") |
|
st.write(f"total: {len(search_results)}") |
|
images = [] |
|
captions = [] |
|
for r in search_results: |
|
score = r.score |
|
if payload := r.payload: |
|
name = payload["name"] |
|
else: |
|
name = "unknown" |
|
image_url = get_image_url(name, image_type="xsmall") |
|
images.append(image_url) |
|
captions.append(f"{name} ({score:.4f})") |
|
image_group_n = 6 |
|
for i in range(0, len(images), image_group_n): |
|
target_images = images[i : i + image_group_n] |
|
target_captions = captions[i : i + image_group_n] |
|
st.image( |
|
target_images, |
|
caption=target_captions, |
|
width=160, |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
st.set_page_config( |
|
layout="wide", page_icon="https://secon.dev/images/profile_usa.png" |
|
) |
|
app() |
|
|