vhr1007 commited on
Commit
567e7ba
·
1 Parent(s): 5897f5d

adding embed-query

Browse files
Files changed (3) hide show
  1. app.py +9 -6
  2. requirements.txt +1 -0
  3. services/qdrant_searcher.py +16 -5
app.py CHANGED
@@ -8,6 +8,7 @@ from services.openai_service import generate_rag_response
8
  from utils.auth import token_required
9
  from dotenv import load_dotenv
10
  import os
 
11
 
12
  # Load environment variables from .env file
13
  load_dotenv()
@@ -57,7 +58,7 @@ try:
57
 
58
  # Initialize the Qdrant searcher after the model is successfully loaded
59
  global searcher # Ensure searcher is accessible globally if needed
60
- searcher = QdrantSearcher(encoder=model, qdrant_url=qdrant_url, access_token=access_token)
61
 
62
  except Exception as e:
63
  logging.error(f"Failed to load the model or initialize searcher: {e}")
@@ -68,7 +69,7 @@ def embed_text(text):
68
  inputs = tokenizer(text, padding=True, truncation=True, return_tensors="pt")
69
  outputs = model(**inputs)
70
  embeddings = outputs.last_hidden_state.mean(dim=1) # Example: mean pooling
71
- return embeddings
72
 
73
  # Define the request body models
74
  class SearchDocumentsRequest(BaseModel):
@@ -97,8 +98,10 @@ async def search_documents(
97
  # Encode the query using the custom embedding function
98
  query_embedding = embed_text(body.query)
99
 
100
- # Assuming searcher.search_documents uses these embeddings for search
101
- hits, error = searcher.search_documents("documents", query_embedding, user_id, body.limit)
 
 
102
 
103
  if error:
104
  logging.error(f"Search documents error: {error}")
@@ -128,7 +131,7 @@ async def generate_rag_response_api(
128
  # Encode the query using the custom embedding function
129
  query_embedding = embed_text(body.search_query)
130
 
131
- # Perform search using the encoded query
132
  hits, error = searcher.search_documents("documents", query_embedding, user_id)
133
 
134
  if error:
@@ -137,7 +140,7 @@ async def generate_rag_response_api(
137
 
138
  logging.info("Generating RAG response")
139
 
140
- # Assuming generate_rag_response uses the retrieved documents to generate a response
141
  response, error = generate_rag_response(hits, body.search_query)
142
 
143
  if error:
 
8
  from utils.auth import token_required
9
  from dotenv import load_dotenv
10
  import os
11
+ import torch
12
 
13
  # Load environment variables from .env file
14
  load_dotenv()
 
58
 
59
  # Initialize the Qdrant searcher after the model is successfully loaded
60
  global searcher # Ensure searcher is accessible globally if needed
61
+ searcher = QdrantSearcher(qdrant_url=qdrant_url, access_token=access_token)
62
 
63
  except Exception as e:
64
  logging.error(f"Failed to load the model or initialize searcher: {e}")
 
69
  inputs = tokenizer(text, padding=True, truncation=True, return_tensors="pt")
70
  outputs = model(**inputs)
71
  embeddings = outputs.last_hidden_state.mean(dim=1) # Example: mean pooling
72
+ return embeddings.detach().numpy()
73
 
74
  # Define the request body models
75
  class SearchDocumentsRequest(BaseModel):
 
98
  # Encode the query using the custom embedding function
99
  query_embedding = embed_text(body.query)
100
 
101
+ collection_name = "my_embeddings" # Use the collection name where the embeddings are stored
102
+
103
+ # Perform search using the precomputed embeddings
104
+ hits, error = searcher.search_documents(collection_name, query_embedding, user_id, body.limit)
105
 
106
  if error:
107
  logging.error(f"Search documents error: {error}")
 
131
  # Encode the query using the custom embedding function
132
  query_embedding = embed_text(body.search_query)
133
 
134
+ # Perform search using the precomputed embeddings
135
  hits, error = searcher.search_documents("documents", query_embedding, user_id)
136
 
137
  if error:
 
140
 
141
  logging.info("Generating RAG response")
142
 
143
+ # Generate the RAG response using the retrieved documents
144
  response, error = generate_rag_response(hits, body.search_query)
145
 
146
  if error:
requirements.txt CHANGED
@@ -5,6 +5,7 @@ cryptography>=3.4.7
5
  openai==1.37.1
6
  PyJWT==2.6.0
7
  nltk==3.6.7
 
8
  pydantic==2.8.2
9
  pydantic_core==2.20.1
10
  Pygments==2.18.0
 
5
  openai==1.37.1
6
  PyJWT==2.6.0
7
  nltk==3.6.7
8
+ numpy==1.22.0
9
  pydantic==2.8.2
10
  pydantic_core==2.20.1
11
  Pygments==2.18.0
services/qdrant_searcher.py CHANGED
@@ -1,21 +1,32 @@
1
  import logging
 
 
2
  from qdrant_client import QdrantClient
3
  from qdrant_client.http.models import Filter, FieldCondition
4
 
5
  class QdrantSearcher:
6
- def __init__(self, encoder, qdrant_url, access_token):
7
- self.encoder = encoder
8
  self.client = QdrantClient(url=qdrant_url, api_key=access_token)
9
 
10
- def search_documents(self, collection_name, query, user_id, limit=3):
11
  logging.info("Starting document search")
12
- query_vector = self.encoder.encode(query).tolist()
 
 
 
 
 
 
 
 
 
13
  query_filter = Filter(must=[FieldCondition(key="user_id", match={"value": user_id})])
14
 
15
  try:
16
  hits = self.client.search(
17
  collection_name=collection_name,
18
- query_vector=query_vector,
19
  limit=limit,
20
  query_filter=query_filter
21
  )
 
1
  import logging
2
+ import torch
3
+ import numpy as np
4
  from qdrant_client import QdrantClient
5
  from qdrant_client.http.models import Filter, FieldCondition
6
 
7
  class QdrantSearcher:
8
+ def __init__(self, qdrant_url, access_token):
9
+ # Removed the encoder since embeddings are precomputed externally
10
  self.client = QdrantClient(url=qdrant_url, api_key=access_token)
11
 
12
+ def search_documents(self, collection_name, query_embedding, user_id, limit=3):
13
  logging.info("Starting document search")
14
+
15
+ # Ensure the query_embedding is in the correct format (list)
16
+ if isinstance(query_embedding, torch.Tensor):
17
+ query_embedding = query_embedding.detach().numpy().tolist()
18
+ logging.info("Converted query embedding to list")
19
+ elif isinstance(query_embedding, np.ndarray):
20
+ query_embedding = query_embedding.tolist()
21
+ logging.info("Converted query embedding to list")
22
+
23
+ # Filter by user_id
24
  query_filter = Filter(must=[FieldCondition(key="user_id", match={"value": user_id})])
25
 
26
  try:
27
  hits = self.client.search(
28
  collection_name=collection_name,
29
+ query_vector=query_embedding,
30
  limit=limit,
31
  query_filter=query_filter
32
  )