prasunsrivastava
Added the app files with the fist version.
bd9a582
raw
history blame
3.32 kB
import os
from typing import List
from fastapi import FastAPI, HTTPException, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import asyncio
import tempfile
from aimakerspace.vectordatabase import VectorDatabase
from aimakerspace.openai_utils.chatmodel import ChatOpenAI
from app import (
RetrievalAugmentedQAPipeline,
process_file,
system_role_prompt,
user_role_prompt,
)
app = FastAPI()
# Update CORS middleware configuration
app.add_middleware(
CORSMiddleware,
allow_origins=[
"http://localhost:3001", # Development React server
"http://localhost:7860", # Production nginx server
"http://localhost", # Just in case
"*", # Allow all origins in development
],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
expose_headers=["*"],
)
class ChatResponse(BaseModel):
response: str
context: List[tuple]
class ChatRequest(BaseModel):
query: str
@app.post("/api/upload", response_model=dict)
async def upload_file(file: UploadFile = File(...)):
try:
# Create a temporary file to store the upload
with tempfile.NamedTemporaryFile(delete=False, suffix=f".{file.filename.split('.')[-1]}") as temp_file:
content = await file.read()
temp_file.write(content)
temp_file.flush()
# Process the file using existing function
texts = process_file(temp_file.name, file.filename)
# Create vector database
vector_db = VectorDatabase()
vector_db = await vector_db.abuild_from_list(texts)
# Create chat model
chat_openai = ChatOpenAI()
# Create pipeline
pipeline = RetrievalAugmentedQAPipeline(
vector_db_retriever=vector_db,
llm=chat_openai
)
# Store the pipeline in memory (Note: this is not production-ready)
if not hasattr(app, 'pipelines'):
app.pipelines = {}
pipeline_id = str(len(app.pipelines))
app.pipelines[pipeline_id] = pipeline
# Clean up temporary file
os.unlink(temp_file.name)
return {"pipeline_id": pipeline_id, "message": "File processed successfully"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/chat/{pipeline_id}", response_model=ChatResponse)
async def chat(pipeline_id: str, request: ChatRequest):
try:
if not hasattr(app, 'pipelines') or pipeline_id not in app.pipelines:
raise HTTPException(status_code=404, detail="Pipeline not found. Please upload a file first.")
pipeline = app.pipelines[pipeline_id]
result = await pipeline.arun_pipeline(request.query)
# Collect the streaming response
response_text = ""
async for chunk in result["response"]:
response_text += chunk
return ChatResponse(
response=response_text,
context=result["context"]
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)