iShare commited on
Commit
eb1095f
·
1 Parent(s): 9a18337

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +82 -3
main.py CHANGED
@@ -8,6 +8,13 @@ from PyPDF2 import PdfReader
8
  from fastapi import Depends
9
  #在FastAPI中,Depends()函数用于声明依赖项
10
 
 
 
 
 
 
 
 
11
  import random
12
  import string
13
  import sys
@@ -19,10 +26,44 @@ import os
19
  from dotenv import load_dotenv
20
  load_dotenv()
21
 
 
 
 
 
 
 
 
 
 
22
  def generate_random_string(length):
23
  letters = string.ascii_lowercase
24
  return ''.join(random.choice(letters) for i in range(length))
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  app = FastAPI()
27
 
28
  class FileToProcess(BaseModel):
@@ -66,10 +107,48 @@ async def pdf_file_qa_process(username: str, request: Request, file_to_process:
66
  text = page.extract_text()
67
  if text:
68
  raw_text += text
69
- temp_texts = text_splitter.split_text(raw_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
- print(temp_texts)
72
- print()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
  api_call_msg={"INFO": f"File '{file_saved_in_api}' saved to your profile."}
75
  print(api_call_msg)
 
8
  from fastapi import Depends
9
  #在FastAPI中,Depends()函数用于声明依赖项
10
 
11
+ from langchain.chains.question_answering import load_qa_chain
12
+ from langchain import PromptTemplate, LLMChain
13
+ from langchain import HuggingFaceHub
14
+ from langchain.document_loaders import TextLoader
15
+ import torch
16
+
17
+ import requests
18
  import random
19
  import string
20
  import sys
 
26
  from dotenv import load_dotenv
27
  load_dotenv()
28
 
29
+ HUGGINGFACEHUB_API_TOKEN = os.getenv('HUGGINGFACEHUB_API_TOKEN')
30
+ model_id = os.getenv('model_id')
31
+ hf_token = os.getenv('hf_token')
32
+ repo_id = os.getenv('repo_id')
33
+
34
+ def get_embeddings(input_str_texts):
35
+ response = requests.post(api_url, headers=headers, json={"inputs": input_str_texts, "options":{"wait_for_model":True}})
36
+ return response.json()
37
+
38
  def generate_random_string(length):
39
  letters = string.ascii_lowercase
40
  return ''.join(random.choice(letters) for i in range(length))
41
 
42
+ def remove_context(text):
43
+ if 'Context:' in text:
44
+ end_of_context = text.find('\n\n')
45
+ return text[end_of_context + 2:]
46
+ else:
47
+ return text
48
+
49
+ api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{model_id}"
50
+ headers = {"Authorization": f"Bearer {hf_token}"}
51
+
52
+ llm = HuggingFaceHub(repo_id=repo_id,
53
+ model_kwargs={"min_length":100,
54
+ "max_new_tokens":1024, "do_sample":True,
55
+ "temperature":0.1,
56
+ "top_k":50,
57
+ "top_p":0.95, "eos_token_id":49155})
58
+
59
+ prompt_template = """
60
+ You are a very helpful AI assistant. Please ONLY use {context} to answer the user's question {question}. If you don't know the answer, just say that you don't know. DON'T try to make up an answer.
61
+ Your response should be full and easy to understand.
62
+ """
63
+ PROMPT = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
64
+
65
+ chain = load_qa_chain(llm=llm, chain_type="stuff", prompt=PROMPT)
66
+
67
  app = FastAPI()
68
 
69
  class FileToProcess(BaseModel):
 
107
  text = page.extract_text()
108
  if text:
109
  raw_text += text
110
+ temp_texts = text_splitter.split_text(raw_text)
111
+ texts=temp_texts
112
+ initial_embeddings=get_embeddings(temp_texts)
113
+ db_embeddings = torch.FloatTensor(initial_embeddings)
114
+ print("db_embeddings created...")
115
+
116
+ #question = var_query.query
117
+ question = username
118
+ print("API Call Query Received: "+question)
119
+ q_embedding=get_embeddings(question)
120
+ final_q_embedding = torch.FloatTensor(q_embedding)
121
+ from sentence_transformers.util import semantic_search
122
+ hits = semantic_search(final_q_embedding, torch.FloatTensor(db_embeddings), top_k=5)
123
+
124
+ page_contents = []
125
+ for i in range(len(hits[0])):
126
+ page_content = texts[hits[0][i]['corpus_id']]
127
+ page_contents.append(page_content)
128
 
129
+ temp_page_contents=str(page_contents)
130
+ final_page_contents = temp_page_contents.replace('\\n', '')
131
+ random_string_2=generate_random_string(20)
132
+ file_path = random_string_2 + ".txt"
133
+ with open(file_path, "w", encoding="utf-8") as file:
134
+ file.write(final_page_contents)
135
+
136
+ loader = TextLoader(file_path, encoding="utf-8")
137
+ loaded_documents = loader.load()
138
+
139
+ temp_ai_response = chain({"input_documents": loaded_documents, "question": question}, return_only_outputs=False)
140
+
141
+ initial_ai_response=temp_ai_response['output_text']
142
+
143
+ cleaned_initial_ai_response = remove_context(initial_ai_response)
144
+
145
+ #final_ai_response = cleaned_initial_ai_response.partition('¿Cuál es')[0].strip().replace('\n\n', '\n').replace('<|end|>', '').replace('<|user|>', '').replace('<|system|>', '').replace('<|assistant|>', '')
146
+ final_ai_response = cleaned_initial_ai_response.partition('¿Cuál es')[0].strip()
147
+ final_ai_response = final_ai_response.partition('¿Cuáles')[0].strip()
148
+ final_ai_response = final_ai_response.partition('<|end|>')[0].strip().replace('\n\n', '\n').replace('<|end|>', '').replace('<|user|>', '').replace('<|system|>', '').replace('<|assistant|>', '')
149
+ new_final_ai_response = final_ai_response.split('Unhelpful Answer:')[0].strip()
150
+ new_final_ai_response = new_final_ai_response.split('Note:')[0].strip()
151
+ new_final_ai_response = new_final_ai_response.split('Please provide feedback on how to improve the chatbot.')[0].strip()
152
 
153
  api_call_msg={"INFO": f"File '{file_saved_in_api}' saved to your profile."}
154
  print(api_call_msg)