vhr1007 commited on
Commit
b687ff9
1 Parent(s): 2733f93

new_version_changes3.0

Browse files
Files changed (3) hide show
  1. app.py +42 -18
  2. services/openai_service.py +2 -2
  3. services/qdrant_searcher.py +65 -2
app.py CHANGED
@@ -10,6 +10,7 @@ from dotenv import load_dotenv
10
  import os
11
  import torch
12
  from utils.auth_x import x_api_key_auth
 
13
 
14
  # Load environment variables from .env file
15
  load_dotenv()
@@ -25,6 +26,8 @@ hf_home_dir = os.environ["HF_HOME"]
25
  if not os.path.exists(hf_home_dir):
26
  os.makedirs(hf_home_dir)
27
 
 
 
28
  # Setup logging using Python's standard logging library
29
  logging.basicConfig(level=logging.INFO)
30
 
@@ -76,14 +79,17 @@ def embed_text(text):
76
  class SearchDocumentsRequest(BaseModel):
77
  query: str
78
  limit: int = 3
 
79
 
80
  class GenerateRAGRequest(BaseModel):
81
  search_query: str
 
82
 
83
  class XApiKeyRequest(BaseModel):
84
  organization_id: str
85
  user_id: str
86
  search_query: str
 
87
 
88
 
89
  @app.get("/")
@@ -97,7 +103,7 @@ async def search_documents(
97
  credentials: tuple = Depends(token_required)
98
  ):
99
  customer_id, user_id = credentials
100
-
101
  if not customer_id or not user_id:
102
  logging.error("Failed to extract customer_id or user_id from the JWT token.")
103
  raise HTTPException(status_code=401, detail="Invalid token: missing customer_id or user_id")
@@ -109,17 +115,19 @@ async def search_documents(
109
  # Encode the query using the custom embedding function
110
  query_embedding = embed_text(body.query)
111
  print(body.query)
112
- collection_name = "embed" # Use the collection name where the embeddings are stored
113
-
114
  logging.info("Performing search using the precomputed embeddings")
 
 
115
  # Perform search using the precomputed embeddings
116
  hits, error = searcher.search_documents(collection_name, query_embedding, user_id, body.limit)
117
 
118
  if error:
119
  logging.error(f"Search documents error: {error}")
120
  raise HTTPException(status_code=500, detail=error)
121
-
122
- return hits
 
123
  except Exception as e:
124
  logging.error(f"Unexpected error: {e}")
125
  raise HTTPException(status_code=500, detail=str(e))
@@ -131,31 +139,41 @@ async def generate_rag_response_api(
131
  credentials: tuple = Depends(token_required)
132
  ):
133
  customer_id, user_id = credentials
134
-
135
  if not customer_id or not user_id:
136
  logging.error("Failed to extract customer_id or user_id from the JWT token.")
137
  raise HTTPException(status_code=401, detail="Invalid token: missing customer_id or user_id")
138
 
139
  logging.info("Received request to generate RAG response")
 
140
  try:
 
141
  logging.info("Starting document search")
142
-
143
  # Encode the query using the custom embedding function
144
  query_embedding = embed_text(body.search_query)
145
  print(body.search_query)
146
- collection_name = "embed" # Use the collection name where the embeddings are stored
147
  # Perform search using the precomputed embeddings
148
- hits, error = searcher.search_documents(collection_name, query_embedding, user_id)
 
 
 
149
 
150
  if error:
151
  logging.error(f"Search documents error: {error}")
152
  raise HTTPException(status_code=500, detail=error)
153
 
154
  logging.info("Generating RAG response")
155
-
 
 
156
  # Generate the RAG response using the retrieved documents
157
  response, error = generate_rag_response(hits, body.search_query)
158
-
 
 
 
 
159
  if error:
160
  logging.error(f"Generate RAG response error: {error}")
161
  raise HTTPException(status_code=500, detail=error)
@@ -172,9 +190,11 @@ async def search_documents_x_api_key(
172
  ):
173
  if not authorized:
174
  raise HTTPException(status_code=401, detail="Unauthorized")
175
-
176
  organization_id = body.organization_id
177
  user_id = body.user_id
 
 
178
  logging.info(f'search query {body.search_query}')
179
  logging.info(f"organization_id: {organization_id}, user_id: {user_id}")
180
  logging.info("Received request to search documents with x-api-key auth")
@@ -183,16 +203,18 @@ async def search_documents_x_api_key(
183
 
184
  # Encode the query using the custom embedding function
185
  query_embedding = embed_text(body.search_query)
186
- collection_name = "embed" # Use the collection name where the embeddings are stored
187
 
188
  # Perform search using the precomputed embeddings
189
- hits, error = searcher.search_documents(collection_name, query_embedding, user_id, limit=3)
190
 
191
  if error:
192
  logging.error(f"Search documents error: {error}")
193
  raise HTTPException(status_code=500, detail=error)
194
 
195
  logging.info(f"Document search completed with {len(hits)} hits")
 
 
196
  return hits
197
  except Exception as e:
198
  logging.error(f"Unexpected error: {e}")
@@ -206,9 +228,10 @@ async def generate_rag_response_x_api_key(
206
  # Assuming x_api_key_auth validates the key
207
  if not authorized:
208
  raise HTTPException(status_code=401, detail="Unauthorized")
209
-
210
  organization_id = body.organization_id
211
  user_id = body.user_id
 
212
 
213
  logging.info(f'search query {body.search_query}')
214
  logging.info(f"organization_id: {organization_id}, user_id: {user_id}")
@@ -218,10 +241,10 @@ async def generate_rag_response_x_api_key(
218
 
219
  # Encode the query using the custom embedding function
220
  query_embedding = embed_text(body.search_query)
221
- collection_name = "embed" # Use the collection name where the embeddings are stored
222
 
223
  # Perform search using the precomputed embeddings
224
- hits, error = searcher.search_documents(collection_name, query_embedding, user_id)
225
 
226
  if error:
227
  logging.error(f"Search documents error: {error}")
@@ -235,7 +258,8 @@ async def generate_rag_response_x_api_key(
235
  if error:
236
  logging.error(f"Generate RAG response error: {error}")
237
  raise HTTPException(status_code=500, detail=error)
238
-
 
239
  return {"response": response}
240
  except Exception as e:
241
  logging.error(f"Unexpected error: {e}")
 
10
  import os
11
  import torch
12
  from utils.auth_x import x_api_key_auth
13
+ import time
14
 
15
  # Load environment variables from .env file
16
  load_dotenv()
 
26
  if not os.path.exists(hf_home_dir):
27
  os.makedirs(hf_home_dir)
28
 
29
+ collection_name = os.getenv('QDRANT_COLLECTION_NAME')
30
+ logging.info(f"Collection name: {collection_name}")
31
  # Setup logging using Python's standard logging library
32
  logging.basicConfig(level=logging.INFO)
33
 
 
79
  class SearchDocumentsRequest(BaseModel):
80
  query: str
81
  limit: int = 3
82
+ file_id: str = None
83
 
84
  class GenerateRAGRequest(BaseModel):
85
  search_query: str
86
+ file_id: str = None
87
 
88
  class XApiKeyRequest(BaseModel):
89
  organization_id: str
90
  user_id: str
91
  search_query: str
92
+ file_id: str = None
93
 
94
 
95
  @app.get("/")
 
103
  credentials: tuple = Depends(token_required)
104
  ):
105
  customer_id, user_id = credentials
106
+ start_time = time.time()
107
  if not customer_id or not user_id:
108
  logging.error("Failed to extract customer_id or user_id from the JWT token.")
109
  raise HTTPException(status_code=401, detail="Invalid token: missing customer_id or user_id")
 
115
  # Encode the query using the custom embedding function
116
  query_embedding = embed_text(body.query)
117
  print(body.query)
118
+ #collection_name = "embed" # Use the collection name where the embeddings are stored
 
119
  logging.info("Performing search using the precomputed embeddings")
120
+ if body.file_id:
121
+ hits, error = searcher.search_documents(collection_name, query_embedding, user_id, body.limit, file_id=body.file_id)
122
  # Perform search using the precomputed embeddings
123
  hits, error = searcher.search_documents(collection_name, query_embedding, user_id, body.limit)
124
 
125
  if error:
126
  logging.error(f"Search documents error: {error}")
127
  raise HTTPException(status_code=500, detail=error)
128
+ end_time = time.time()
129
+ time_taken = end_time - start_time
130
+ return hits, time_taken
131
  except Exception as e:
132
  logging.error(f"Unexpected error: {e}")
133
  raise HTTPException(status_code=500, detail=str(e))
 
139
  credentials: tuple = Depends(token_required)
140
  ):
141
  customer_id, user_id = credentials
142
+ start_time = time.time()
143
  if not customer_id or not user_id:
144
  logging.error("Failed to extract customer_id or user_id from the JWT token.")
145
  raise HTTPException(status_code=401, detail="Invalid token: missing customer_id or user_id")
146
 
147
  logging.info("Received request to generate RAG response")
148
+
149
  try:
150
+ search_time = time.time()
151
  logging.info("Starting document search")
 
152
  # Encode the query using the custom embedding function
153
  query_embedding = embed_text(body.search_query)
154
  print(body.search_query)
155
+ #collection_name = "embed" # Use the collection name where the embeddings are stored
156
  # Perform search using the precomputed embeddings
157
+ if body.file_id:
158
+ hits, error = searcher.search_documents(collection_name, query_embedding, user_id, file_id=body.file_id)
159
+ else:
160
+ hits, error = searcher.search_documents(collection_name, query_embedding, user_id)
161
 
162
  if error:
163
  logging.error(f"Search documents error: {error}")
164
  raise HTTPException(status_code=500, detail=error)
165
 
166
  logging.info("Generating RAG response")
167
+ end_search_time = time.time()
168
+ search_time_taken = end_search_time - search_time
169
+ rag_start_time = time.time()
170
  # Generate the RAG response using the retrieved documents
171
  response, error = generate_rag_response(hits, body.search_query)
172
+ rag_end_time = time.time()
173
+ rag_time_taken = rag_end_time - rag_start_time
174
+ end_time= time.time()
175
+ total_time = end_time - start_time
176
+ logging.info(f"Search time: {search_time_taken}, RAG time: {rag_time_taken}, Total time: {total_time}")
177
  if error:
178
  logging.error(f"Generate RAG response error: {error}")
179
  raise HTTPException(status_code=500, detail=error)
 
190
  ):
191
  if not authorized:
192
  raise HTTPException(status_code=401, detail="Unauthorized")
193
+ start_time = time.time()
194
  organization_id = body.organization_id
195
  user_id = body.user_id
196
+ file_id = body.file_id
197
+
198
  logging.info(f'search query {body.search_query}')
199
  logging.info(f"organization_id: {organization_id}, user_id: {user_id}")
200
  logging.info("Received request to search documents with x-api-key auth")
 
203
 
204
  # Encode the query using the custom embedding function
205
  query_embedding = embed_text(body.search_query)
206
+ #collection_name = "embed" # Use the collection name where the embeddings are stored
207
 
208
  # Perform search using the precomputed embeddings
209
+ hits, error = searcher.search_documents(collection_name, query_embedding, user_id, limit=3, file_id=file_id)
210
 
211
  if error:
212
  logging.error(f"Search documents error: {error}")
213
  raise HTTPException(status_code=500, detail=error)
214
 
215
  logging.info(f"Document search completed with {len(hits)} hits")
216
+ end_time = time.time()
217
+ logging.info(f"Time taken: {end_time - start_time}")
218
  return hits
219
  except Exception as e:
220
  logging.error(f"Unexpected error: {e}")
 
228
  # Assuming x_api_key_auth validates the key
229
  if not authorized:
230
  raise HTTPException(status_code=401, detail="Unauthorized")
231
+ start_time = time.time()
232
  organization_id = body.organization_id
233
  user_id = body.user_id
234
+ file_id = body.file_id
235
 
236
  logging.info(f'search query {body.search_query}')
237
  logging.info(f"organization_id: {organization_id}, user_id: {user_id}")
 
241
 
242
  # Encode the query using the custom embedding function
243
  query_embedding = embed_text(body.search_query)
244
+ #collection_name = "embed" # Use the collection name where the embeddings are stored
245
 
246
  # Perform search using the precomputed embeddings
247
+ hits, error = searcher.search_documents(collection_name, query_embedding, user_id, file_id=file_id)
248
 
249
  if error:
250
  logging.error(f"Search documents error: {error}")
 
258
  if error:
259
  logging.error(f"Generate RAG response error: {error}")
260
  raise HTTPException(status_code=500, detail=error)
261
+ end_time = time.time()
262
+ logging.info(f"Time taken: {end_time - start_time}")
263
  return {"response": response}
264
  except Exception as e:
265
  logging.error(f"Unexpected error: {e}")
services/openai_service.py CHANGED
@@ -28,7 +28,7 @@ def generate_rag_response(json_output, user_query):
28
 
29
  # Create the context for the prompt
30
  context = "\n".join(context_texts)
31
- prompt = f"Based on the given context, answer the user query: {user_query}\nContext:\n{context}"
32
 
33
  main_prompt = [
34
  {"role": "system", "content": "You are a helpful assistant."},
@@ -39,7 +39,7 @@ def generate_rag_response(json_output, user_query):
39
  # Create a chat completion request
40
  chat_completion = client.chat.completions.create(
41
  messages=main_prompt,
42
- model="gpt-35-turbo", # Use the gpt-4o-mini model
43
  max_tokens=2000, # Limit the maximum number of tokens in the response
44
  temperature=0.5
45
  )
 
28
 
29
  # Create the context for the prompt
30
  context = "\n".join(context_texts)
31
+ prompt = f"Based on the given context, answer the user query: {user_query}\nContext:\n{context} and Employ references to the ID of articles provided [ID], ensuring their relevance to the query. The referencing should always be in the format of [1][2]... etc. </instructions> "
32
 
33
  main_prompt = [
34
  {"role": "system", "content": "You are a helpful assistant."},
 
39
  # Create a chat completion request
40
  chat_completion = client.chat.completions.create(
41
  messages=main_prompt,
42
+ model="urdu-llama", # Use the gpt-4o-mini model
43
  max_tokens=2000, # Limit the maximum number of tokens in the response
44
  temperature=0.5
45
  )
services/qdrant_searcher.py CHANGED
@@ -3,12 +3,13 @@ import torch
3
  import numpy as np
4
  from qdrant_client import QdrantClient
5
  from qdrant_client.http.models import Filter, FieldCondition
 
6
 
7
  class QdrantSearcher:
8
  def __init__(self, qdrant_url, access_token):
9
  self.client = QdrantClient(url=qdrant_url, api_key=access_token)
10
 
11
- def search_documents(self, collection_name, query_embedding, user_id, limit=3,similarity_threshold=0.6):
12
  logging.info("Starting document search")
13
 
14
  # Ensure the query_embedding is in the correct format (flat list of floats)
@@ -23,8 +24,13 @@ class QdrantSearcher:
23
  if not all(isinstance(x, float) for x in query_embedding):
24
  raise ValueError("All elements in query_embedding must be of type float")
25
 
 
 
 
 
 
26
  # Filter by user_id
27
- query_filter = Filter(must=[FieldCondition(key="user_id", match={"value": user_id})])
28
  logging.info(f"Performing search using the precomputed embeddings for user_id: {user_id}")
29
  try:
30
  hits = self.client.search(
@@ -49,6 +55,7 @@ class QdrantSearcher:
49
  "id": hit.id,
50
  "score": hit.score,
51
  "file_id": hit.payload.get('file_id'),
 
52
  "organization_id": hit.payload.get('organization_id'),
53
  "chunk_index": hit.payload.get('chunk_index'),
54
  "chunk_text": hit.payload.get('chunk_text'),
@@ -59,3 +66,59 @@ class QdrantSearcher:
59
  logging.info(f"Document search completed with {len(hits_list)} hits")
60
  logging.info(f"Hits: {hits_list}")
61
  return hits_list, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import numpy as np
4
  from qdrant_client import QdrantClient
5
  from qdrant_client.http.models import Filter, FieldCondition
6
+ from collections import defaultdict
7
 
8
  class QdrantSearcher:
9
  def __init__(self, qdrant_url, access_token):
10
  self.client = QdrantClient(url=qdrant_url, api_key=access_token)
11
 
12
+ def search_documents(self, collection_name, query_embedding, user_id, limit=3,similarity_threshold=0.6, file_id=None):
13
  logging.info("Starting document search")
14
 
15
  # Ensure the query_embedding is in the correct format (flat list of floats)
 
24
  if not all(isinstance(x, float) for x in query_embedding):
25
  raise ValueError("All elements in query_embedding must be of type float")
26
 
27
+ filter_conditions = [FieldCondition(key="user_id", match={"value": user_id})]
28
+
29
+ if file_id:
30
+ filter_conditions.append(FieldCondition(key="file_id", match={"value": file_id}))
31
+
32
  # Filter by user_id
33
+ query_filter = Filter(must=filter_conditions)
34
  logging.info(f"Performing search using the precomputed embeddings for user_id: {user_id}")
35
  try:
36
  hits = self.client.search(
 
55
  "id": hit.id,
56
  "score": hit.score,
57
  "file_id": hit.payload.get('file_id'),
58
+ "file_name": hit.payload.get('file_name'),
59
  "organization_id": hit.payload.get('organization_id'),
60
  "chunk_index": hit.payload.get('chunk_index'),
61
  "chunk_text": hit.payload.get('chunk_text'),
 
66
  logging.info(f"Document search completed with {len(hits_list)} hits")
67
  logging.info(f"Hits: {hits_list}")
68
  return hits_list, None
69
+
70
+ def search_documents_grouped(self, collection_name, query_embedding, user_id, limit=60, similarity_threshold=0.6, file_id=None):
71
+ logging.info("Starting grouped document search")
72
+
73
+ if isinstance(query_embedding, torch.Tensor):
74
+ query_embedding = query_embedding.detach().numpy().flatten().tolist()
75
+ elif isinstance(query_embedding, np.ndarray):
76
+ query_embedding = query_embedding.flatten().tolist()
77
+ else:
78
+ raise ValueError("query_embedding must be a torch.Tensor or numpy.ndarray")
79
+
80
+ if not all(isinstance(x, float) for x in query_embedding):
81
+ raise ValueError("All elements in query_embedding must be of type float")
82
+ #query_filter = Filter(must=[FieldCondition(key="user_id", match={"value": user_id})])
83
+ filter_conditions = [FieldCondition(key="user_id", match={"value": user_id})]
84
+
85
+ if file_id:
86
+ filter_conditions.append(FieldCondition(key="file_id", match={"value": file_id}))
87
+
88
+ # Filter by user_id
89
+ query_filter = Filter(must=filter_conditions)
90
+ logging.info(f"Performing grouped search using the precomputed embeddings for user_id: {user_id}")
91
+ try:
92
+ hits = self.client.search(
93
+ collection_name=collection_name,
94
+ query_vector=query_embedding,
95
+ limit=limit,
96
+ query_filter=query_filter
97
+ )
98
+ except Exception as e:
99
+ logging.error(f"Error during Qdrant search: {e}")
100
+ return None, str(e)
101
+
102
+ #filtered_hits = [hit for hit in hits if hit.score >= similarity_threshold]
103
+
104
+ if not hits:
105
+ logging.info("No documents found for the given query")
106
+ return None, "No documents found for the given query."
107
+
108
+ # Group hits by filename and calculate average score
109
+ grouped_hits = defaultdict(list)
110
+ for hit in hits:
111
+ grouped_hits[hit.payload.get('file_name')].append(hit.score)
112
+
113
+ grouped_results = []
114
+ for file_name, scores in grouped_hits.items():
115
+ average_score = sum(scores) / len(scores)
116
+ grouped_results.append({
117
+ "file_name": file_name,
118
+ "average_score": average_score
119
+ })
120
+
121
+ logging.info(f"Grouped search completed with {len(grouped_results)} results")
122
+ logging.info(f"Grouped Hits: {grouped_results}")
123
+ return grouped_results, None
124
+