File size: 5,363 Bytes
500c1ba
567e7ba
 
500c1ba
 
b687ff9
500c1ba
 
567e7ba
500c1ba
 
b687ff9
500c1ba
567e7ba
a80ee03
567e7ba
a80ee03
567e7ba
a80ee03
 
 
 
 
 
 
567e7ba
b687ff9
 
 
 
 
567e7ba
b687ff9
2733f93
500c1ba
 
 
567e7ba
500c1ba
 
 
 
 
 
2733f93
 
500c1ba
2733f93
500c1ba
 
 
 
2733f93
500c1ba
 
 
 
b687ff9
500c1ba
 
 
 
 
 
 
 
a6cce41
500c1ba
b687ff9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import logging
import torch
import numpy as np
from qdrant_client import QdrantClient
from qdrant_client.http.models import Filter, FieldCondition
from collections import defaultdict

class QdrantSearcher:
    def __init__(self, qdrant_url, access_token):
        self.client = QdrantClient(url=qdrant_url, api_key=access_token)

    def search_documents(self, collection_name, query_embedding, user_id, limit=3,similarity_threshold=0.6, file_id=None):
        logging.info("Starting document search")

        # Ensure the query_embedding is in the correct format (flat list of floats)
        if isinstance(query_embedding, torch.Tensor):
            query_embedding = query_embedding.detach().numpy().flatten().tolist()
        elif isinstance(query_embedding, np.ndarray):
            query_embedding = query_embedding.flatten().tolist()
        else:
            raise ValueError("query_embedding must be a torch.Tensor or numpy.ndarray")

        # Validate that all elements in the query_vector are floats
        if not all(isinstance(x, float) for x in query_embedding):
            raise ValueError("All elements in query_embedding must be of type float")

        filter_conditions = [FieldCondition(key="user_id", match={"value": user_id})]

        if file_id:
            filter_conditions.append(FieldCondition(key="file_id", match={"value": file_id}))

        # Filter by user_id
        query_filter = Filter(must=filter_conditions)
        logging.info(f"Performing search using the precomputed embeddings for user_id: {user_id}")
        try:
            hits = self.client.search(
                collection_name=collection_name,
                query_vector=query_embedding,
                limit=limit,
                query_filter=query_filter
            )
        except Exception as e:
            logging.error(f"Error during Qdrant search: {e}")
            return None, str(e)
        
        filtered_hits = [hit for hit in hits if hit.score >= similarity_threshold]

        if not filtered_hits:
            logging.info("No documents found for the given query")
            return None, "No documents found for the given query."

        hits_list = []
        for hit in filtered_hits:
            hit_info = {
                "id": hit.id,
                "score": hit.score,
                "file_id": hit.payload.get('file_id'),
                "file_name": hit.payload.get('file_name'),
                "organization_id": hit.payload.get('organization_id'),
                "chunk_index": hit.payload.get('chunk_index'),
                "chunk_text": hit.payload.get('chunk_text'),
                "s3_bucket_key": hit.payload.get('s3_bucket_key')
            }
            hits_list.append(hit_info)

        logging.info(f"Document search completed with {len(hits_list)} hits")
        logging.info(f"Hits: {hits_list}")
        return hits_list, None
    
    def search_documents_grouped(self, collection_name, query_embedding, user_id, limit=60, similarity_threshold=0.6, file_id=None):
        logging.info("Starting grouped document search")

        if isinstance(query_embedding, torch.Tensor):
            query_embedding = query_embedding.detach().numpy().flatten().tolist()
        elif isinstance(query_embedding, np.ndarray):
            query_embedding = query_embedding.flatten().tolist()
        else:
            raise ValueError("query_embedding must be a torch.Tensor or numpy.ndarray")

        if not all(isinstance(x, float) for x in query_embedding):
            raise ValueError("All elements in query_embedding must be of type float")
        #query_filter = Filter(must=[FieldCondition(key="user_id", match={"value": user_id})])
        filter_conditions = [FieldCondition(key="user_id", match={"value": user_id})]

        if file_id:
            filter_conditions.append(FieldCondition(key="file_id", match={"value": file_id}))

        # Filter by user_id
        query_filter = Filter(must=filter_conditions)
        logging.info(f"Performing grouped search using the precomputed embeddings for user_id: {user_id}")
        try:
            hits = self.client.search(
                collection_name=collection_name,
                query_vector=query_embedding,
                limit=limit,
                query_filter=query_filter
            )
        except Exception as e:
            logging.error(f"Error during Qdrant search: {e}")
            return None, str(e)
        
        #filtered_hits = [hit for hit in hits if hit.score >= similarity_threshold]

        if not hits:
            logging.info("No documents found for the given query")
            return None, "No documents found for the given query."
        
        # Group hits by filename and calculate average score 
        grouped_hits = defaultdict(list)
        for hit in hits:
            grouped_hits[hit.payload.get('file_name')].append(hit.score)

        grouped_results = []
        for file_name, scores in grouped_hits.items():
            average_score = sum(scores) / len(scores)
            grouped_results.append({
                "file_name": file_name,
                "average_score": average_score
            })

        logging.info(f"Grouped search completed with {len(grouped_results)} results")
        logging.info(f"Grouped Hits: {grouped_results}")
        return grouped_results, None