Spaces:
Sleeping
Sleeping
from huggingface_hub import login | |
from fastapi import FastAPI, Depends, HTTPException | |
import logging | |
from pydantic import BaseModel | |
from transformers import AutoTokenizer, AutoModel | |
from services.qdrant_searcher import QdrantSearcher | |
from services.openai_service import generate_rag_response | |
from utils.auth import token_required | |
from dotenv import load_dotenv | |
import os | |
import torch | |
from utils.auth_x import x_api_key_auth | |
import time | |
# Load environment variables from .env file | |
load_dotenv() | |
# Initialize FastAPI application | |
app = FastAPI() | |
# Set the cache directory for Hugging Face | |
os.environ["HF_HOME"] = "/tmp/huggingface_cache" | |
# Ensure the cache directory exists | |
hf_home_dir = os.environ["HF_HOME"] | |
if not os.path.exists(hf_home_dir): | |
os.makedirs(hf_home_dir) | |
collection_name = os.getenv('QDRANT_COLLECTION_NAME') | |
logging.info(f"Collection name: {collection_name}") | |
# Setup logging using Python's standard logging library | |
logging.basicConfig(level=logging.INFO) | |
# Load Hugging Face token from environment variable | |
huggingface_token = os.getenv('HUGGINGFACE_HUB_TOKEN') | |
if huggingface_token: | |
try: | |
login(token=huggingface_token, add_to_git_credential=True) | |
logging.info("Successfully logged into Hugging Face Hub.") | |
except Exception as e: | |
logging.error(f"Failed to log into Hugging Face Hub: {e}") | |
raise HTTPException(status_code=500, detail="Failed to log into Hugging Face Hub.") | |
else: | |
raise ValueError("Hugging Face token is not set. Please set the HUGGINGFACE_HUB_TOKEN environment variable.") | |
# Initialize the Qdrant searcher | |
qdrant_url = os.getenv('QDRANT_URL') | |
access_token = os.getenv('QDRANT_ACCESS_TOKEN') | |
if not qdrant_url or not access_token: | |
raise ValueError("Qdrant URL or Access Token is not set. Please set the QDRANT_URL and QDRANT_ACCESS_TOKEN environment variables.") | |
# Load the model and tokenizer with trust_remote_code=True | |
try: | |
cache_folder = os.path.join(hf_home_dir, "transformers_cache") | |
# Load the tokenizer and model with trust_remote_code=True | |
tokenizer = AutoTokenizer.from_pretrained('nomic-ai/nomic-embed-text-v1.5', trust_remote_code=True) | |
model = AutoModel.from_pretrained('nomic-ai/nomic-embed-text-v1.5', trust_remote_code=True) | |
logging.info("Successfully loaded the model and tokenizer with transformers.") | |
# Initialize the Qdrant searcher after the model is successfully loaded | |
global searcher # Ensure searcher is accessible globally if needed | |
searcher = QdrantSearcher(qdrant_url=qdrant_url, access_token=access_token) | |
except Exception as e: | |
logging.error(f"Failed to load the model or initialize searcher: {e}") | |
raise HTTPException(status_code=500, detail="Failed to load the custom model or initialize searcher.") | |
# Function to embed text using the model | |
def embed_text(text): | |
inputs = tokenizer(text, padding=True, truncation=True, return_tensors="pt") | |
outputs = model(**inputs) | |
embeddings = outputs.last_hidden_state.mean(dim=1) # Example: mean pooling | |
return embeddings.detach().numpy() | |
# Define the request body models | |
class SearchDocumentsRequest(BaseModel): | |
query: str | |
limit: int = 3 | |
file_id: str = None | |
class GenerateRAGRequest(BaseModel): | |
search_query: str | |
file_id: str = None | |
class XApiKeyRequest(BaseModel): | |
organization_id: str | |
user_id: str | |
search_query: str | |
file_id: str = None | |
async def root(): | |
return {"message": "Welcome to the Search and RAG API!, go to relevant address for API request"} | |
# Define the search documents endpoint | |
async def search_documents( | |
body: SearchDocumentsRequest, | |
credentials: tuple = Depends(token_required) | |
): | |
customer_id, user_id = credentials | |
start_time = time.time() | |
if not customer_id or not user_id: | |
logging.error("Failed to extract customer_id or user_id from the JWT token.") | |
raise HTTPException(status_code=401, detail="Invalid token: missing customer_id or user_id") | |
logging.info("Received request to search documents") | |
try: | |
logging.info("Starting document search") | |
# Encode the query using the custom embedding function | |
query_embedding = embed_text(body.query) | |
print(body.query) | |
#collection_name = "embed" # Use the collection name where the embeddings are stored | |
logging.info("Performing search using the precomputed embeddings") | |
if body.file_id: | |
hits, error = searcher.search_documents(collection_name, query_embedding, user_id, body.limit, file_id=body.file_id) | |
# Perform search using the precomputed embeddings | |
hits, error = searcher.search_documents(collection_name, query_embedding, user_id, body.limit) | |
if error: | |
logging.error(f"Search documents error: {error}") | |
raise HTTPException(status_code=500, detail=error) | |
end_time = time.time() | |
time_taken = end_time - start_time | |
return hits, time_taken | |
except Exception as e: | |
logging.error(f"Unexpected error: {e}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
# Define the generate RAG response endpoint | |
async def generate_rag_response_api( | |
body: GenerateRAGRequest, | |
credentials: tuple = Depends(token_required) | |
): | |
customer_id, user_id = credentials | |
start_time = time.time() | |
if not customer_id or not user_id: | |
logging.error("Failed to extract customer_id or user_id from the JWT token.") | |
raise HTTPException(status_code=401, detail="Invalid token: missing customer_id or user_id") | |
logging.info("Received request to generate RAG response") | |
try: | |
search_time = time.time() | |
logging.info("Starting document search") | |
# Encode the query using the custom embedding function | |
query_embedding = embed_text(body.search_query) | |
print(body.search_query) | |
#collection_name = "embed" # Use the collection name where the embeddings are stored | |
# Perform search using the precomputed embeddings | |
if body.file_id: | |
hits, error = searcher.search_documents(collection_name, query_embedding, user_id, file_id=body.file_id) | |
else: | |
hits, error = searcher.search_documents(collection_name, query_embedding, user_id) | |
if error: | |
logging.error(f"Search documents error: {error}") | |
raise HTTPException(status_code=500, detail=error) | |
logging.info("Generating RAG response") | |
end_search_time = time.time() | |
search_time_taken = end_search_time - search_time | |
rag_start_time = time.time() | |
# Generate the RAG response using the retrieved documents | |
response, error = generate_rag_response(hits, body.search_query) | |
rag_end_time = time.time() | |
rag_time_taken = rag_end_time - rag_start_time | |
end_time= time.time() | |
total_time = end_time - start_time | |
logging.info(f"Search time: {search_time_taken}, RAG time: {rag_time_taken}, Total time: {total_time}") | |
if error: | |
logging.error(f"Generate RAG response error: {error}") | |
raise HTTPException(status_code=500, detail=error) | |
return {"response": response} | |
except Exception as e: | |
logging.error(f"Unexpected error: {e}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def search_documents_x_api_key( | |
body: XApiKeyRequest, | |
authorized: bool = Depends(x_api_key_auth) | |
): | |
if not authorized: | |
raise HTTPException(status_code=401, detail="Unauthorized") | |
start_time = time.time() | |
organization_id = body.organization_id | |
user_id = body.user_id | |
file_id = body.file_id | |
logging.info(f'search query {body.search_query}') | |
logging.info(f"organization_id: {organization_id}, user_id: {user_id}") | |
logging.info("Received request to search documents with x-api-key auth") | |
try: | |
logging.info("Starting document search") | |
# Encode the query using the custom embedding function | |
query_embedding = embed_text(body.search_query) | |
#collection_name = "embed" # Use the collection name where the embeddings are stored | |
# Perform search using the precomputed embeddings | |
hits, error = searcher.search_documents(collection_name, query_embedding, user_id, limit=3, file_id=file_id) | |
if error: | |
logging.error(f"Search documents error: {error}") | |
raise HTTPException(status_code=500, detail=error) | |
logging.info(f"Document search completed with {len(hits)} hits") | |
end_time = time.time() | |
logging.info(f"Time taken: {end_time - start_time}") | |
return hits | |
except Exception as e: | |
logging.error(f"Unexpected error: {e}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def generate_rag_response_x_api_key( | |
body: XApiKeyRequest, | |
authorized: bool = Depends(x_api_key_auth) | |
): | |
# Assuming x_api_key_auth validates the key | |
if not authorized: | |
raise HTTPException(status_code=401, detail="Unauthorized") | |
start_time = time.time() | |
organization_id = body.organization_id | |
user_id = body.user_id | |
file_id = body.file_id | |
logging.info(f'search query {body.search_query}') | |
logging.info(f"organization_id: {organization_id}, user_id: {user_id}") | |
logging.info("Received request to generate RAG response with x-api-key auth") | |
try: | |
logging.info("Starting document search") | |
# Encode the query using the custom embedding function | |
query_embedding = embed_text(body.search_query) | |
#collection_name = "embed" # Use the collection name where the embeddings are stored | |
# Perform search using the precomputed embeddings | |
hits, error = searcher.search_documents(collection_name, query_embedding, user_id, file_id=file_id) | |
if error: | |
logging.error(f"Search documents error: {error}") | |
raise HTTPException(status_code=500, detail=error) | |
logging.info("Generating RAG response") | |
# Generate the RAG response using the retrieved documents | |
response, error = generate_rag_response(hits, body.search_query) | |
if error: | |
logging.error(f"Generate RAG response error: {error}") | |
raise HTTPException(status_code=500, detail=error) | |
end_time = time.time() | |
logging.info(f"Time taken: {end_time - start_time}") | |
return {"response": response} | |
except Exception as e: | |
logging.error(f"Unexpected error: {e}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
if __name__ == '__main__': | |
import uvicorn | |
uvicorn.run(app, host='0.0.0.0', port=8000) | |