from fastapi import FastAPI, File, UploadFile, Request from pydantic import BaseModel from pathlib import Path from fastapi import Form from fastapi.responses import JSONResponse from langchain.text_splitter import RecursiveCharacterTextSplitter from PyPDF2 import PdfReader from fastapi import Depends #在FastAPI中,Depends()函数用于声明依赖项 from huggingface_hub import InferenceClient import numpy as np from langchain.chains.question_answering import load_qa_chain from langchain import PromptTemplate, LLMChain from langchain import HuggingFaceHub from langchain.document_loaders import TextLoader import torch from sentence_transformers.util import semantic_search import requests import random import string import sys import timeit import datetime import io import os from dotenv import load_dotenv load_dotenv() HUGGINGFACEHUB_API_TOKEN = os.getenv('HUGGINGFACEHUB_API_TOKEN') model_id = os.getenv('model_id') hf_token = os.getenv('hf_token') repo_id = os.getenv('repo_id') def get_embeddings(input_str_texts): response = requests.post(api_url, headers=headers, json={"inputs": input_str_texts, "options":{"wait_for_model":True}}) return response.json() def generate_random_string(length): letters = string.ascii_lowercase return ''.join(random.choice(letters) for i in range(length)) def remove_context(text): if 'Context:' in text: end_of_context = text.find('\n\n') return text[end_of_context + 2:] else: return text api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{model_id}" headers = {"Authorization": f"Bearer {hf_token}"} llm = HuggingFaceHub(repo_id=repo_id, model_kwargs={"min_length":512, "max_new_tokens":1024, "do_sample":True, "temperature":0.01, "top_k":50, "top_p":0.95, "eos_token_id":49155}) #prompt_template = """ #You are a very helpful AI assistant. Please ONLY use {context} to answer the user's question {question}. If you don't know the answer, just say that you don't know. DON'T try to make up an answer. #Your response should be full and easy to understand. #""" #PROMPT = PromptTemplate(template=prompt_template, input_variables=["context", "question"]) #chain = load_qa_chain(llm=llm, chain_type="stuff", prompt=PROMPT) chain = load_qa_chain(llm=llm, chain_type="stuff") app = FastAPI() class FileToProcess(BaseModel): uploaded_file: UploadFile = File(...) @app.get("/") async def home(): return "API Working!" @app.post("/fastapi_file_upload_process") #async def upload_file(user_question: str, file_to_process: FileToProcess = Depends()): async def pdf_file_qa_process(user_question: str, request: Request, file_to_process: FileToProcess = Depends()): uploaded_file = file_to_process.uploaded_file print("File received:"+uploaded_file.filename) user_question = request.query_params.get("user_question") filename = request.query_params.get("filename") print("User entered question: "+user_question) print("User uploaded file: "+filename) random_string = generate_random_string(20) file_path = Path.cwd() / random_string file_path.mkdir(parents=True, exist_ok=True) file_saved_in_api = file_path / uploaded_file.filename print(file_saved_in_api) with open(file_saved_in_api, "wb+") as file_object: file_object.write(uploaded_file.file.read()) text_splitter = RecursiveCharacterTextSplitter( #separator = "\n", chunk_size = 500, chunk_overlap = 100, #striding over the text length_function = len, ) doc_reader = PdfReader(file_saved_in_api) raw_text = '' for i, page in enumerate(doc_reader.pages): text = page.extract_text() if text: raw_text += text temp_texts = text_splitter.split_text(raw_text) texts=temp_texts initial_embeddings=get_embeddings(temp_texts) db_embeddings = torch.FloatTensor(initial_embeddings) print(db_embeddings) print("db_embeddings created...") #question = var_query.query question = user_question print("API Call Query Received: "+question) q_embedding=get_embeddings(question) final_q_embedding = torch.FloatTensor(q_embedding) print(final_q_embedding) print("Semantic Similarity Search Starts...") start_1 = timeit.default_timer() hits = semantic_search(final_q_embedding, torch.FloatTensor(db_embeddings), top_k=5) end_1 = timeit.default_timer() print("Semantic Similarity Search Ends...") print(f'Semantic Similarity Search共耗时: @ {end_1 - start_1}') page_contents = [] for i in range(len(hits[0])): page_content = texts[hits[0][i]['corpus_id']] page_contents.append(page_content) print(page_contents) temp_page_contents=str(page_contents) final_page_contents = temp_page_contents.replace('\\n', '') random_string_2=generate_random_string(20) file_path = random_string_2 + ".txt" with open(file_path, "w", encoding="utf-8") as file: file.write(final_page_contents) loader = TextLoader(file_path, encoding="utf-8") loaded_documents = loader.load() print("*****loaded_documents******") print(loaded_documents) print("***********") print(question) print("*****question******") print("LLM Chain Starts...") start_2 = timeit.default_timer() temp_ai_response = chain({"input_documents": loaded_documents, "question": question}, return_only_outputs=False) end_2 = timeit.default_timer() print("LLM Chain Ends...") print(f'LLM Chain共耗时: @ {end_2 - start_2}') print(temp_ai_response) initial_ai_response=temp_ai_response['output_text'] print(initial_ai_response) cleaned_initial_ai_response = remove_context(initial_ai_response) #final_ai_response = cleaned_initial_ai_response.partition('¿Cuál es')[0].strip().replace('\n\n', '\n').replace('<|end|>', '').replace('<|user|>', '').replace('<|system|>', '').replace('<|assistant|>', '') final_ai_response = cleaned_initial_ai_response.partition('¿Cuál es')[0].strip() final_ai_response = final_ai_response.partition('¿Cuáles')[0].strip() final_ai_response = final_ai_response.partition('¿Qué es')[0].strip() final_ai_response = final_ai_response.partition('<|end|>')[0].strip().replace('\n\n', '\n').replace('<|end|>', '').replace('<|user|>', '').replace('<|system|>', '').replace('<|assistant|>', '') new_final_ai_response = final_ai_response.split('Unhelpful Answer:')[0].strip() new_final_ai_response = new_final_ai_response.split('Note:')[0].strip() new_final_ai_response = new_final_ai_response.split('Please provide feedback on how to improve the chatbot.')[0].strip() print(new_final_ai_response) return {"AIResponse": new_final_ai_response} #return JSONResponse({"AIResponse": new_final_ai_response})