mtesmer-iqnox commited on
Commit
556f1d5
·
1 Parent(s): c0a528b
Files changed (1) hide show
  1. app.py +299 -2
app.py CHANGED
@@ -1,4 +1,301 @@
1
  import streamlit as st
 
 
 
 
 
 
2
 
3
- x = st.slider('Select a value')
4
- st.write(x, 'squared is', x * x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import json
3
+ from typing import List
4
+ from fastembed import LateInteractionTextEmbedding, TextEmbedding
5
+ from fastembed import SparseTextEmbedding, SparseEmbedding
6
+ from qdrant_client import QdrantClient, models
7
+ from tokenizers import Tokenizer
8
 
9
+ #############################
10
+ # 1. Utility / Helper Code
11
+ #############################
12
+
13
+ @st.cache_resource
14
+ def load_tokenizer():
15
+ """
16
+ Load the tokenizer for interpreting sparse embeddings (optional usage).
17
+ """
18
+ return Tokenizer.from_pretrained(SparseTextEmbedding.list_supported_models()[0]["sources"]["hf"])
19
+
20
+ @st.cache_resource
21
+ def load_models():
22
+ """
23
+ Load/initialize your models once and cache them.
24
+ """
25
+ # Dense embedding model
26
+ dense_embedding_model = TextEmbedding("BAAI/bge-small-en-v1.5")
27
+
28
+ # Late interaction model (ColBERTv2)
29
+ late_embedding_model = LateInteractionTextEmbedding("colbert-ir/colbertv2.0")
30
+
31
+ # Sparse embedding model
32
+ sparse_model_name = "Qdrant/bm25"
33
+ sparse_model = SparseTextEmbedding(model_name=sparse_model_name)
34
+
35
+ return dense_embedding_model, late_embedding_model, sparse_model
36
+
37
+ def build_qdrant_index(data):
38
+ """
39
+ Given the parsed data (list of items), build an in-memory Qdrant index
40
+ with dense, late, and sparse vectors.
41
+ """
42
+ # Extract fields
43
+ items = data["items"]
44
+ descriptions = [f"{item['name']} - {item['description']}" for item in items]
45
+ names = [item["name"] for item in items]
46
+ metadata = [
47
+ {"name": item["name"]} # You can store more fields if you like
48
+ for item in items
49
+ ]
50
+
51
+ # Load models
52
+ dense_embedding_model, late_embedding_model, sparse_model = load_models()
53
+
54
+ # Generate embeddings
55
+ dense_embeddings = list(dense_embedding_model.embed(descriptions))
56
+ name_dense_embeddings = list(dense_embedding_model.embed(names))
57
+ late_embeddings = list(late_embedding_model.embed(descriptions))
58
+ sparse_embeddings: List[SparseEmbedding] = list(sparse_model.embed(descriptions, batch_size=6))
59
+
60
+ # Create an in-memory Qdrant instance
61
+ qdrant_client = QdrantClient(":memory:")
62
+
63
+ # Create collection schema
64
+ qdrant_client.create_collection(
65
+ collection_name="items",
66
+ vectors_config={
67
+ "dense": models.VectorParams(
68
+ size=len(dense_embeddings[0]),
69
+ distance=models.Distance.COSINE,
70
+ ),
71
+ "late": models.VectorParams(
72
+ size=len(late_embeddings[0][0]),
73
+ distance=models.Distance.COSINE,
74
+ multivector_config=models.MultiVectorConfig(
75
+ comparator=models.MultiVectorComparator.MAX_SIM
76
+ ),
77
+ ),
78
+ },
79
+ sparse_vectors_config={
80
+ "sparse": models.SparseVectorParams(
81
+ modifier=models.Modifier.IDF,
82
+ ),
83
+ }
84
+ )
85
+
86
+ # Upload points
87
+ points = []
88
+ for idx, _ in enumerate(metadata):
89
+ points.append(
90
+ models.PointStruct(
91
+ id=idx,
92
+ payload=metadata[idx],
93
+ vector={
94
+ "late": late_embeddings[idx].tolist(),
95
+ "dense": dense_embeddings[idx],
96
+ "sparse": sparse_embeddings[idx].as_object(),
97
+ },
98
+ )
99
+ )
100
+
101
+ qdrant_client.upload_points(
102
+ collection_name="items",
103
+ points=points,
104
+ )
105
+
106
+ return qdrant_client
107
+
108
+ def run_queries(qdrant_client, query_text):
109
+ """
110
+ Run all the different query types and return results in a dictionary.
111
+ """
112
+ # Load models
113
+ dense_embedding_model, late_embedding_model, sparse_model = load_models()
114
+
115
+ # Generate single-query embeddings
116
+ dense_query = next(dense_embedding_model.query_embed(query_text))
117
+ late_query = next(late_embedding_model.query_embed(query_text))
118
+ sparse_query = next(sparse_model.query_embed(query_text))
119
+
120
+ # For the fusion approach, we need a list form for prefetch
121
+ tsq = list(sparse_model.embed(query_text, batch_size=6))
122
+
123
+ # We'll store top-5 results for each approach
124
+ results = {}
125
+
126
+ # 1) ColBERT (late)
127
+ results["C"] = qdrant_client.query_points(
128
+ collection_name="items",
129
+ query=late_query,
130
+ using="late",
131
+ limit=5,
132
+ with_payload=True
133
+ )
134
+
135
+ # 2) Sparse only
136
+ results["S"] = qdrant_client.query_points(
137
+ collection_name="items",
138
+ query=models.SparseVector(**sparse_query.as_object()),
139
+ using="sparse",
140
+ limit=5,
141
+ with_payload=True
142
+ )
143
+
144
+ # 3) Dense only
145
+ results["D"] = qdrant_client.query_points(
146
+ collection_name="items",
147
+ query=dense_query,
148
+ using="dense",
149
+ limit=5,
150
+ with_payload=True
151
+ )
152
+
153
+ # 4) Hybrid fusion (RRF for Sparse+Dense)
154
+ results["S+D-F"] = qdrant_client.query_points(
155
+ collection_name="items",
156
+ prefetch=[
157
+ models.Prefetch(
158
+ query=dense_query,
159
+ using="dense",
160
+ limit=100,
161
+ ),
162
+ models.Prefetch(
163
+ query=tsq[0].as_object(),
164
+ using="sparse",
165
+ limit=50,
166
+ )
167
+ ],
168
+ query=models.FusionQuery(fusion=models.Fusion.RRF),
169
+ limit=5,
170
+ with_payload=True
171
+ )
172
+
173
+ # 5) Hybrid fusion + ColBERT
174
+ sparse_dense_prefetch = models.Prefetch(
175
+ prefetch=[
176
+ models.Prefetch(query=dense_query, using="dense", limit=100),
177
+ models.Prefetch(query=tsq[0].as_object(), using="sparse", limit=50),
178
+ ],
179
+ limit=10,
180
+ query=models.FusionQuery(fusion=models.Fusion.RRF),
181
+ )
182
+ results["S+D-F-C"] = qdrant_client.query_points(
183
+ collection_name="items",
184
+ prefetch=[sparse_dense_prefetch],
185
+ query=late_query,
186
+ using="late",
187
+ limit=5,
188
+ with_payload=True
189
+ )
190
+
191
+ # 6) Hybrid no-fusion + ColBERT
192
+ old_prefetch = models.Prefetch(
193
+ prefetch=[
194
+ models.Prefetch(
195
+ prefetch=[
196
+ models.Prefetch(query=dense_query, using="dense", limit=100)
197
+ ],
198
+ query=tsq[0].as_object(),
199
+ using="sparse",
200
+ limit=50,
201
+ )
202
+ ]
203
+ )
204
+ results["S+D-C"] = qdrant_client.query_points(
205
+ collection_name="items",
206
+ prefetch=[old_prefetch],
207
+ query=late_query,
208
+ using="late",
209
+ limit=5,
210
+ with_payload=True
211
+ )
212
+
213
+ return results
214
+
215
+ #############################
216
+ # 2. Streamlit Main App
217
+ #############################
218
+
219
+ def main():
220
+ st.title("Semantic Search Sandbox")
221
+
222
+ # Initialize session state if not present
223
+ if "json_loaded" not in st.session_state:
224
+ st.session_state["json_loaded"] = False
225
+ if "qdrant_client" not in st.session_state:
226
+ st.session_state["qdrant_client"] = None
227
+
228
+ #######################################
229
+ # Show JSON input only if not loaded
230
+ #######################################
231
+ if not st.session_state["json_loaded"]:
232
+ st.subheader("Paste items.json Here")
233
+ default_json = """
234
+ {
235
+ "items": [
236
+ {
237
+ "name": "Example1",
238
+ "description": "An example item"
239
+ },
240
+ {
241
+ "name": "Example2",
242
+ "description": "Another item for demonstration"
243
+ }
244
+ ]
245
+ }
246
+ """.strip()
247
+
248
+ json_text = st.text_area("JSON Input", value=default_json, height=300)
249
+
250
+ if st.button("Load JSON"):
251
+ try:
252
+ data = json.loads(json_text)
253
+ # Build Qdrant index in memory
254
+ st.session_state["qdrant_client"] = build_qdrant_index(data)
255
+ st.session_state["json_loaded"] = True
256
+ st.success("JSON loaded and Qdrant index built successfully!")
257
+ st.rerun()
258
+ except Exception as e:
259
+ st.error(f"Error parsing JSON: {e}")
260
+
261
+ else:
262
+ # The data is loaded, show a button to reset if you want to load new JSON
263
+ if st.button("Load a different JSON"):
264
+ st.session_state["json_loaded"] = False
265
+ st.session_state["qdrant_client"] = None
266
+ #st.experimental_rerun() # Refresh the page
267
+ else:
268
+ # Show the search interface
269
+ query_text = st.text_input("Search Query", value="ACB 1.0 Ports")
270
+ if st.button("Search"):
271
+ if st.session_state["qdrant_client"] is None:
272
+ st.warning("Please load valid JSON first.")
273
+ return
274
+
275
+ # Run queries
276
+ results_dict = run_queries(st.session_state["qdrant_client"], query_text)
277
+
278
+ # Display results in columns
279
+ col_names = list(results_dict.keys())
280
+ # You can split into multiple rows if there are more than 3
281
+ n_cols = 3
282
+ # We'll create enough columns to handle all search types
283
+ rows_needed = (len(col_names) + n_cols - 1) // n_cols
284
+
285
+ for row_idx in range(rows_needed):
286
+ cols = st.columns(n_cols)
287
+ for col_idx in range(n_cols):
288
+ method_idx = row_idx * n_cols + col_idx
289
+ if method_idx < len(col_names):
290
+ method = col_names[method_idx]
291
+ qdrant_result = results_dict[method]
292
+
293
+ with cols[col_idx]:
294
+ st.markdown(f"### {method}")
295
+ for point in qdrant_result.points:
296
+ name = point.payload.get("name", "Unnamed")
297
+ score = round(point.score, 4) if point.score else "N/A"
298
+ st.write(f"- **{name}** (score={score})")
299
+
300
+ if __name__ == "__main__":
301
+ main()