|
|
|
import gradio as gr |
|
import os |
|
import json |
|
from glob import glob |
|
import requests |
|
from langchain import FAISS |
|
from langchain.embeddings import CohereEmbeddings, OpenAIEmbeddings |
|
from langchain import VectorDBQA |
|
from langchain.chat_models import ChatOpenAI |
|
from prompts import MyTemplate |
|
from build_index.run import process_files |
|
from langchain.prompts.chat import ( |
|
ChatPromptTemplate, |
|
SystemMessagePromptTemplate, |
|
HumanMessagePromptTemplate, |
|
) |
|
from langchain.prompts import PromptTemplate |
|
from langchain.chains.llm import LLMChain |
|
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain |
|
from langchain.chains import QAGenerationChain |
|
from langchain.chains.combine_documents.stuff import StuffDocumentsChain |
|
|
|
|
|
API_URL = "https://api.openai.com/v1/chat/completions" |
|
cohere_key = '5IRbILAbjTI0VcqTsktBfKsr13Lych9iBAFbLpkj' |
|
faiss_store = './output/' |
|
global doc_search |
|
|
|
|
|
def process(files, openai_api_key, max_tokens, n_sample): |
|
""" |
|
对文档处理进行摘要,构建问题,构建文档索引 |
|
""" |
|
os.environ['OPENAI_API_KEY'] = openai_api_key |
|
print('Displaying uploading files ') |
|
print(glob('/tmp/*')) |
|
docs = process_files([i.name for i in files], 'openai', max_tokens) |
|
print('Display Faiss index') |
|
print(glob('./output/*')) |
|
question = get_question(docs, openai_api_key, max_tokens, n_sample) |
|
summary = get_summary(docs, openai_api_key, max_tokens, n_sample) |
|
return question, summary |
|
|
|
|
|
def get_question(docs, openai_api_key, max_tokens, n_sample=5): |
|
q_list = [] |
|
llm = ChatOpenAI(openai_api_key=openai_api_key, max_tokens=max_tokens, temperature=0) |
|
|
|
prompt = ChatPromptTemplate.from_messages( |
|
[ |
|
SystemMessagePromptTemplate.from_template(MyTemplate['qa_sys_template']), |
|
HumanMessagePromptTemplate.from_template(MyTemplate['qa_user_template']), |
|
] |
|
) |
|
chain = QAGenerationChain.from_llm(llm, prompt=prompt) |
|
print('Generating Question from template') |
|
for i in range(n_sample): |
|
qa = chain.run(docs[i].page_content)[0] |
|
print(qa) |
|
q_list.append(f"问题{i + 1}: {qa['question']}") |
|
return '\n'.join(q_list) |
|
|
|
|
|
def get_summary(docs, openai_api_key, max_tokens, n_sample=5, verbose=None): |
|
llm = ChatOpenAI(openai_api_key=openai_api_key, max_tokens=max_tokens) |
|
|
|
|
|
print('Generating Summary from template') |
|
map_prompt = PromptTemplate(template=MyTemplate['summary_template'], input_variables=["text"]) |
|
combine_prompt = PromptTemplate(template=MyTemplate['summary_template'], input_variables=["text"]) |
|
map_chain = LLMChain(llm=llm, prompt=map_prompt, verbose=verbose) |
|
reduce_chain = LLMChain(llm=llm, prompt=combine_prompt, verbose=verbose) |
|
combine_document_chain = StuffDocumentsChain( |
|
llm_chain=reduce_chain, |
|
document_variable_name='text', |
|
verbose=verbose, |
|
) |
|
chain = MapReduceDocumentsChain( |
|
llm_chain=map_chain, |
|
combine_document_chain=combine_document_chain, |
|
document_variable_name='text', |
|
collapse_document_chain=None, |
|
verbose=verbose |
|
) |
|
summary = chain.run(docs[:n_sample]) |
|
print(summary) |
|
return summary |
|
|
|
|
|
def predict(inputs, openai_api_key, max_tokens, chat_counter, chatbot=[], history=[]): |
|
print(f"chat_counter - {chat_counter}") |
|
print(f'Histroy - {history}') |
|
print(f'chatbot - {chatbot}') |
|
|
|
history.append(inputs) |
|
if doc_search is None: |
|
print(f'loading faiss store from {faiss_store}') |
|
docsearch = FAISS.load_local(faiss_store, OpenAIEmbeddings(openai_api_key=openai_api_key)) |
|
else: |
|
print('faiss already loaded') |
|
|
|
llm = ChatOpenAI(openai_api_key=openai_api_key, max_tokens=max_tokens) |
|
messages_combine = [ |
|
SystemMessagePromptTemplate.from_template(MyTemplate['chat_combine_template']), |
|
HumanMessagePromptTemplate.from_template("{question}") |
|
] |
|
p_chat_combine = ChatPromptTemplate.from_messages(messages_combine) |
|
messages_reduce = [ |
|
SystemMessagePromptTemplate.from_template(MyTemplate['chat_reduce_template']), |
|
HumanMessagePromptTemplate.from_template("{question}") |
|
] |
|
p_chat_reduce = ChatPromptTemplate.from_messages(messages_reduce) |
|
chain = VectorDBQA.from_chain_type(llm=llm, chain_type="map_reduce", vectorstore=docsearch, |
|
k=4, |
|
chain_type_kwargs={"question_prompt": p_chat_reduce, |
|
"combine_prompt": p_chat_combine} |
|
) |
|
result = chain({"query": inputs}) |
|
print(result) |
|
result = result['result'] |
|
|
|
history.append(result) |
|
chat = [(history[i], history[i + 1]) for i in range(0, len(history) - 1, 2)] |
|
chat_counter += 1 |
|
yield chat, history, chat_counter |
|
|
|
|
|
def reset_textbox(): |
|
return gr.update(value='') |
|
|
|
|
|
with gr.Blocks(css="""#col_container {width: 1000px; margin-left: auto; margin-right: auto;} |
|
#chatbot {height: 520px; overflow: auto;}""") as demo: |
|
gr.HTML("""<h1 align="center">🚀Smart Doc Reader🚀</h1>""") |
|
with gr.Column(elem_id="col_container"): |
|
openai_api_key = gr.Textbox(type='password', label="输入 API Key") |
|
|
|
with gr.Accordion("Parameters", open=True): |
|
with gr.Row(): |
|
max_tokens = gr.Slider(minimum=100, maximum=2000, value=1000, step=100, interactive=True, |
|
label="字数") |
|
chat_counter = gr.Number(value=0, precision=0, label='对话轮数') |
|
n_sample = gr.Slider(minimum=3, maximum=5, value=3, step=1, interactive=True, |
|
label="问题数") |
|
|
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
files = gr.File(file_count="multiple", file_types=[".pdf"], label='上传pdf文件') |
|
run = gr.Button('文档内容解读') |
|
|
|
with gr.Column(): |
|
summary = gr.Textbox(type='text', label="一眼看尽 - 文档概览") |
|
question = gr.Textbox(type='text', label='推荐问题 - 问别的也行哟') |
|
|
|
chatbot = gr.Chatbot(elem_id='chatbot') |
|
inputs = gr.Textbox(placeholder="这篇文档是关于什么的", label="针对文档你有哪些问题?") |
|
state = gr.State([]) |
|
|
|
with gr.Row(): |
|
clear = gr.Button("清空") |
|
start = gr.Button("提问") |
|
|
|
run.click(process, [files, openai_api_key, max_tokens, n_sample], [question, summary]) |
|
inputs.submit(predict, |
|
[inputs, openai_api_key, max_tokens, chat_counter, chatbot, state], |
|
[chatbot, state, chat_counter], ) |
|
start.click(predict, |
|
[inputs, openai_api_key, max_tokens, chat_counter, chatbot, state], |
|
[chatbot, state, chat_counter], ) |
|
|
|
|
|
clear.click(reset_textbox, [], [inputs], queue=False) |
|
inputs.submit(reset_textbox, [], [inputs]) |
|
demo.queue().launch(debug=True) |
|
|