Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, File, UploadFile, Request | |
from pydantic import BaseModel | |
from pathlib import Path | |
from fastapi import Form | |
from fastapi.responses import JSONResponse | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from PyPDF2 import PdfReader | |
from fastapi import Depends | |
#在FastAPI中,Depends()函数用于声明依赖项 | |
from huggingface_hub import InferenceClient | |
import numpy as np | |
from langchain.chains.question_answering import load_qa_chain | |
from langchain import PromptTemplate, LLMChain | |
from langchain import HuggingFaceHub | |
from langchain.document_loaders import TextLoader | |
import torch | |
from sentence_transformers.util import semantic_search | |
import requests | |
import random | |
import string | |
import sys | |
import timeit | |
import datetime | |
import io | |
import os | |
from dotenv import load_dotenv | |
load_dotenv() | |
HUGGINGFACEHUB_API_TOKEN = os.getenv('HUGGINGFACEHUB_API_TOKEN') | |
model_id = os.getenv('model_id') | |
hf_token = os.getenv('hf_token') | |
repo_id = os.getenv('repo_id') | |
def get_embeddings(input_str_texts): | |
response = requests.post(api_url, headers=headers, json={"inputs": input_str_texts, "options":{"wait_for_model":True}}) | |
return response.json() | |
def generate_random_string(length): | |
letters = string.ascii_lowercase | |
return ''.join(random.choice(letters) for i in range(length)) | |
def remove_context(text): | |
if 'Context:' in text: | |
end_of_context = text.find('\n\n') | |
return text[end_of_context + 2:] | |
else: | |
return text | |
api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{model_id}" | |
headers = {"Authorization": f"Bearer {hf_token}"} | |
llm = HuggingFaceHub(repo_id=repo_id, | |
model_kwargs={"min_length":512, | |
"max_new_tokens":1024, "do_sample":True, | |
"temperature":0.01, | |
"top_k":50, | |
"top_p":0.95, "eos_token_id":49155}) | |
#prompt_template = """ | |
#You are a very helpful AI assistant. Please ONLY use {context} to answer the user's question {question}. If you don't know the answer, just say that you don't know. DON'T try to make up an answer. | |
#Your response should be full and easy to understand. | |
#""" | |
#PROMPT = PromptTemplate(template=prompt_template, input_variables=["context", "question"]) | |
#chain = load_qa_chain(llm=llm, chain_type="stuff", prompt=PROMPT) | |
chain = load_qa_chain(llm=llm, chain_type="stuff") | |
app = FastAPI() | |
class FileToProcess(BaseModel): | |
uploaded_file: UploadFile = File(...) | |
async def home(): | |
return "API Working!" | |
#async def upload_file(user_question: str, file_to_process: FileToProcess = Depends()): | |
async def pdf_file_qa_process(user_question: str, request: Request, file_to_process: FileToProcess = Depends()): | |
print("API Call Triggered.") | |
start_0 = timeit.default_timer() | |
uploaded_file = file_to_process.uploaded_file | |
print("File received:"+uploaded_file.filename) | |
user_question = request.query_params.get("user_question") | |
filename = request.query_params.get("filename") | |
print("User entered question: "+user_question) | |
print("User uploaded file: "+filename) | |
random_string = generate_random_string(20) | |
file_path = Path.cwd() / random_string | |
file_path.mkdir(parents=True, exist_ok=True) | |
file_saved_in_api = file_path / uploaded_file.filename | |
print(file_saved_in_api) | |
with open(file_saved_in_api, "wb+") as file_object: | |
file_object.write(uploaded_file.file.read()) | |
text_splitter = RecursiveCharacterTextSplitter( | |
#separator = "\n", | |
chunk_size = 500, | |
chunk_overlap = 100, #striding over the text | |
length_function = len, | |
) | |
doc_reader = PdfReader(file_saved_in_api) | |
raw_text = '' | |
for i, page in enumerate(doc_reader.pages): | |
text = page.extract_text() | |
if text: | |
raw_text += text | |
temp_texts = text_splitter.split_text(raw_text) | |
texts=temp_texts | |
initial_embeddings=get_embeddings(temp_texts) | |
db_embeddings = torch.FloatTensor(initial_embeddings) | |
print(db_embeddings) | |
print("db_embeddings created...") | |
#question = var_query.query | |
question = user_question | |
print("API Call Query Received: "+question) | |
q_embedding=get_embeddings(question) | |
final_q_embedding = torch.FloatTensor(q_embedding) | |
print(final_q_embedding) | |
print("Semantic Similarity Search Starts...") | |
start_1 = timeit.default_timer() | |
hits = semantic_search(final_q_embedding, torch.FloatTensor(db_embeddings), top_k=5) | |
end_1 = timeit.default_timer() | |
print("Semantic Similarity Search Ends...") | |
print(f'Semantic Similarity Search共耗时: @ {end_1 - start_1}') | |
page_contents = [] | |
for i in range(len(hits[0])): | |
page_content = texts[hits[0][i]['corpus_id']] | |
page_contents.append(page_content) | |
print(page_contents) | |
temp_page_contents=str(page_contents) | |
final_page_contents = temp_page_contents.replace('\\n', '') | |
random_string_2=generate_random_string(20) | |
file_path = random_string_2 + ".txt" | |
with open(file_path, "w", encoding="utf-8") as file: | |
file.write(final_page_contents) | |
loader = TextLoader(file_path, encoding="utf-8") | |
loaded_documents = loader.load() | |
print("*****loaded_documents******") | |
print(loaded_documents) | |
print("***********") | |
print(question) | |
print("*****question******") | |
print("LLM Chain Starts...") | |
start_2 = timeit.default_timer() | |
temp_ai_response = chain({"input_documents": loaded_documents, "question": question}, return_only_outputs=False) | |
end_2 = timeit.default_timer() | |
print("LLM Chain Ends...") | |
print(f'LLM Chain共耗时: @ {end_2 - start_2}') | |
print(temp_ai_response) | |
initial_ai_response=temp_ai_response['output_text'] | |
print(initial_ai_response) | |
cleaned_initial_ai_response = remove_context(initial_ai_response) | |
#final_ai_response = cleaned_initial_ai_response.partition('¿Cuál es')[0].strip().replace('\n\n', '\n').replace('<|end|>', '').replace('<|user|>', '').replace('<|system|>', '').replace('<|assistant|>', '') | |
final_ai_response = cleaned_initial_ai_response.partition('¿Cuál es')[0].strip() | |
final_ai_response = final_ai_response.partition('¿Cuáles')[0].strip() | |
final_ai_response = final_ai_response.partition('¿Qué es')[0].strip() | |
final_ai_response = final_ai_response.partition('<|end|>')[0].strip().replace('\n\n', '\n').replace('<|end|>', '').replace('<|user|>', '').replace('<|system|>', '').replace('<|assistant|>', '') | |
new_final_ai_response = final_ai_response.split('Unhelpful Answer:')[0].strip() | |
new_final_ai_response = new_final_ai_response.split('Note:')[0].strip() | |
new_final_ai_response = new_final_ai_response.split('Please provide feedback on how to improve the chatbot.')[0].strip() | |
print(new_final_ai_response) | |
end_0 = timeit.default_timer() | |
print("API Call Ended.") | |
print(f'API Call共耗时: @ {end_0 - start_0}') | |
return {"AIResponse": new_final_ai_response} | |
#return JSONResponse({"AIResponse": new_final_ai_response}) |