from huggingface_hub import login from fastapi import FastAPI, Depends, HTTPException import logging from pydantic import BaseModel from sentence_transformers import SentenceTransformer from services.qdrant_searcher import QdrantSearcher from services.openai_service import generate_rag_response from utils.auth import token_required from dotenv import load_dotenv import os # Load environment variables from .env file load_dotenv() # Initialize FastAPI application app = FastAPI() os.environ["HF_HOME"] = "/tmp/huggingface_cache" # Ensure the cache directory exists cache_dir = os.environ["HF_HOME"] if not os.path.exists(cache_dir): os.makedirs(cache_dir) # Setup logging logging.basicConfig(level=logging.INFO) # Load Hugging Face token from environment variable huggingface_token = os.getenv('HUGGINGFACE_HUB_TOKEN') if huggingface_token: try: # Log in to Hugging Face without adding credentials to Git login(token=huggingface_token, add_to_git_credential=False) logging.info("Successfully logged into Hugging Face Hub.") except Exception as e: logging.error(f"Failed to log into Hugging Face Hub: {e}") raise HTTPException(status_code=500, detail="Failed to log into Hugging Face Hub.") else: raise ValueError("Hugging Face token is not set. Please set the HUGGINGFACE_HUB_TOKEN environment variable.") # Initialize the Qdrant searcher qdrant_url = os.getenv('QDRANT_URL') access_token = os.getenv('QDRANT_ACCESS_TOKEN') if not qdrant_url or not access_token: raise ValueError("Qdrant URL or Access Token is not set. Please set the QDRANT_URL and QDRANT_ACCESS_TOKEN environment variables.") # Initialize the SentenceTransformer model try: encoder = SentenceTransformer('nomic-ai/nomic-embed-text-v1.5') logging.info("Successfully loaded the SentenceTransformer model.") except Exception as e: logging.error(f"Failed to load the SentenceTransformer model: {e}") raise HTTPException(status_code=500, detail="Failed to load the SentenceTransformer model.") # Initialize the Qdrant searcher searcher = QdrantSearcher(encoder, qdrant_url, access_token) # Define the request body models class SearchDocumentsRequest(BaseModel): query: str limit: int = 3 class GenerateRAGRequest(BaseModel): search_query: str # Define the search documents endpoint @app.post("/api/search-documents") async def search_documents( body: SearchDocumentsRequest, credentials: tuple = Depends(token_required) ): customer_id, user_id = credentials if not customer_id or not user_id: logging.error("Failed to extract customer_id or user_id from the JWT token.") raise HTTPException(status_code=401, detail="Invalid token: missing customer_id or user_id") logging.info("Received request to search documents") try: hits, error = searcher.search_documents("documents", body.query, user_id, body.limit) if error: logging.error(f"Search documents error: {error}") raise HTTPException(status_code=500, detail=error) return hits except Exception as e: logging.error(f"Unexpected error: {e}") raise HTTPException(status_code=500, detail=str(e)) # Define the generate RAG response endpoint @app.post("/api/generate-rag-response") async def generate_rag_response_api( body: GenerateRAGRequest, credentials: tuple = Depends(token_required) ): customer_id, user_id = credentials if not customer_id or not user_id: logging.error("Failed to extract customer_id or user_id from the JWT token.") raise HTTPException(status_code=401, detail="Invalid token: missing customer_id or user_id") logging.info("Received request to generate RAG response") try: hits, error = searcher.search_documents("documents", body.search_query, user_id) if error: logging.error(f"Search documents error: {error}") raise HTTPException(status_code=500, detail=error) response, error = generate_rag_response(hits, body.search_query) if error: logging.error(f"Generate RAG response error: {error}") raise HTTPException(status_code=500, detail=error) return {"response": response} except Exception as e: logging.error(f"Unexpected error: {e}") raise HTTPException(status_code=500, detail=str(e)) if __name__ == '__main__': import uvicorn uvicorn.run(app, host='0.0.0.0', port=8000)