Spaces:
Running
Running
import streamlit as st | |
import json | |
from typing import List | |
from fastembed import LateInteractionTextEmbedding, TextEmbedding | |
from fastembed import SparseTextEmbedding, SparseEmbedding | |
from qdrant_client import QdrantClient, models | |
from tokenizers import Tokenizer | |
############################# | |
# 1. Utility / Helper Code | |
############################# | |
def load_tokenizer(): | |
""" | |
Load the tokenizer for interpreting sparse embeddings (optional usage). | |
""" | |
return Tokenizer.from_pretrained(SparseTextEmbedding.list_supported_models()[0]["sources"]["hf"]) | |
def load_models(): | |
""" | |
Load/initialize your models once and cache them. | |
""" | |
# Dense embedding model | |
dense_embedding_model = TextEmbedding("BAAI/bge-small-en-v1.5") | |
# Late interaction model (ColBERTv2) | |
late_embedding_model = LateInteractionTextEmbedding("colbert-ir/colbertv2.0") | |
# Sparse embedding model | |
sparse_model_name = "Qdrant/bm25" | |
sparse_model = SparseTextEmbedding(model_name=sparse_model_name) | |
return dense_embedding_model, late_embedding_model, sparse_model | |
def build_qdrant_index(data): | |
""" | |
Given the parsed data (list of items), build an in-memory Qdrant index | |
with dense, late, and sparse vectors. | |
""" | |
# Extract fields | |
items = data["items"] | |
descriptions = [f"{item['name']} - {item['description']}" for item in items] | |
names = [item["name"] for item in items] | |
metadata = [ | |
{"name": item["name"]} # You can store more fields if you like | |
for item in items | |
] | |
# Load models | |
dense_embedding_model, late_embedding_model, sparse_model = load_models() | |
# Generate embeddings | |
dense_embeddings = list(dense_embedding_model.embed(descriptions)) | |
name_dense_embeddings = list(dense_embedding_model.embed(names)) | |
late_embeddings = list(late_embedding_model.embed(descriptions)) | |
sparse_embeddings: List[SparseEmbedding] = list(sparse_model.embed(descriptions, batch_size=6)) | |
# Create an in-memory Qdrant instance | |
qdrant_client = QdrantClient(":memory:") | |
# Create collection schema | |
qdrant_client.create_collection( | |
collection_name="items", | |
vectors_config={ | |
"dense": models.VectorParams( | |
size=len(dense_embeddings[0]), | |
distance=models.Distance.COSINE, | |
), | |
"late": models.VectorParams( | |
size=len(late_embeddings[0][0]), | |
distance=models.Distance.COSINE, | |
multivector_config=models.MultiVectorConfig( | |
comparator=models.MultiVectorComparator.MAX_SIM | |
), | |
), | |
}, | |
sparse_vectors_config={ | |
"sparse": models.SparseVectorParams( | |
modifier=models.Modifier.IDF, | |
), | |
} | |
) | |
# Upload points | |
points = [] | |
for idx, _ in enumerate(metadata): | |
points.append( | |
models.PointStruct( | |
id=idx, | |
payload=metadata[idx], | |
vector={ | |
"late": late_embeddings[idx].tolist(), | |
"dense": dense_embeddings[idx], | |
"sparse": sparse_embeddings[idx].as_object(), | |
}, | |
) | |
) | |
qdrant_client.upload_points( | |
collection_name="items", | |
points=points, | |
) | |
return qdrant_client | |
def run_queries(qdrant_client, query_text): | |
""" | |
Run all the different query types and return results in a dictionary. | |
""" | |
# Load models | |
dense_embedding_model, late_embedding_model, sparse_model = load_models() | |
# Generate single-query embeddings | |
dense_query = next(dense_embedding_model.query_embed(query_text)) | |
late_query = next(late_embedding_model.query_embed(query_text)) | |
sparse_query = next(sparse_model.query_embed(query_text)) | |
# For the fusion approach, we need a list form for prefetch | |
tsq = list(sparse_model.embed(query_text, batch_size=6)) | |
# We'll store top-5 results for each approach | |
results = {} | |
# 1) ColBERT (late) | |
results["C"] = qdrant_client.query_points( | |
collection_name="items", | |
query=late_query, | |
using="late", | |
limit=5, | |
with_payload=True | |
) | |
# 2) Sparse only | |
results["S"] = qdrant_client.query_points( | |
collection_name="items", | |
query=models.SparseVector(**sparse_query.as_object()), | |
using="sparse", | |
limit=5, | |
with_payload=True | |
) | |
# 3) Dense only | |
results["D"] = qdrant_client.query_points( | |
collection_name="items", | |
query=dense_query, | |
using="dense", | |
limit=5, | |
with_payload=True | |
) | |
# 4) Hybrid fusion (RRF for Sparse+Dense) | |
results["S+D-F"] = qdrant_client.query_points( | |
collection_name="items", | |
prefetch=[ | |
models.Prefetch( | |
query=dense_query, | |
using="dense", | |
limit=100, | |
), | |
models.Prefetch( | |
query=tsq[0].as_object(), | |
using="sparse", | |
limit=50, | |
) | |
], | |
query=models.FusionQuery(fusion=models.Fusion.RRF), | |
limit=5, | |
with_payload=True | |
) | |
# 5) Hybrid fusion + ColBERT | |
sparse_dense_prefetch = models.Prefetch( | |
prefetch=[ | |
models.Prefetch(query=dense_query, using="dense", limit=100), | |
models.Prefetch(query=tsq[0].as_object(), using="sparse", limit=50), | |
], | |
limit=10, | |
query=models.FusionQuery(fusion=models.Fusion.RRF), | |
) | |
results["S+D-F-C"] = qdrant_client.query_points( | |
collection_name="items", | |
prefetch=[sparse_dense_prefetch], | |
query=late_query, | |
using="late", | |
limit=5, | |
with_payload=True | |
) | |
# 6) Hybrid no-fusion + ColBERT | |
old_prefetch = models.Prefetch( | |
prefetch=[ | |
models.Prefetch( | |
prefetch=[ | |
models.Prefetch(query=dense_query, using="dense", limit=100) | |
], | |
query=tsq[0].as_object(), | |
using="sparse", | |
limit=50, | |
) | |
] | |
) | |
results["S+D-C"] = qdrant_client.query_points( | |
collection_name="items", | |
prefetch=[old_prefetch], | |
query=late_query, | |
using="late", | |
limit=5, | |
with_payload=True | |
) | |
return results | |
############################# | |
# 2. Streamlit Main App | |
############################# | |
def main(): | |
st.title("Semantic Search Sandbox") | |
# Initialize session state if not present | |
if "json_loaded" not in st.session_state: | |
st.session_state["json_loaded"] = False | |
if "qdrant_client" not in st.session_state: | |
st.session_state["qdrant_client"] = None | |
####################################### | |
# Show JSON input only if not loaded | |
####################################### | |
if not st.session_state["json_loaded"]: | |
st.subheader("Paste items.json Here") | |
default_json = """ | |
{ | |
"items": [ | |
{ | |
"name": "Example1", | |
"description": "An example item" | |
}, | |
{ | |
"name": "Example2", | |
"description": "Another item for demonstration" | |
} | |
] | |
} | |
""".strip() | |
json_text = st.text_area("JSON Input", value=default_json, height=300) | |
if st.button("Load JSON"): | |
try: | |
data = json.loads(json_text) | |
# Build Qdrant index in memory | |
st.session_state["qdrant_client"] = build_qdrant_index(data) | |
st.session_state["json_loaded"] = True | |
st.success("JSON loaded and Qdrant index built successfully!") | |
st.rerun() | |
except Exception as e: | |
st.error(f"Error parsing JSON: {e}") | |
else: | |
# The data is loaded, show a button to reset if you want to load new JSON | |
if st.button("Load a different JSON"): | |
st.session_state["json_loaded"] = False | |
st.session_state["qdrant_client"] = None | |
#st.experimental_rerun() # Refresh the page | |
else: | |
# Show the search interface | |
query_text = st.text_input("Search Query", value="ACB 1.0 Ports") | |
if st.button("Search"): | |
if st.session_state["qdrant_client"] is None: | |
st.warning("Please load valid JSON first.") | |
return | |
# Run queries | |
results_dict = run_queries(st.session_state["qdrant_client"], query_text) | |
# Display results in columns | |
col_names = list(results_dict.keys()) | |
# You can split into multiple rows if there are more than 3 | |
n_cols = 3 | |
# We'll create enough columns to handle all search types | |
rows_needed = (len(col_names) + n_cols - 1) // n_cols | |
for row_idx in range(rows_needed): | |
cols = st.columns(n_cols) | |
for col_idx in range(n_cols): | |
method_idx = row_idx * n_cols + col_idx | |
if method_idx < len(col_names): | |
method = col_names[method_idx] | |
qdrant_result = results_dict[method] | |
with cols[col_idx]: | |
st.markdown(f"### {method}") | |
for point in qdrant_result.points: | |
name = point.payload.get("name", "Unnamed") | |
score = round(point.score, 4) if point.score else "N/A" | |
st.write(f"- **{name}** (score={score})") | |
if __name__ == "__main__": | |
main() |