Spaces:
Sleeping
Sleeping
File size: 9,484 Bytes
500c1ba ce94de4 500c1ba 567e7ba bcd2179 500c1ba 21c27da 500c1ba 21c27da 500c1ba b619001 500c1ba e014b5f 500c1ba b619001 500c1ba 8411b7d 500c1ba 21c27da 500c1ba 21c27da b619001 ce94de4 b619001 21c27da 500c1ba 21c27da 500c1ba 21c27da 500c1ba 7d3c394 21c27da 3408e43 ce94de4 7d3c394 5897f5d b8ef5f6 567e7ba 5897f5d 21c27da 5897f5d 500c1ba 7d3c394 5897f5d 7d3c394 567e7ba 500c1ba 21c27da 500c1ba d343a87 bcd2179 500c1ba 9db95db 21c27da 500c1ba 5897f5d a6cce41 36331f1 567e7ba a6cce41 567e7ba 500c1ba 21c27da 500c1ba 5897f5d d343a87 36331f1 567e7ba 5213518 500c1ba 5897f5d 7d3c394 567e7ba d343a87 500c1ba bcd2179 500c1ba 21c27da |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 |
from huggingface_hub import login
from fastapi import FastAPI, Depends, HTTPException
import logging
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModel
from services.qdrant_searcher import QdrantSearcher
from services.openai_service import generate_rag_response
from utils.auth import token_required
from dotenv import load_dotenv
import os
import torch
from utils.auth_x import x_api_key_auth
# Load environment variables from .env file
load_dotenv()
# Initialize FastAPI application
app = FastAPI()
# Set the cache directory for Hugging Face
os.environ["HF_HOME"] = "/tmp/huggingface_cache"
# Ensure the cache directory exists
hf_home_dir = os.environ["HF_HOME"]
if not os.path.exists(hf_home_dir):
os.makedirs(hf_home_dir)
# Setup logging using Python's standard logging library
logging.basicConfig(level=logging.INFO)
# Load Hugging Face token from environment variable
huggingface_token = os.getenv('HUGGINGFACE_HUB_TOKEN')
if huggingface_token:
try:
login(token=huggingface_token, add_to_git_credential=True)
logging.info("Successfully logged into Hugging Face Hub.")
except Exception as e:
logging.error(f"Failed to log into Hugging Face Hub: {e}")
raise HTTPException(status_code=500, detail="Failed to log into Hugging Face Hub.")
else:
raise ValueError("Hugging Face token is not set. Please set the HUGGINGFACE_HUB_TOKEN environment variable.")
# Initialize the Qdrant searcher
qdrant_url = os.getenv('QDRANT_URL')
access_token = os.getenv('QDRANT_ACCESS_TOKEN')
if not qdrant_url or not access_token:
raise ValueError("Qdrant URL or Access Token is not set. Please set the QDRANT_URL and QDRANT_ACCESS_TOKEN environment variables.")
# Load the model and tokenizer with trust_remote_code=True
try:
cache_folder = os.path.join(hf_home_dir, "transformers_cache")
# Load the tokenizer and model with trust_remote_code=True
tokenizer = AutoTokenizer.from_pretrained('nomic-ai/nomic-embed-text-v1.5', trust_remote_code=True)
model = AutoModel.from_pretrained('nomic-ai/nomic-embed-text-v1.5', trust_remote_code=True)
logging.info("Successfully loaded the model and tokenizer with transformers.")
# Initialize the Qdrant searcher after the model is successfully loaded
global searcher # Ensure searcher is accessible globally if needed
searcher = QdrantSearcher(qdrant_url=qdrant_url, access_token=access_token)
except Exception as e:
logging.error(f"Failed to load the model or initialize searcher: {e}")
raise HTTPException(status_code=500, detail="Failed to load the custom model or initialize searcher.")
# Function to embed text using the model
def embed_text(text):
inputs = tokenizer(text, padding=True, truncation=True, return_tensors="pt")
outputs = model(**inputs)
embeddings = outputs.last_hidden_state.mean(dim=1) # Example: mean pooling
return embeddings.detach().numpy()
# Define the request body models
class SearchDocumentsRequest(BaseModel):
query: str
limit: int = 3
class GenerateRAGRequest(BaseModel):
search_query: str
class XApiKeyRequest(BaseModel):
organization_id: str
user_id: str
search_query: str
@app.get("/")
async def root():
return {"message": "Welcome to the Search and RAG API!, go to relevant address for API request"}
# Define the search documents endpoint
@app.post("/api/search-documents")
async def search_documents(
body: SearchDocumentsRequest,
credentials: tuple = Depends(token_required)
):
customer_id, user_id = credentials
if not customer_id or not user_id:
logging.error("Failed to extract customer_id or user_id from the JWT token.")
raise HTTPException(status_code=401, detail="Invalid token: missing customer_id or user_id")
logging.info("Received request to search documents")
try:
logging.info("Starting document search")
# Encode the query using the custom embedding function
query_embedding = embed_text(body.query)
print(body.query)
collection_name = "embed" # Use the collection name where the embeddings are stored
logging.info("Performing search using the precomputed embeddings")
# Perform search using the precomputed embeddings
hits, error = searcher.search_documents(collection_name, query_embedding, user_id, body.limit)
if error:
logging.error(f"Search documents error: {error}")
raise HTTPException(status_code=500, detail=error)
return hits
except Exception as e:
logging.error(f"Unexpected error: {e}")
raise HTTPException(status_code=500, detail=str(e))
# Define the generate RAG response endpoint
@app.post("/api/generate-rag-response")
async def generate_rag_response_api(
body: GenerateRAGRequest,
credentials: tuple = Depends(token_required)
):
customer_id, user_id = credentials
if not customer_id or not user_id:
logging.error("Failed to extract customer_id or user_id from the JWT token.")
raise HTTPException(status_code=401, detail="Invalid token: missing customer_id or user_id")
logging.info("Received request to generate RAG response")
try:
logging.info("Starting document search")
# Encode the query using the custom embedding function
query_embedding = embed_text(body.search_query)
print(body.search_query)
collection_name = "embed" # Use the collection name where the embeddings are stored
# Perform search using the precomputed embeddings
hits, error = searcher.search_documents(collection_name, query_embedding, user_id)
if error:
logging.error(f"Search documents error: {error}")
raise HTTPException(status_code=500, detail=error)
logging.info("Generating RAG response")
# Generate the RAG response using the retrieved documents
response, error = generate_rag_response(hits, body.search_query)
if error:
logging.error(f"Generate RAG response error: {error}")
raise HTTPException(status_code=500, detail=error)
return {"response": response}
except Exception as e:
logging.error(f"Unexpected error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/search-documents/v1")
async def search_documents_x_api_key(
body: XApiKeyRequest,
authorized: bool = Depends(x_api_key_auth)
):
if not authorized:
raise HTTPException(status_code=401, detail="Unauthorized")
organization_id = body.organization_id
user_id = body.user_id
logging.info(f'search query {body.search_query}')
logging.info(f"organization_id: {organization_id}, user_id: {user_id}")
logging.info("Received request to search documents with x-api-key auth")
try:
logging.info("Starting document search")
# Encode the query using the custom embedding function
query_embedding = embed_text(body.search_query)
collection_name = "embed" # Use the collection name where the embeddings are stored
# Perform search using the precomputed embeddings
hits, error = searcher.search_documents(collection_name, query_embedding, user_id, limit=3)
if error:
logging.error(f"Search documents error: {error}")
raise HTTPException(status_code=500, detail=error)
logging.info(f"Document search completed with {len(hits)} hits")
return hits
except Exception as e:
logging.error(f"Unexpected error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/generate-rag-response/v1")
async def generate_rag_response_x_api_key(
body: XApiKeyRequest,
authorized: bool = Depends(x_api_key_auth)
):
# Assuming x_api_key_auth validates the key
if not authorized:
raise HTTPException(status_code=401, detail="Unauthorized")
organization_id = body.organization_id
user_id = body.user_id
logging.info(f'search query {body.search_query}')
logging.info(f"organization_id: {organization_id}, user_id: {user_id}")
logging.info("Received request to generate RAG response with x-api-key auth")
try:
logging.info("Starting document search")
# Encode the query using the custom embedding function
query_embedding = embed_text(body.search_query)
collection_name = "embed" # Use the collection name where the embeddings are stored
# Perform search using the precomputed embeddings
hits, error = searcher.search_documents(collection_name, query_embedding, user_id)
if error:
logging.error(f"Search documents error: {error}")
raise HTTPException(status_code=500, detail=error)
logging.info("Generating RAG response")
# Generate the RAG response using the retrieved documents
response, error = generate_rag_response(hits, body.search_query)
if error:
logging.error(f"Generate RAG response error: {error}")
raise HTTPException(status_code=500, detail=error)
return {"response": response}
except Exception as e:
logging.error(f"Unexpected error: {e}")
raise HTTPException(status_code=500, detail=str(e))
if __name__ == '__main__':
import uvicorn
uvicorn.run(app, host='0.0.0.0', port=8000)
|