Spaces:
Sleeping
Sleeping
File size: 5,363 Bytes
500c1ba 567e7ba 500c1ba b687ff9 500c1ba 567e7ba 500c1ba b687ff9 500c1ba 567e7ba a80ee03 567e7ba a80ee03 567e7ba a80ee03 567e7ba b687ff9 567e7ba b687ff9 2733f93 500c1ba 567e7ba 500c1ba 2733f93 500c1ba 2733f93 500c1ba 2733f93 500c1ba b687ff9 500c1ba a6cce41 500c1ba b687ff9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
import logging
import torch
import numpy as np
from qdrant_client import QdrantClient
from qdrant_client.http.models import Filter, FieldCondition
from collections import defaultdict
class QdrantSearcher:
def __init__(self, qdrant_url, access_token):
self.client = QdrantClient(url=qdrant_url, api_key=access_token)
def search_documents(self, collection_name, query_embedding, user_id, limit=3,similarity_threshold=0.6, file_id=None):
logging.info("Starting document search")
# Ensure the query_embedding is in the correct format (flat list of floats)
if isinstance(query_embedding, torch.Tensor):
query_embedding = query_embedding.detach().numpy().flatten().tolist()
elif isinstance(query_embedding, np.ndarray):
query_embedding = query_embedding.flatten().tolist()
else:
raise ValueError("query_embedding must be a torch.Tensor or numpy.ndarray")
# Validate that all elements in the query_vector are floats
if not all(isinstance(x, float) for x in query_embedding):
raise ValueError("All elements in query_embedding must be of type float")
filter_conditions = [FieldCondition(key="user_id", match={"value": user_id})]
if file_id:
filter_conditions.append(FieldCondition(key="file_id", match={"value": file_id}))
# Filter by user_id
query_filter = Filter(must=filter_conditions)
logging.info(f"Performing search using the precomputed embeddings for user_id: {user_id}")
try:
hits = self.client.search(
collection_name=collection_name,
query_vector=query_embedding,
limit=limit,
query_filter=query_filter
)
except Exception as e:
logging.error(f"Error during Qdrant search: {e}")
return None, str(e)
filtered_hits = [hit for hit in hits if hit.score >= similarity_threshold]
if not filtered_hits:
logging.info("No documents found for the given query")
return None, "No documents found for the given query."
hits_list = []
for hit in filtered_hits:
hit_info = {
"id": hit.id,
"score": hit.score,
"file_id": hit.payload.get('file_id'),
"file_name": hit.payload.get('file_name'),
"organization_id": hit.payload.get('organization_id'),
"chunk_index": hit.payload.get('chunk_index'),
"chunk_text": hit.payload.get('chunk_text'),
"s3_bucket_key": hit.payload.get('s3_bucket_key')
}
hits_list.append(hit_info)
logging.info(f"Document search completed with {len(hits_list)} hits")
logging.info(f"Hits: {hits_list}")
return hits_list, None
def search_documents_grouped(self, collection_name, query_embedding, user_id, limit=60, similarity_threshold=0.6, file_id=None):
logging.info("Starting grouped document search")
if isinstance(query_embedding, torch.Tensor):
query_embedding = query_embedding.detach().numpy().flatten().tolist()
elif isinstance(query_embedding, np.ndarray):
query_embedding = query_embedding.flatten().tolist()
else:
raise ValueError("query_embedding must be a torch.Tensor or numpy.ndarray")
if not all(isinstance(x, float) for x in query_embedding):
raise ValueError("All elements in query_embedding must be of type float")
#query_filter = Filter(must=[FieldCondition(key="user_id", match={"value": user_id})])
filter_conditions = [FieldCondition(key="user_id", match={"value": user_id})]
if file_id:
filter_conditions.append(FieldCondition(key="file_id", match={"value": file_id}))
# Filter by user_id
query_filter = Filter(must=filter_conditions)
logging.info(f"Performing grouped search using the precomputed embeddings for user_id: {user_id}")
try:
hits = self.client.search(
collection_name=collection_name,
query_vector=query_embedding,
limit=limit,
query_filter=query_filter
)
except Exception as e:
logging.error(f"Error during Qdrant search: {e}")
return None, str(e)
#filtered_hits = [hit for hit in hits if hit.score >= similarity_threshold]
if not hits:
logging.info("No documents found for the given query")
return None, "No documents found for the given query."
# Group hits by filename and calculate average score
grouped_hits = defaultdict(list)
for hit in hits:
grouped_hits[hit.payload.get('file_name')].append(hit.score)
grouped_results = []
for file_name, scores in grouped_hits.items():
average_score = sum(scores) / len(scores)
grouped_results.append({
"file_name": file_name,
"average_score": average_score
})
logging.info(f"Grouped search completed with {len(grouped_results)} results")
logging.info(f"Grouped Hits: {grouped_results}")
return grouped_results, None
|