Spaces:
Sleeping
Sleeping
Update main.py
Browse files
main.py
CHANGED
@@ -8,6 +8,13 @@ from PyPDF2 import PdfReader
|
|
8 |
from fastapi import Depends
|
9 |
#在FastAPI中,Depends()函数用于声明依赖项
|
10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
import random
|
12 |
import string
|
13 |
import sys
|
@@ -19,10 +26,44 @@ import os
|
|
19 |
from dotenv import load_dotenv
|
20 |
load_dotenv()
|
21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
def generate_random_string(length):
|
23 |
letters = string.ascii_lowercase
|
24 |
return ''.join(random.choice(letters) for i in range(length))
|
25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
app = FastAPI()
|
27 |
|
28 |
class FileToProcess(BaseModel):
|
@@ -66,10 +107,48 @@ async def pdf_file_qa_process(username: str, request: Request, file_to_process:
|
|
66 |
text = page.extract_text()
|
67 |
if text:
|
68 |
raw_text += text
|
69 |
-
temp_texts = text_splitter.split_text(raw_text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
|
71 |
-
|
72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
|
74 |
api_call_msg={"INFO": f"File '{file_saved_in_api}' saved to your profile."}
|
75 |
print(api_call_msg)
|
|
|
8 |
from fastapi import Depends
|
9 |
#在FastAPI中,Depends()函数用于声明依赖项
|
10 |
|
11 |
+
from langchain.chains.question_answering import load_qa_chain
|
12 |
+
from langchain import PromptTemplate, LLMChain
|
13 |
+
from langchain import HuggingFaceHub
|
14 |
+
from langchain.document_loaders import TextLoader
|
15 |
+
import torch
|
16 |
+
|
17 |
+
import requests
|
18 |
import random
|
19 |
import string
|
20 |
import sys
|
|
|
26 |
from dotenv import load_dotenv
|
27 |
load_dotenv()
|
28 |
|
29 |
+
HUGGINGFACEHUB_API_TOKEN = os.getenv('HUGGINGFACEHUB_API_TOKEN')
|
30 |
+
model_id = os.getenv('model_id')
|
31 |
+
hf_token = os.getenv('hf_token')
|
32 |
+
repo_id = os.getenv('repo_id')
|
33 |
+
|
34 |
+
def get_embeddings(input_str_texts):
|
35 |
+
response = requests.post(api_url, headers=headers, json={"inputs": input_str_texts, "options":{"wait_for_model":True}})
|
36 |
+
return response.json()
|
37 |
+
|
38 |
def generate_random_string(length):
|
39 |
letters = string.ascii_lowercase
|
40 |
return ''.join(random.choice(letters) for i in range(length))
|
41 |
|
42 |
+
def remove_context(text):
|
43 |
+
if 'Context:' in text:
|
44 |
+
end_of_context = text.find('\n\n')
|
45 |
+
return text[end_of_context + 2:]
|
46 |
+
else:
|
47 |
+
return text
|
48 |
+
|
49 |
+
api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{model_id}"
|
50 |
+
headers = {"Authorization": f"Bearer {hf_token}"}
|
51 |
+
|
52 |
+
llm = HuggingFaceHub(repo_id=repo_id,
|
53 |
+
model_kwargs={"min_length":100,
|
54 |
+
"max_new_tokens":1024, "do_sample":True,
|
55 |
+
"temperature":0.1,
|
56 |
+
"top_k":50,
|
57 |
+
"top_p":0.95, "eos_token_id":49155})
|
58 |
+
|
59 |
+
prompt_template = """
|
60 |
+
You are a very helpful AI assistant. Please ONLY use {context} to answer the user's question {question}. If you don't know the answer, just say that you don't know. DON'T try to make up an answer.
|
61 |
+
Your response should be full and easy to understand.
|
62 |
+
"""
|
63 |
+
PROMPT = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
|
64 |
+
|
65 |
+
chain = load_qa_chain(llm=llm, chain_type="stuff", prompt=PROMPT)
|
66 |
+
|
67 |
app = FastAPI()
|
68 |
|
69 |
class FileToProcess(BaseModel):
|
|
|
107 |
text = page.extract_text()
|
108 |
if text:
|
109 |
raw_text += text
|
110 |
+
temp_texts = text_splitter.split_text(raw_text)
|
111 |
+
texts=temp_texts
|
112 |
+
initial_embeddings=get_embeddings(temp_texts)
|
113 |
+
db_embeddings = torch.FloatTensor(initial_embeddings)
|
114 |
+
print("db_embeddings created...")
|
115 |
+
|
116 |
+
#question = var_query.query
|
117 |
+
question = username
|
118 |
+
print("API Call Query Received: "+question)
|
119 |
+
q_embedding=get_embeddings(question)
|
120 |
+
final_q_embedding = torch.FloatTensor(q_embedding)
|
121 |
+
from sentence_transformers.util import semantic_search
|
122 |
+
hits = semantic_search(final_q_embedding, torch.FloatTensor(db_embeddings), top_k=5)
|
123 |
+
|
124 |
+
page_contents = []
|
125 |
+
for i in range(len(hits[0])):
|
126 |
+
page_content = texts[hits[0][i]['corpus_id']]
|
127 |
+
page_contents.append(page_content)
|
128 |
|
129 |
+
temp_page_contents=str(page_contents)
|
130 |
+
final_page_contents = temp_page_contents.replace('\\n', '')
|
131 |
+
random_string_2=generate_random_string(20)
|
132 |
+
file_path = random_string_2 + ".txt"
|
133 |
+
with open(file_path, "w", encoding="utf-8") as file:
|
134 |
+
file.write(final_page_contents)
|
135 |
+
|
136 |
+
loader = TextLoader(file_path, encoding="utf-8")
|
137 |
+
loaded_documents = loader.load()
|
138 |
+
|
139 |
+
temp_ai_response = chain({"input_documents": loaded_documents, "question": question}, return_only_outputs=False)
|
140 |
+
|
141 |
+
initial_ai_response=temp_ai_response['output_text']
|
142 |
+
|
143 |
+
cleaned_initial_ai_response = remove_context(initial_ai_response)
|
144 |
+
|
145 |
+
#final_ai_response = cleaned_initial_ai_response.partition('¿Cuál es')[0].strip().replace('\n\n', '\n').replace('<|end|>', '').replace('<|user|>', '').replace('<|system|>', '').replace('<|assistant|>', '')
|
146 |
+
final_ai_response = cleaned_initial_ai_response.partition('¿Cuál es')[0].strip()
|
147 |
+
final_ai_response = final_ai_response.partition('¿Cuáles')[0].strip()
|
148 |
+
final_ai_response = final_ai_response.partition('<|end|>')[0].strip().replace('\n\n', '\n').replace('<|end|>', '').replace('<|user|>', '').replace('<|system|>', '').replace('<|assistant|>', '')
|
149 |
+
new_final_ai_response = final_ai_response.split('Unhelpful Answer:')[0].strip()
|
150 |
+
new_final_ai_response = new_final_ai_response.split('Note:')[0].strip()
|
151 |
+
new_final_ai_response = new_final_ai_response.split('Please provide feedback on how to improve the chatbot.')[0].strip()
|
152 |
|
153 |
api_call_msg={"INFO": f"File '{file_saved_in_api}' saved to your profile."}
|
154 |
print(api_call_msg)
|