import streamlit as st import json from typing import List from fastembed import LateInteractionTextEmbedding, TextEmbedding from fastembed import SparseTextEmbedding, SparseEmbedding from qdrant_client import QdrantClient, models from tokenizers import Tokenizer ############################# # 1. Utility / Helper Code ############################# @st.cache_resource def load_tokenizer(): """ Load the tokenizer for interpreting sparse embeddings (optional usage). """ return Tokenizer.from_pretrained(SparseTextEmbedding.list_supported_models()[0]["sources"]["hf"]) @st.cache_resource def load_models(): """ Load/initialize your models once and cache them. """ # Dense embedding model dense_embedding_model = TextEmbedding("BAAI/bge-small-en-v1.5") # Late interaction model (ColBERTv2) late_embedding_model = LateInteractionTextEmbedding("colbert-ir/colbertv2.0") # Sparse embedding model sparse_model_name = "Qdrant/bm25" sparse_model = SparseTextEmbedding(model_name=sparse_model_name) return dense_embedding_model, late_embedding_model, sparse_model def build_qdrant_index(data): """ Given the parsed data (list of items), build an in-memory Qdrant index with dense, late, and sparse vectors. """ # Extract fields items = data["items"] descriptions = [f"{item['name']} - {item['description']}" for item in items] names = [item["name"] for item in items] metadata = [ {"name": item["name"]} # You can store more fields if you like for item in items ] # Load models dense_embedding_model, late_embedding_model, sparse_model = load_models() # Generate embeddings dense_embeddings = list(dense_embedding_model.embed(descriptions)) name_dense_embeddings = list(dense_embedding_model.embed(names)) late_embeddings = list(late_embedding_model.embed(descriptions)) sparse_embeddings: List[SparseEmbedding] = list(sparse_model.embed(descriptions, batch_size=6)) # Create an in-memory Qdrant instance qdrant_client = QdrantClient(":memory:") # Create collection schema qdrant_client.create_collection( collection_name="items", vectors_config={ "dense": models.VectorParams( size=len(dense_embeddings[0]), distance=models.Distance.COSINE, ), "late": models.VectorParams( size=len(late_embeddings[0][0]), distance=models.Distance.COSINE, multivector_config=models.MultiVectorConfig( comparator=models.MultiVectorComparator.MAX_SIM ), ), }, sparse_vectors_config={ "sparse": models.SparseVectorParams( modifier=models.Modifier.IDF, ), } ) # Upload points points = [] for idx, _ in enumerate(metadata): points.append( models.PointStruct( id=idx, payload=metadata[idx], vector={ "late": late_embeddings[idx].tolist(), "dense": dense_embeddings[idx], "sparse": sparse_embeddings[idx].as_object(), }, ) ) qdrant_client.upload_points( collection_name="items", points=points, ) return qdrant_client def run_queries(qdrant_client, query_text): """ Run all the different query types and return results in a dictionary. """ # Load models dense_embedding_model, late_embedding_model, sparse_model = load_models() # Generate single-query embeddings dense_query = next(dense_embedding_model.query_embed(query_text)) late_query = next(late_embedding_model.query_embed(query_text)) sparse_query = next(sparse_model.query_embed(query_text)) # For the fusion approach, we need a list form for prefetch tsq = list(sparse_model.embed(query_text, batch_size=6)) # We'll store top-5 results for each approach results = {} # 1) ColBERT (late) results["C"] = qdrant_client.query_points( collection_name="items", query=late_query, using="late", limit=5, with_payload=True ) # 2) Sparse only results["S"] = qdrant_client.query_points( collection_name="items", query=models.SparseVector(**sparse_query.as_object()), using="sparse", limit=5, with_payload=True ) # 3) Dense only results["D"] = qdrant_client.query_points( collection_name="items", query=dense_query, using="dense", limit=5, with_payload=True ) # 4) Hybrid fusion (RRF for Sparse+Dense) results["S+D-F"] = qdrant_client.query_points( collection_name="items", prefetch=[ models.Prefetch( query=dense_query, using="dense", limit=100, ), models.Prefetch( query=tsq[0].as_object(), using="sparse", limit=50, ) ], query=models.FusionQuery(fusion=models.Fusion.RRF), limit=5, with_payload=True ) # 5) Hybrid fusion + ColBERT sparse_dense_prefetch = models.Prefetch( prefetch=[ models.Prefetch(query=dense_query, using="dense", limit=100), models.Prefetch(query=tsq[0].as_object(), using="sparse", limit=50), ], limit=10, query=models.FusionQuery(fusion=models.Fusion.RRF), ) results["S+D-F-C"] = qdrant_client.query_points( collection_name="items", prefetch=[sparse_dense_prefetch], query=late_query, using="late", limit=5, with_payload=True ) # 6) Hybrid no-fusion + ColBERT old_prefetch = models.Prefetch( prefetch=[ models.Prefetch( prefetch=[ models.Prefetch(query=dense_query, using="dense", limit=100) ], query=tsq[0].as_object(), using="sparse", limit=50, ) ] ) results["S+D-C"] = qdrant_client.query_points( collection_name="items", prefetch=[old_prefetch], query=late_query, using="late", limit=5, with_payload=True ) return results ############################# # 2. Streamlit Main App ############################# def main(): st.title("Semantic Search Sandbox") # Initialize session state if not present if "json_loaded" not in st.session_state: st.session_state["json_loaded"] = False if "qdrant_client" not in st.session_state: st.session_state["qdrant_client"] = None ####################################### # Show JSON input only if not loaded ####################################### if not st.session_state["json_loaded"]: st.subheader("Paste items.json Here") default_json = """ { "items": [ { "name": "Example1", "description": "An example item" }, { "name": "Example2", "description": "Another item for demonstration" } ] } """.strip() json_text = st.text_area("JSON Input", value=default_json, height=300) if st.button("Load JSON"): try: data = json.loads(json_text) # Build Qdrant index in memory st.session_state["qdrant_client"] = build_qdrant_index(data) st.session_state["json_loaded"] = True st.success("JSON loaded and Qdrant index built successfully!") st.rerun() except Exception as e: st.error(f"Error parsing JSON: {e}") else: # The data is loaded, show a button to reset if you want to load new JSON if st.button("Load a different JSON"): st.session_state["json_loaded"] = False st.session_state["qdrant_client"] = None #st.experimental_rerun() # Refresh the page else: # Show the search interface query_text = st.text_input("Search Query", value="ACB 1.0 Ports") if st.button("Search"): if st.session_state["qdrant_client"] is None: st.warning("Please load valid JSON first.") return # Run queries results_dict = run_queries(st.session_state["qdrant_client"], query_text) # Display results in columns col_names = list(results_dict.keys()) # You can split into multiple rows if there are more than 3 n_cols = 3 # We'll create enough columns to handle all search types rows_needed = (len(col_names) + n_cols - 1) // n_cols for row_idx in range(rows_needed): cols = st.columns(n_cols) for col_idx in range(n_cols): method_idx = row_idx * n_cols + col_idx if method_idx < len(col_names): method = col_names[method_idx] qdrant_result = results_dict[method] with cols[col_idx]: st.markdown(f"### {method}") for point in qdrant_result.points: name = point.payload.get("name", "Unnamed") score = round(point.score, 4) if point.score else "N/A" st.write(f"- **{name}** (score={score})") if __name__ == "__main__": main()