import os from typing import List from fastapi import FastAPI, HTTPException, UploadFile, File from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel import asyncio import tempfile from aimakerspace.vectordatabase import VectorDatabase from aimakerspace.openai_utils.chatmodel import ChatOpenAI from app import ( RetrievalAugmentedQAPipeline, process_file, system_role_prompt, user_role_prompt, ) app = FastAPI() # Update CORS middleware configuration app.add_middleware( CORSMiddleware, allow_origins=[ "http://localhost:3001", # Development React server "http://localhost:7860", # Production nginx server "http://localhost", # Just in case "*", # Allow all origins in development ], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], expose_headers=["*"], ) class ChatResponse(BaseModel): response: str context: List[tuple] class ChatRequest(BaseModel): query: str @app.post("/api/upload", response_model=dict) async def upload_file(file: UploadFile = File(...)): try: # Create a temporary file to store the upload with tempfile.NamedTemporaryFile(delete=False, suffix=f".{file.filename.split('.')[-1]}") as temp_file: content = await file.read() temp_file.write(content) temp_file.flush() # Process the file using existing function texts = process_file(temp_file.name, file.filename) # Create vector database vector_db = VectorDatabase() vector_db = await vector_db.abuild_from_list(texts) # Create chat model chat_openai = ChatOpenAI() # Create pipeline pipeline = RetrievalAugmentedQAPipeline( vector_db_retriever=vector_db, llm=chat_openai ) # Store the pipeline in memory (Note: this is not production-ready) if not hasattr(app, 'pipelines'): app.pipelines = {} pipeline_id = str(len(app.pipelines)) app.pipelines[pipeline_id] = pipeline # Clean up temporary file os.unlink(temp_file.name) return {"pipeline_id": pipeline_id, "message": "File processed successfully"} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/api/chat/{pipeline_id}", response_model=ChatResponse) async def chat(pipeline_id: str, request: ChatRequest): try: if not hasattr(app, 'pipelines') or pipeline_id not in app.pipelines: raise HTTPException(status_code=404, detail="Pipeline not found. Please upload a file first.") pipeline = app.pipelines[pipeline_id] result = await pipeline.arun_pipeline(request.query) # Collect the streaming response response_text = "" async for chunk in result["response"]: response_text += chunk return ChatResponse( response=response_text, context=result["context"] ) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)