FinDoc / app.py
xl2533's picture
global variable
89ede82
# -*-coding:utf-8 -*-
import gradio as gr
import os
import json
from glob import glob
import requests
from langchain import FAISS
from langchain.embeddings import CohereEmbeddings, OpenAIEmbeddings
from langchain import VectorDBQA
from langchain.chat_models import ChatOpenAI
from prompts import MyTemplate
from build_index.run import process_files
from langchain.prompts.chat import (
ChatPromptTemplate,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
)
from langchain.prompts import PromptTemplate
from langchain.chains.llm import LLMChain
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
from langchain.chains import QAGenerationChain
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
# Streaming endpoint
API_URL = "https://api.openai.com/v1/chat/completions"
cohere_key = '5IRbILAbjTI0VcqTsktBfKsr13Lych9iBAFbLpkj'
faiss_store = './output/'
docsearch = None
def process(files, openai_api_key, max_tokens, n_sample):
"""
对文档处理进行摘要,构建问题,构建文档索引
"""
os.environ['OPENAI_API_KEY'] = openai_api_key
print('Displaying uploading files ')
print(glob('/tmp/*'))
docs = process_files([i.name for i in files], 'openai', max_tokens)
print('Display Faiss index')
print(glob('./output/*'))
question = get_question(docs, openai_api_key, max_tokens, n_sample)
summary = get_summary(docs, openai_api_key, max_tokens, n_sample)
return question, summary
def get_question(docs, openai_api_key, max_tokens, n_sample=5):
q_list = []
llm = ChatOpenAI(openai_api_key=openai_api_key, max_tokens=max_tokens, temperature=0)
# 基于文档进行QA生成
prompt = ChatPromptTemplate.from_messages(
[
SystemMessagePromptTemplate.from_template(MyTemplate['qa_sys_template']),
HumanMessagePromptTemplate.from_template(MyTemplate['qa_user_template']),
]
)
chain = QAGenerationChain.from_llm(llm, prompt=prompt)
print('Generating Question from template')
for i in range(n_sample):
qa = chain.run(docs[i].page_content)[0]
print(qa)
q_list.append(f"问题{i + 1}: {qa['question']}")
return '\n'.join(q_list)
def get_summary(docs, openai_api_key, max_tokens, n_sample=5, verbose=None):
llm = ChatOpenAI(openai_api_key=openai_api_key, max_tokens=max_tokens)
print('Generating Summary from template')
map_prompt = PromptTemplate(template=MyTemplate['summary_template'], input_variables=["text"])
combine_prompt = PromptTemplate(template=MyTemplate['summary_template'], input_variables=["text"])
map_chain = LLMChain(llm=llm, prompt=map_prompt, verbose=verbose)
reduce_chain = LLMChain(llm=llm, prompt=combine_prompt, verbose=verbose)
combine_document_chain = StuffDocumentsChain(
llm_chain=reduce_chain,
document_variable_name='text',
verbose=verbose,
)
chain = MapReduceDocumentsChain(
llm_chain=map_chain,
combine_document_chain=combine_document_chain,
document_variable_name='text',
collapse_document_chain=None,
verbose=verbose
)
summary = chain.run(docs[:n_sample])
print(summary)
return summary
def predict(inputs, openai_api_key, max_tokens, chat_counter, chatbot=[], history=[]):
global docsearch
print(f"chat_counter - {chat_counter}")
print(f'Histroy - {history}') # History: Original Input and Output in flatten list
print(f'chatbot - {chatbot}') # Chat Bot: 上一轮回复的[[user, AI]]
history.append(inputs)
if docsearch is None:
print(f'loading faiss store from {faiss_store}')
docsearch = FAISS.load_local(faiss_store, OpenAIEmbeddings(openai_api_key=openai_api_key))
else:
print('faiss already loaded')
# 构建模板
llm = ChatOpenAI(openai_api_key=openai_api_key, max_tokens=max_tokens)
messages_combine = [
SystemMessagePromptTemplate.from_template(MyTemplate['chat_combine_template']),
HumanMessagePromptTemplate.from_template("{question}")
]
p_chat_combine = ChatPromptTemplate.from_messages(messages_combine)
messages_reduce = [
SystemMessagePromptTemplate.from_template(MyTemplate['chat_reduce_template']),
HumanMessagePromptTemplate.from_template("{question}")
]
p_chat_reduce = ChatPromptTemplate.from_messages(messages_reduce)
chain = VectorDBQA.from_chain_type(llm=llm, chain_type="map_reduce", vectorstore=docsearch,
k=4,
chain_type_kwargs={"question_prompt": p_chat_reduce,
"combine_prompt": p_chat_combine}
)
result = chain({"query": inputs})
print(result)
result = result['result']
# 生成返回值
history.append(result)
chat = [(history[i], history[i + 1]) for i in range(0, len(history) - 1, 2)]
chat_counter += 1
yield chat, history, chat_counter
def reset_textbox():
return gr.update(value='')
with gr.Blocks(css="""#col_container {width: 1000px; margin-left: auto; margin-right: auto;}
#chatbot {height: 520px; overflow: auto;}""") as demo:
gr.HTML("""<h1 align="center">🚀Smart Doc Reader🚀</h1>""")
with gr.Column(elem_id="col_container"):
openai_api_key = gr.Textbox(type='password', label="输入 API Key")
with gr.Accordion("Parameters", open=True):
with gr.Row():
max_tokens = gr.Slider(minimum=100, maximum=2000, value=1000, step=100, interactive=True,
label="字数")
chat_counter = gr.Number(value=0, precision=0, label='对话轮数')
n_sample = gr.Slider(minimum=3, maximum=5, value=3, step=1, interactive=True,
label="问题数")
# 输入文件,进行摘要和问题生成
with gr.Row():
with gr.Column():
files = gr.File(file_count="multiple", file_types=[".pdf"], label='上传pdf文件')
run = gr.Button('文档内容解读')
with gr.Column():
summary = gr.Textbox(type='text', label="一眼看尽 - 文档概览")
question = gr.Textbox(type='text', label='推荐问题 - 问别的也行哟')
chatbot = gr.Chatbot(elem_id='chatbot')
inputs = gr.Textbox(placeholder="这篇文档是关于什么的", label="针对文档你有哪些问题?")
state = gr.State([])
with gr.Row():
clear = gr.Button("清空")
start = gr.Button("提问")
run.click(process, [files, openai_api_key, max_tokens, n_sample], [question, summary])
inputs.submit(predict,
[inputs, openai_api_key, max_tokens, chat_counter, chatbot, state],
[chatbot, state, chat_counter], )
start.click(predict,
[inputs, openai_api_key, max_tokens, chat_counter, chatbot, state],
[chatbot, state, chat_counter], )
# 每次对话结束都重置对话
clear.click(reset_textbox, [], [inputs], queue=False)
inputs.submit(reset_textbox, [], [inputs])
demo.queue().launch(debug=True)