File size: 3,467 Bytes
6e06893 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import os
import shutil
import shutil
import numpy as np
from uuid import uuid4
from io import BytesIO
from pydantic import BaseModel
from fastapi.staticfiles import StaticFiles
from fastapi.encoders import jsonable_encoder
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi import FastAPI, UploadFile, File, Response, status
from llama_index.readers import StringIterableReader, PDFReader, SimpleDirectoryReader
from llama_index import (
VectorStoreIndex,
ServiceContext,
set_global_service_context,
)
# from pyngrok import ngrok
import inference
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class Message(BaseModel):
content: str
if not os.path.exists("tmp"):
os.mkdir("tmp")
vector_stores = {}
@app.post("/retriveal/ingest")
async def store_file(
file: UploadFile = File(...)
):
try:
print(file.filename)
id = str(uuid4())
file_location = f"tmp/{id}"
if not os.path.exists(file_location):
os.mkdir(file_location)
with open(f"{file_location}/{file.filename}", "wb+") as f:
shutil.copyfileobj(file.file, f)
pdf = SimpleDirectoryReader(f"tmp/{id}").load_data()
vector_stores[id] = VectorStoreIndex.from_documents(pdf)
return jsonable_encoder({"uuid": id})
except Exception as e:
# response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
return jsonable_encoder({"error": str(e)})
@app.post("/retriveal/ingest/{id}")
async def store_file_with_id(
id,
file: UploadFile = File(...)
):
try:
print(file.filename)
if(id == None or id == ""):
raise Exception("Id is required")
file_location = f"tmp/{id}"
if not os.path.exists(file_location):
os.mkdir(file_location)
with open(f"{file_location}/{file.filename}", "wb+") as f:
shutil.copyfileobj(file.file, f)
pdf = SimpleDirectoryReader(f"tmp/{id}").load_data()
vector_stores[id] = VectorStoreIndex.from_documents(pdf)
return jsonable_encoder({"uuid": id})
except Exception as e:
# response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
return jsonable_encoder({"error": str(e)})
@app.delete("/session/{id}")
async def delete_session(id):
try:
shutil.rmtree(f"tmp/{id}")
return jsonable_encoder({"message": "ok"})
except Exception as e:
return jsonable_encoder({"error": str(e)})
@app.post("/retriveal/{id}")
async def inference(
id,
message: Message
):
if(id == None or id == ""):
raise Exception("Id is required")
query = message.content
query_engine = vector_stores[id].as_query_engine()
inference = query_engine.query(query)
return inference
def stream_inference(gen):
for token in gen:
yield token
@app.post("/retriveal/stream/{id}")
async def inference(
id,
message: Message
):
if(id == None or id == ""):
raise Exception("Id is required")
query = message.content
query_engine = vector_stores[id].as_query_engine(streaming=True)
gen = query_engine.query(query).response_gen
return StreamingResponse(stream_inference(gen))
app.mount("/", StaticFiles(directory="static", html = True), name="static")
|