iShare's picture
Update main.py
38465ad
from fastapi import FastAPI, File, UploadFile, Request
from pydantic import BaseModel
from pathlib import Path
from fastapi import Form
from fastapi.responses import JSONResponse
from langchain.text_splitter import RecursiveCharacterTextSplitter
from PyPDF2 import PdfReader
from fastapi import Depends
#在FastAPI中,Depends()函数用于声明依赖项
from huggingface_hub import InferenceClient
import numpy as np
from langchain.chains.question_answering import load_qa_chain
from langchain import PromptTemplate, LLMChain
from langchain import HuggingFaceHub
from langchain.document_loaders import TextLoader
import torch
from sentence_transformers.util import semantic_search
import requests
import random
import string
import sys
import timeit
import datetime
import io
import os
from dotenv import load_dotenv
load_dotenv()
HUGGINGFACEHUB_API_TOKEN = os.getenv('HUGGINGFACEHUB_API_TOKEN')
model_id = os.getenv('model_id')
hf_token = os.getenv('hf_token')
repo_id = os.getenv('repo_id')
def get_embeddings(input_str_texts):
response = requests.post(api_url, headers=headers, json={"inputs": input_str_texts, "options":{"wait_for_model":True}})
return response.json()
def generate_random_string(length):
letters = string.ascii_lowercase
return ''.join(random.choice(letters) for i in range(length))
def remove_context(text):
if 'Context:' in text:
end_of_context = text.find('\n\n')
return text[end_of_context + 2:]
else:
return text
api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{model_id}"
headers = {"Authorization": f"Bearer {hf_token}"}
llm = HuggingFaceHub(repo_id=repo_id,
model_kwargs={"min_length":512,
"max_new_tokens":1024, "do_sample":True,
"temperature":0.01,
"top_k":50,
"top_p":0.95, "eos_token_id":49155})
#prompt_template = """
#You are a very helpful AI assistant. Please ONLY use {context} to answer the user's question {question}. If you don't know the answer, just say that you don't know. DON'T try to make up an answer.
#Your response should be full and easy to understand.
#"""
#PROMPT = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
#chain = load_qa_chain(llm=llm, chain_type="stuff", prompt=PROMPT)
chain = load_qa_chain(llm=llm, chain_type="stuff")
app = FastAPI()
class FileToProcess(BaseModel):
uploaded_file: UploadFile = File(...)
@app.get("/")
async def home():
return "API Working!"
@app.post("/fastapi_file_upload_process")
#async def upload_file(user_question: str, file_to_process: FileToProcess = Depends()):
async def pdf_file_qa_process(user_question: str, request: Request, file_to_process: FileToProcess = Depends()):
print("API Call Triggered.")
start_0 = timeit.default_timer()
uploaded_file = file_to_process.uploaded_file
print("File received:"+uploaded_file.filename)
user_question = request.query_params.get("user_question")
filename = request.query_params.get("filename")
print("User entered question: "+user_question)
print("User uploaded file: "+filename)
random_string = generate_random_string(20)
file_path = Path.cwd() / random_string
file_path.mkdir(parents=True, exist_ok=True)
file_saved_in_api = file_path / uploaded_file.filename
print(file_saved_in_api)
with open(file_saved_in_api, "wb+") as file_object:
file_object.write(uploaded_file.file.read())
text_splitter = RecursiveCharacterTextSplitter(
#separator = "\n",
chunk_size = 500,
chunk_overlap = 100, #striding over the text
length_function = len,
)
doc_reader = PdfReader(file_saved_in_api)
raw_text = ''
for i, page in enumerate(doc_reader.pages):
text = page.extract_text()
if text:
raw_text += text
temp_texts = text_splitter.split_text(raw_text)
texts=temp_texts
initial_embeddings=get_embeddings(temp_texts)
db_embeddings = torch.FloatTensor(initial_embeddings)
print(db_embeddings)
print("db_embeddings created...")
#question = var_query.query
question = user_question
print("API Call Query Received: "+question)
q_embedding=get_embeddings(question)
final_q_embedding = torch.FloatTensor(q_embedding)
print(final_q_embedding)
print("Semantic Similarity Search Starts...")
start_1 = timeit.default_timer()
hits = semantic_search(final_q_embedding, torch.FloatTensor(db_embeddings), top_k=5)
end_1 = timeit.default_timer()
print("Semantic Similarity Search Ends...")
print(f'Semantic Similarity Search共耗时: @ {end_1 - start_1}')
page_contents = []
for i in range(len(hits[0])):
page_content = texts[hits[0][i]['corpus_id']]
page_contents.append(page_content)
print(page_contents)
temp_page_contents=str(page_contents)
final_page_contents = temp_page_contents.replace('\\n', '')
random_string_2=generate_random_string(20)
file_path = random_string_2 + ".txt"
with open(file_path, "w", encoding="utf-8") as file:
file.write(final_page_contents)
loader = TextLoader(file_path, encoding="utf-8")
loaded_documents = loader.load()
print("*****loaded_documents******")
print(loaded_documents)
print("***********")
print(question)
print("*****question******")
print("LLM Chain Starts...")
start_2 = timeit.default_timer()
temp_ai_response = chain({"input_documents": loaded_documents, "question": question}, return_only_outputs=False)
end_2 = timeit.default_timer()
print("LLM Chain Ends...")
print(f'LLM Chain共耗时: @ {end_2 - start_2}')
print(temp_ai_response)
initial_ai_response=temp_ai_response['output_text']
print(initial_ai_response)
cleaned_initial_ai_response = remove_context(initial_ai_response)
#final_ai_response = cleaned_initial_ai_response.partition('¿Cuál es')[0].strip().replace('\n\n', '\n').replace('<|end|>', '').replace('<|user|>', '').replace('<|system|>', '').replace('<|assistant|>', '')
final_ai_response = cleaned_initial_ai_response.partition('¿Cuál es')[0].strip()
final_ai_response = final_ai_response.partition('¿Cuáles')[0].strip()
final_ai_response = final_ai_response.partition('¿Qué es')[0].strip()
final_ai_response = final_ai_response.partition('<|end|>')[0].strip().replace('\n\n', '\n').replace('<|end|>', '').replace('<|user|>', '').replace('<|system|>', '').replace('<|assistant|>', '')
new_final_ai_response = final_ai_response.split('Unhelpful Answer:')[0].strip()
new_final_ai_response = new_final_ai_response.split('Note:')[0].strip()
new_final_ai_response = new_final_ai_response.split('Please provide feedback on how to improve the chatbot.')[0].strip()
print(new_final_ai_response)
end_0 = timeit.default_timer()
print("API Call Ended.")
print(f'API Call共耗时: @ {end_0 - start_0}')
return {"AIResponse": new_final_ai_response}
#return JSONResponse({"AIResponse": new_final_ai_response})