|
import os |
|
from typing import List |
|
from fastapi import FastAPI, HTTPException, UploadFile, File |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from pydantic import BaseModel |
|
import asyncio |
|
import tempfile |
|
from aimakerspace.vectordatabase import VectorDatabase |
|
from aimakerspace.openai_utils.chatmodel import ChatOpenAI |
|
|
|
from app import ( |
|
RetrievalAugmentedQAPipeline, |
|
process_file, |
|
system_role_prompt, |
|
user_role_prompt, |
|
) |
|
|
|
app = FastAPI() |
|
|
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=[ |
|
"http://localhost:3001", |
|
"http://localhost:7860", |
|
"http://localhost", |
|
"*", |
|
], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
expose_headers=["*"], |
|
) |
|
|
|
class ChatResponse(BaseModel): |
|
response: str |
|
context: List[tuple] |
|
|
|
class ChatRequest(BaseModel): |
|
query: str |
|
|
|
@app.post("/api/upload", response_model=dict) |
|
async def upload_file(file: UploadFile = File(...)): |
|
try: |
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=f".{file.filename.split('.')[-1]}") as temp_file: |
|
content = await file.read() |
|
temp_file.write(content) |
|
temp_file.flush() |
|
|
|
|
|
texts = process_file(temp_file.name, file.filename) |
|
|
|
|
|
vector_db = VectorDatabase() |
|
vector_db = await vector_db.abuild_from_list(texts) |
|
|
|
|
|
chat_openai = ChatOpenAI() |
|
|
|
|
|
pipeline = RetrievalAugmentedQAPipeline( |
|
vector_db_retriever=vector_db, |
|
llm=chat_openai |
|
) |
|
|
|
|
|
if not hasattr(app, 'pipelines'): |
|
app.pipelines = {} |
|
pipeline_id = str(len(app.pipelines)) |
|
app.pipelines[pipeline_id] = pipeline |
|
|
|
|
|
os.unlink(temp_file.name) |
|
|
|
return {"pipeline_id": pipeline_id, "message": "File processed successfully"} |
|
|
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
@app.post("/api/chat/{pipeline_id}", response_model=ChatResponse) |
|
async def chat(pipeline_id: str, request: ChatRequest): |
|
try: |
|
if not hasattr(app, 'pipelines') or pipeline_id not in app.pipelines: |
|
raise HTTPException(status_code=404, detail="Pipeline not found. Please upload a file first.") |
|
|
|
pipeline = app.pipelines[pipeline_id] |
|
result = await pipeline.arun_pipeline(request.query) |
|
|
|
|
|
response_text = "" |
|
async for chunk in result["response"]: |
|
response_text += chunk |
|
|
|
return ChatResponse( |
|
response=response_text, |
|
context=result["context"] |
|
) |
|
|
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
uvicorn.run(app, host="0.0.0.0", port=8000) |