initial
Browse files- app.py +153 -0
- build_index/__init__.py +1 -0
- build_index/base.py +88 -0
- build_index/doc2vec.py +52 -0
- build_index/parser/__init__.py +7 -0
- build_index/parser/base.py +39 -0
- build_index/parser/html_parser.py +48 -0
- build_index/parser/pdf_parser.py +27 -0
- build_index/pricing.py +23 -0
- build_index/process.py +78 -0
- build_index/run.py +56 -0
- build_index/unit_test/__init__.py +1 -0
- build_index/unit_test/test_faiss.py +33 -0
- build_index/unit_test/test_loader.py +1 -0
- key.py +4 -0
- prompts/__init__.py +22 -0
- prompts/chat_combine_prompt.txt +4 -0
- prompts/chat_reduce_prompt.txt +3 -0
- prompts/combine_prompt.txt +25 -0
- prompts/combine_prompt_hist.txt +33 -0
- prompts/question_prompt.txt +4 -0
- prompts/style.py +14 -0
- requirements.txt +10 -0
app.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*-coding:utf-8 -*-
|
2 |
+
import gradio as gr
|
3 |
+
import os
|
4 |
+
import json
|
5 |
+
from glob import glob
|
6 |
+
import requests
|
7 |
+
from langchain import FAISS
|
8 |
+
from langchain.embeddings import CohereEmbeddings, OpenAIEmbeddings
|
9 |
+
from langchain import VectorDBQA
|
10 |
+
from langchain.chat_models import ChatOpenAI
|
11 |
+
from prompts import MyTemplate
|
12 |
+
from build_index.run import process_files
|
13 |
+
from langchain.prompts.chat import (
|
14 |
+
ChatPromptTemplate,
|
15 |
+
SystemMessagePromptTemplate,
|
16 |
+
HumanMessagePromptTemplate,
|
17 |
+
)
|
18 |
+
|
19 |
+
from langchain.chains.summarize import load_summarize_chain
|
20 |
+
from langchain.chains import QAGenerationChain
|
21 |
+
|
22 |
+
# Streaming endpoint
|
23 |
+
API_URL = "https://api.openai.com/v1/chat/completions"
|
24 |
+
cohere_key = '5IRbILAbjTI0VcqTsktBfKsr13Lych9iBAFbLpkj'
|
25 |
+
faiss_store = './output/'
|
26 |
+
|
27 |
+
|
28 |
+
def process(files, model, max_tokens, openai_api_key, n_sample):
|
29 |
+
"""
|
30 |
+
对文档处理进行摘要,构建问题,构建文档索引
|
31 |
+
"""
|
32 |
+
model = model[0]
|
33 |
+
os.environ['OPENAI_API_KEY'] = openai_api_key
|
34 |
+
print('Displaying uploading files ')
|
35 |
+
print(glob('/tmp/*'))
|
36 |
+
docs = process_files([i.name for i in files], model, max_tokens)
|
37 |
+
print('Display Faiss index')
|
38 |
+
print(glob('./output/*'))
|
39 |
+
question = get_question(docs, openai_api_key, max_tokens, n_sample)
|
40 |
+
summary = get_summary(docs, openai_api_key, max_tokens, n_sample)
|
41 |
+
return question, summary
|
42 |
+
|
43 |
+
|
44 |
+
def get_question(docs, openai_api_key, max_tokens, n_sample=5):
|
45 |
+
q_list = []
|
46 |
+
llm = ChatOpenAI(openai_api_key=openai_api_key, max_tokens=max_tokens, temperature=0)
|
47 |
+
chain = QAGenerationChain.from_llm(llm)
|
48 |
+
print('Generating Question from template')
|
49 |
+
for i in range(n_sample):
|
50 |
+
qa = chain.run(docs[i].page_content)[0]
|
51 |
+
print(qa)
|
52 |
+
q_list.append(f"问题{i+1}: {qa['question']}" )
|
53 |
+
return '\n'.join(q_list)
|
54 |
+
|
55 |
+
|
56 |
+
def get_summary(docs, openai_api_key, max_tokens, n_sample=5):
|
57 |
+
llm = ChatOpenAI(openai_api_key=openai_api_key, max_tokens=max_tokens)
|
58 |
+
chain = load_summarize_chain(llm, chain_type="map_reduce")
|
59 |
+
print('Generating Summary from tempalte')
|
60 |
+
summary = chain.run(docs[:n_sample])
|
61 |
+
print(summary)
|
62 |
+
return summary
|
63 |
+
|
64 |
+
|
65 |
+
def predict(inputs, openai_api_key, max_tokens, model, chat_counter, chatbot=[], history=[]):
|
66 |
+
model = model[0]
|
67 |
+
print(f"chat_counter - {chat_counter}")
|
68 |
+
print(f'Histroy - {history}') # History: Original Input and Output in flatten list
|
69 |
+
print(f'chatbot - {chatbot}') # Chat Bot: 上一轮回复的[[user, AI]]
|
70 |
+
|
71 |
+
history.append(inputs)
|
72 |
+
print(f'loading faiss store from {faiss_store}')
|
73 |
+
if model == 'openai':
|
74 |
+
docsearch = FAISS.load_local(faiss_store, OpenAIEmbeddings(openai_api_key=openai_api_key))
|
75 |
+
else:
|
76 |
+
docsearch = FAISS.load_local(faiss_store, CohereEmbeddings(cohere_api_key=cohere_key))
|
77 |
+
# 构建模板
|
78 |
+
llm = ChatOpenAI(openai_api_key=openai_api_key, max_tokens=max_tokens)
|
79 |
+
messages_combine = [
|
80 |
+
SystemMessagePromptTemplate.from_template(MyTemplate['chat_combine_template']),
|
81 |
+
HumanMessagePromptTemplate.from_template("{question}")
|
82 |
+
]
|
83 |
+
p_chat_combine = ChatPromptTemplate.from_messages(messages_combine)
|
84 |
+
messages_reduce = [
|
85 |
+
SystemMessagePromptTemplate.from_template(MyTemplate['chat_reduce_template']),
|
86 |
+
HumanMessagePromptTemplate.from_template("{question}")
|
87 |
+
]
|
88 |
+
p_chat_reduce = ChatPromptTemplate.from_messages(messages_reduce)
|
89 |
+
chain = VectorDBQA.from_chain_type(llm=llm, chain_type="map_reduce", vectorstore=docsearch,
|
90 |
+
k=4,
|
91 |
+
chain_type_kwargs={"question_prompt": p_chat_reduce,
|
92 |
+
"combine_prompt": p_chat_combine}
|
93 |
+
)
|
94 |
+
result = chain({"query": inputs})
|
95 |
+
print(result)
|
96 |
+
result = result['result']
|
97 |
+
# 生成返回值
|
98 |
+
history.append(result)
|
99 |
+
chat = [(history[i], history[i + 1]) for i in range(0, len(history) - 1, 2)]
|
100 |
+
chat_counter += 1
|
101 |
+
yield chat, history, chat_counter
|
102 |
+
|
103 |
+
|
104 |
+
def reset_textbox():
|
105 |
+
return gr.update(value='')
|
106 |
+
|
107 |
+
|
108 |
+
with gr.Blocks(css="""#col_container {width: 1000px; margin-left: auto; margin-right: auto;}
|
109 |
+
#chatbot {height: 520px; overflow: auto;}""") as demo:
|
110 |
+
gr.HTML("""<h1 align="center">🚀Your Doc Reader🚀</h1>""")
|
111 |
+
with gr.Column(elem_id="col_container"):
|
112 |
+
openai_api_key = gr.Textbox(type='password', label="输入 API Key")
|
113 |
+
|
114 |
+
with gr.Accordion("Parameters", open=True):
|
115 |
+
with gr.Row():
|
116 |
+
max_tokens = gr.Slider(minimum=100, maximum=2000, value=1000, step=100, interactive=True,
|
117 |
+
label="字数")
|
118 |
+
model = gr.CheckboxGroup(["cohere", "openai"])
|
119 |
+
chat_counter = gr.Number(value=0, precision=0, label='对话轮数')
|
120 |
+
n_sample = gr.Slider(minimum=3, maximum=5, value=3, step=1, interactive=True,
|
121 |
+
label="问题数")
|
122 |
+
|
123 |
+
# 输入文件,进行摘要和��题生成
|
124 |
+
with gr.Row():
|
125 |
+
with gr.Column():
|
126 |
+
files = gr.File(file_count="multiple", file_types=[".pdf"], label='上传pdf文件')
|
127 |
+
run = gr.Button('研报解读')
|
128 |
+
|
129 |
+
with gr.Column():
|
130 |
+
summary = gr.Textbox(type='text', label="本文摘要")
|
131 |
+
question = gr.Textbox(type='text', label='提问问题')
|
132 |
+
|
133 |
+
chatbot = gr.Chatbot(elem_id='chatbot')
|
134 |
+
inputs = gr.Textbox(placeholder="这篇文档是关于什么的", label="针对文档你有哪些问题?")
|
135 |
+
state = gr.State([])
|
136 |
+
|
137 |
+
with gr.Row():
|
138 |
+
clear = gr.Button("清空")
|
139 |
+
start = gr.Button("提问")
|
140 |
+
|
141 |
+
run.click(process, [files, openai_api_key, max_tokens, model, n_sample], [question, summary])
|
142 |
+
inputs.submit(predict,
|
143 |
+
[inputs, openai_api_key, max_tokens, model, chat_counter, chatbot, state],
|
144 |
+
[chatbot, state, chat_counter], )
|
145 |
+
start.click(predict,
|
146 |
+
[inputs, openai_api_key, max_tokens, model, chat_counter, chatbot, state],
|
147 |
+
[chatbot, state, chat_counter], )
|
148 |
+
|
149 |
+
# 每次对话结束都重置对话
|
150 |
+
clear.click(reset_textbox, [], [inputs], queue=False)
|
151 |
+
inputs.submit(reset_textbox, [], [inputs])
|
152 |
+
|
153 |
+
demo.queue().launch(debug=True)
|
build_index/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# -*-coding:utf-8 -*-
|
build_index/base.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*-coding:utf-8 -*-
|
2 |
+
"""
|
3 |
+
Base Reader and Document
|
4 |
+
"""
|
5 |
+
import os
|
6 |
+
from dataclasses import dataclass
|
7 |
+
from dataclasses_json import dataclass_json
|
8 |
+
from typing import Any, Dict, List, Optional
|
9 |
+
from glob import glob
|
10 |
+
from build_index.parser import ParserFactory
|
11 |
+
from langchain.docstore.document import Document as LCDocument
|
12 |
+
|
13 |
+
|
14 |
+
@dataclass_json
|
15 |
+
@dataclass
|
16 |
+
class Document:
|
17 |
+
text: str = None
|
18 |
+
doc_id: Optional[str] = None
|
19 |
+
embedding: Optional[List[float]] = None
|
20 |
+
extra_info: Optional[Dict[str, Any]] = None
|
21 |
+
|
22 |
+
def get_text(self):
|
23 |
+
return self.text
|
24 |
+
|
25 |
+
def get_doc_id(self):
|
26 |
+
return self.doc_id
|
27 |
+
|
28 |
+
def get_embedding(self):
|
29 |
+
return self.embedding
|
30 |
+
|
31 |
+
@property
|
32 |
+
def extra_info_str(self) -> Optional[str]:
|
33 |
+
"""Extra info string."""
|
34 |
+
if self.extra_info is None:
|
35 |
+
return None
|
36 |
+
|
37 |
+
return "\n".join([f"{k}: {str(v)}" for k, v in self.extra_info.items()])
|
38 |
+
|
39 |
+
def __post_init__(self):
|
40 |
+
#字段检查
|
41 |
+
assert self.text is not None, 'Text Field can not be None'
|
42 |
+
|
43 |
+
def to_langchain_format(self):
|
44 |
+
"""Convert struct to LangChain document format."""
|
45 |
+
metadata = self.extra_info or {}
|
46 |
+
return LCDocument(page_content=self.text, metadata=metadata)
|
47 |
+
|
48 |
+
|
49 |
+
class FileReader(object):
|
50 |
+
"""
|
51 |
+
Load file from ./data_dir
|
52 |
+
"""
|
53 |
+
def __init__(self, data_dir=None, folder_name=None, input_files=None, has_meta=True):
|
54 |
+
self.data_dir = data_dir
|
55 |
+
self.has_meta = has_meta
|
56 |
+
|
57 |
+
if input_files:
|
58 |
+
self.input_files = input_files
|
59 |
+
else:
|
60 |
+
# get all file in data_dir
|
61 |
+
##TODO: 暂不支持data下recursive dir
|
62 |
+
dir = os.path.join(data_dir, folder_name, '*')
|
63 |
+
self.input_files = glob(dir)
|
64 |
+
print(f'{len(self.input_files)} files in {dir}')
|
65 |
+
print(self.input_files)
|
66 |
+
|
67 |
+
def load_data(self, concatenate=False) -> List[Document]:
|
68 |
+
data_list = []
|
69 |
+
metadata_list = []
|
70 |
+
for file in self.input_files:
|
71 |
+
parser = ParserFactory['pdf']
|
72 |
+
if parser is None:
|
73 |
+
raise ValueError(f"{file} format doesn't match any sufix supported")
|
74 |
+
try:
|
75 |
+
data, meta = parser.parse_file(file)
|
76 |
+
except Exception as e:
|
77 |
+
print(f'{file} parse failed. error = {e}')
|
78 |
+
continue
|
79 |
+
data_list.append(data)
|
80 |
+
if self.has_meta:
|
81 |
+
metadata_list.append(meta)
|
82 |
+
|
83 |
+
if concatenate:
|
84 |
+
return [Document("\n".join(data_list))]
|
85 |
+
elif self.has_meta:
|
86 |
+
return [Document(d, extra_info=m) for d, m in zip(data_list, metadata_list)]
|
87 |
+
else:
|
88 |
+
return [Document(d) for d in data_list]
|
build_index/doc2vec.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*-coding:utf-8 -*-
|
2 |
+
import os
|
3 |
+
from tqdm import tqdm
|
4 |
+
from langchain.vectorstores import FAISS
|
5 |
+
from langchain.embeddings import OpenAIEmbeddings, HuggingFaceEmbeddings, CohereEmbeddings
|
6 |
+
from retry import retry
|
7 |
+
from key import CoherenceKey, OpenaiKey
|
8 |
+
|
9 |
+
# Output Directory for FAISS Index data
|
10 |
+
OUTPUT_DIR = './output/'
|
11 |
+
|
12 |
+
|
13 |
+
@retry(tries=10, delay=60)
|
14 |
+
def store_add_texts_with_retry(store, i):
|
15 |
+
store.add_texts([i.page_content], metadatas=[i.metadata])
|
16 |
+
|
17 |
+
|
18 |
+
def doc2vec(docs, model, folder_name=None):
|
19 |
+
if folder_name:
|
20 |
+
dir = os.path.join(OUTPUT_DIR, folder_name)
|
21 |
+
else:
|
22 |
+
dir = OUTPUT_DIR
|
23 |
+
# use first document to init db, 1个1个文件处理避免中间出现问题需要重头尝试
|
24 |
+
print(f'Building faiss Index from {len(docs)} docs')
|
25 |
+
docs_test = [docs[0]]
|
26 |
+
docs.pop(0)
|
27 |
+
index = 0
|
28 |
+
|
29 |
+
print(f'Dumping FAISS to {dir}')
|
30 |
+
if model =='openai':
|
31 |
+
key = os.getenv('OPENAI_API_KEY')
|
32 |
+
db = FAISS.from_documents(docs_test, OpenAIEmbeddings(openai_api_key=key))
|
33 |
+
elif model =='mpnet':
|
34 |
+
db = FAISS.from_documents(docs_test, HuggingFaceEmbeddings())
|
35 |
+
elif model =='cohere':
|
36 |
+
db = FAISS.from_documents(docs_test, CohereEmbeddings(cohere_api_key=CoherenceKey))
|
37 |
+
else:
|
38 |
+
raise ValueError(f'Embedding Model {model} not supported')
|
39 |
+
|
40 |
+
for doc in tqdm(docs, desc="Embedding 🦖", unit="docs", total=len(docs),
|
41 |
+
bar_format='{l_bar}{bar}| Time Left: {remaining}'):
|
42 |
+
try:
|
43 |
+
store_add_texts_with_retry(db, doc)
|
44 |
+
except Exception as e:
|
45 |
+
print(e)
|
46 |
+
print("Error on ", doc)
|
47 |
+
print("Saving progress")
|
48 |
+
print(f"stopped at {index} out of {len(docs)}")
|
49 |
+
db.save_local(dir)
|
50 |
+
break
|
51 |
+
index += 1
|
52 |
+
db.save_local(dir)
|
build_index/parser/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*-coding:utf-8 -*-
|
2 |
+
from .html_parser import *
|
3 |
+
from .pdf_parser import *
|
4 |
+
ParserFactory = {
|
5 |
+
'html': HTMLParser(),
|
6 |
+
'pdf': PDFParser()
|
7 |
+
}
|
build_index/parser/base.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*-coding:utf-8 -*-
|
2 |
+
"""
|
3 |
+
解析不同格式的文档
|
4 |
+
"""
|
5 |
+
import re
|
6 |
+
import abc
|
7 |
+
from typing import Union, List
|
8 |
+
|
9 |
+
|
10 |
+
class BaseParser(abc.ABC):
|
11 |
+
|
12 |
+
def __init__(self, config=None):
|
13 |
+
self.config = config
|
14 |
+
|
15 |
+
|
16 |
+
@staticmethod
|
17 |
+
def full2half(text):
|
18 |
+
s = ''
|
19 |
+
for c in text:
|
20 |
+
num = ord(c)
|
21 |
+
if num == 0x3000:
|
22 |
+
num = 0x20
|
23 |
+
elif 0xFF01 <= num <= 0xFF5E:
|
24 |
+
num = num - 0xFEE0
|
25 |
+
s += chr(num)
|
26 |
+
return s
|
27 |
+
|
28 |
+
@staticmethod
|
29 |
+
def remove_dup_space(text):
|
30 |
+
text = re.sub(r'\s{1,}', '', text, flags=re.MULTILINE|re.DOTALL) # Extra Spacee Remove
|
31 |
+
return text
|
32 |
+
|
33 |
+
@staticmethod
|
34 |
+
def remove_empty_line(text):
|
35 |
+
text = re.sub(r'\n', '', text, flags=re.MULTILINE|re.DOTALL) # Extra Spacee Remove
|
36 |
+
return text
|
37 |
+
|
38 |
+
def parse_file(self, file):
|
39 |
+
raise NotImplementedError
|
build_index/parser/html_parser.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*-coding:utf-8 -*-
|
2 |
+
|
3 |
+
import re
|
4 |
+
from unstructured.partition.html import partition_html
|
5 |
+
from unstructured.staging.base import convert_to_isd
|
6 |
+
from unstructured.cleaners.core import clean
|
7 |
+
from build_index.parser.base import BaseParser
|
8 |
+
|
9 |
+
|
10 |
+
class HTMLParser(BaseParser):
|
11 |
+
def parse_file(self, file):
|
12 |
+
with open(file, "r", encoding="utf-8") as fp:
|
13 |
+
elements = partition_html(file=fp)
|
14 |
+
isd = convert_to_isd(elements)
|
15 |
+
|
16 |
+
for isd_el in isd:
|
17 |
+
isd_el['text'] = isd_el['text'].encode("ascii", "ignore").decode()
|
18 |
+
isd_el['text'] = self.remove_dup_space(isd_el['text'])
|
19 |
+
isd_el['text'] = self.remove_empty_line(isd_el['text'])
|
20 |
+
clean(isd_el['text'], extra_whitespace=True, dashes=True, bullets=True, trailing_punctuation=True )
|
21 |
+
|
22 |
+
# Creating a list of all the indexes of isd_el['type'] = 'Title'
|
23 |
+
title_indexes = [i for i, isd_el in enumerate(isd) if isd_el['type'] == 'Title']
|
24 |
+
|
25 |
+
# Creating 'Chunks' - List of lists of strings
|
26 |
+
# each list starting with with isd_el['type'] = 'Title' and all the data till the next 'Title'
|
27 |
+
# Each Chunk can be thought of as an individual set of data, which can be sent to the model
|
28 |
+
# Where Each Title is grouped together with the data under it
|
29 |
+
|
30 |
+
Chunks = [[]]
|
31 |
+
final_chunks = list(list())
|
32 |
+
|
33 |
+
for i, isd_el in enumerate(isd):
|
34 |
+
if i in title_indexes:
|
35 |
+
Chunks.append([])
|
36 |
+
Chunks[-1].append(isd_el['text'])
|
37 |
+
|
38 |
+
# Removing all the chunks with sum of lenth of all the strings in the chunk < 25 #TODO: This value can be an user defined variable
|
39 |
+
for chunk in Chunks:
|
40 |
+
# sum of lenth of all the strings in the chunk
|
41 |
+
sum = 0
|
42 |
+
sum += len(str(chunk))
|
43 |
+
if sum < 25:
|
44 |
+
Chunks.remove(chunk)
|
45 |
+
else :
|
46 |
+
# appending all the approved chunks to final_chunks as a single string
|
47 |
+
final_chunks.append(" ".join([str(item) for item in chunk]))
|
48 |
+
return final_chunks
|
build_index/parser/pdf_parser.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*-coding:utf-8 -*-
|
2 |
+
import PyPDF2
|
3 |
+
from build_index.parser.base import BaseParser
|
4 |
+
|
5 |
+
|
6 |
+
class PDFParser(BaseParser):
|
7 |
+
def header_remove(self):
|
8 |
+
# 删除研报的页头
|
9 |
+
pass
|
10 |
+
|
11 |
+
def footnote_remove(self):
|
12 |
+
# 删除研报的页脚
|
13 |
+
pass
|
14 |
+
|
15 |
+
def parse_file(self, file):
|
16 |
+
# store pages of
|
17 |
+
text_list = []
|
18 |
+
|
19 |
+
with open(file, "rb") as fp:
|
20 |
+
pdf = PyPDF2.PdfReader(fp)
|
21 |
+
num_pages = len(pdf.pages)
|
22 |
+
for page in range(num_pages):
|
23 |
+
page_text = pdf.pages[page].extract_text()
|
24 |
+
text_list.append(page_text)
|
25 |
+
text = '\n'.join(text_list)
|
26 |
+
metadata = {'source': file, 'pages': num_pages}
|
27 |
+
return text, metadata
|
build_index/pricing.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*-coding:utf-8 -*-
|
2 |
+
import tiktoken
|
3 |
+
|
4 |
+
|
5 |
+
def num_tokens_from_string(string: str, encoding_name: str):
|
6 |
+
# Function to convert string to tokens and estimate user cost.
|
7 |
+
encoding = tiktoken.get_encoding(encoding_name)
|
8 |
+
num_tokens = len(encoding.encode(string))
|
9 |
+
total_price = ((num_tokens / 1000) * 0.0004)
|
10 |
+
return num_tokens, total_price
|
11 |
+
|
12 |
+
|
13 |
+
def check_price(docs):
|
14 |
+
docs_content = ""
|
15 |
+
for doc in docs:
|
16 |
+
docs_content += doc.page_content
|
17 |
+
|
18 |
+
tokens, total_price = num_tokens_from_string(string=docs_content, encoding_name="cl100k_base")
|
19 |
+
|
20 |
+
print(f"Number of Tokens = {format(tokens, ',d')}")
|
21 |
+
print(f"Approx Cost = ${format(total_price, ',.2f')}")
|
22 |
+
user_input = input("Price Okay? (Y/N) \n").upper() == 'Y'
|
23 |
+
return user_input
|
build_index/process.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*-coding:utf-8 -*-
|
2 |
+
"""
|
3 |
+
Split Document into chunks
|
4 |
+
"""
|
5 |
+
|
6 |
+
import re
|
7 |
+
import tiktoken
|
8 |
+
|
9 |
+
from typing import List
|
10 |
+
from build_index.base import Document
|
11 |
+
from math import ceil
|
12 |
+
|
13 |
+
|
14 |
+
def separate_header_and_body(text):
|
15 |
+
header_pattern = r"^(.*?\n){3}"
|
16 |
+
match = re.match(header_pattern, text)
|
17 |
+
header = match.group(0)
|
18 |
+
body = text[len(header):]
|
19 |
+
return header, body
|
20 |
+
|
21 |
+
|
22 |
+
def group_documents(documents: List[Document], min_tokens: int, max_tokens: int) -> List[Document]:
|
23 |
+
docs = []
|
24 |
+
current_group = None
|
25 |
+
|
26 |
+
for doc in documents:
|
27 |
+
doc_len = len(tiktoken.get_encoding("cl100k_base").encode(doc.text))
|
28 |
+
|
29 |
+
if current_group is None:
|
30 |
+
current_group = Document(text=doc.text, doc_id=doc.doc_id, embedding=doc.embedding,
|
31 |
+
extra_info=doc.extra_info)
|
32 |
+
elif len(tiktoken.get_encoding("cl100k_base").encode(current_group.text)) + doc_len < max_tokens and doc_len >= min_tokens:
|
33 |
+
current_group.text += " " + doc.text
|
34 |
+
else:
|
35 |
+
docs.append(current_group)
|
36 |
+
current_group = Document(text=doc.text, doc_id=doc.doc_id, embedding=doc.embedding,
|
37 |
+
extra_info=doc.extra_info)
|
38 |
+
|
39 |
+
if current_group is not None:
|
40 |
+
docs.append(current_group)
|
41 |
+
|
42 |
+
return docs
|
43 |
+
|
44 |
+
|
45 |
+
def split_documents(documents: List[Document], max_tokens: int) -> List[Document]:
|
46 |
+
docs = []
|
47 |
+
for doc in documents:
|
48 |
+
token_length = len(tiktoken.get_encoding("cl100k_base").encode(doc.text))
|
49 |
+
if token_length <= max_tokens:
|
50 |
+
docs.append(doc)
|
51 |
+
else:
|
52 |
+
header, body = separate_header_and_body(doc.text)
|
53 |
+
num_body_parts = ceil(token_length / max_tokens)
|
54 |
+
part_length = ceil(len(body) / num_body_parts)
|
55 |
+
body_parts = [body[i:i + part_length] for i in range(0, len(body), part_length)]
|
56 |
+
for i, body_part in enumerate(body_parts):
|
57 |
+
new_doc = Document(text=header + body_part.strip(),
|
58 |
+
doc_id=f"{doc.doc_id}-{i}",
|
59 |
+
embedding=doc.embedding,
|
60 |
+
extra_info=doc.extra_info)
|
61 |
+
docs.append(new_doc)
|
62 |
+
return docs
|
63 |
+
|
64 |
+
|
65 |
+
def group_split(documents: List[Document], max_tokens: int = 2000, min_tokens: int = 150, token_check: bool = True):
|
66 |
+
if token_check == False:
|
67 |
+
return documents
|
68 |
+
print("Grouping small documents")
|
69 |
+
try:
|
70 |
+
documents = group_documents(documents=documents, min_tokens=min_tokens, max_tokens=max_tokens)
|
71 |
+
except:
|
72 |
+
print("Grouping failed, try running without token_check")
|
73 |
+
print("Separating large documents")
|
74 |
+
try:
|
75 |
+
documents = split_documents(documents=documents, max_tokens=max_tokens)
|
76 |
+
except:
|
77 |
+
print("Grouping failed, try running without token_check")
|
78 |
+
return documents
|
build_index/run.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*-coding:utf-8 -*-
|
2 |
+
"""
|
3 |
+
Build Faiss Index for all document in ./data
|
4 |
+
"""
|
5 |
+
|
6 |
+
import argparse
|
7 |
+
from build_index.base import FileReader, Document
|
8 |
+
from build_index.process import split_documents
|
9 |
+
from build_index.pricing import check_price
|
10 |
+
from build_index.doc2vec import doc2vec
|
11 |
+
|
12 |
+
|
13 |
+
def process_directory(directory, folder_name, model, max_tokens):
|
14 |
+
reader = FileReader(directory, folder_name)
|
15 |
+
docs = reader.load_data()
|
16 |
+
print(f'Before chunking {len(docs)} docs')
|
17 |
+
new_docs = split_documents(documents=docs, max_tokens=max_tokens)
|
18 |
+
print(f'After chunking {len(new_docs)} docs')
|
19 |
+
docs = [Document.to_langchain_format(doc) for doc in new_docs]
|
20 |
+
|
21 |
+
doc2vec(docs, model, folder_name)
|
22 |
+
return docs
|
23 |
+
|
24 |
+
|
25 |
+
def process_files(files, model, max_tokens):
|
26 |
+
reader = FileReader(input_files=files)
|
27 |
+
docs = reader.load_data()
|
28 |
+
print(f'Before chunking {len(docs)} docs')
|
29 |
+
new_docs = split_documents(documents=docs, max_tokens=max_tokens)
|
30 |
+
print(f'After chunking {len(new_docs)} docs')
|
31 |
+
docs = [Document.to_langchain_format(doc) for doc in new_docs]
|
32 |
+
|
33 |
+
doc2vec(docs, model)
|
34 |
+
return docs
|
35 |
+
|
36 |
+
|
37 |
+
if __name__ =='__main__':
|
38 |
+
import argparse
|
39 |
+
parser = argparse.ArgumentParser()
|
40 |
+
parser.add_argument("directory", type=str, deafult='./data')
|
41 |
+
parser.add_argument("folder", type=str, default='半导体')
|
42 |
+
parser.add_argument("model", type=str, default='mpnet')
|
43 |
+
parser.add_argument('max_tokens', type=int, default=514)
|
44 |
+
args = parser.parse_args()
|
45 |
+
|
46 |
+
process_directory(args.directory, args.folder, args.model, args.max_token)
|
47 |
+
|
48 |
+
process_directory('./data', '半导体', 'mpnet', 514)
|
49 |
+
process_directory('./data', '数字经济', 'cohere', 514)
|
50 |
+
process_directory('./data', '数字经济', 'cohere', 514)
|
51 |
+
process_directory('./data', '两会', 'cohere', 514)
|
52 |
+
|
53 |
+
from glob import glob
|
54 |
+
files = glob('./data/两会/*.pdf')
|
55 |
+
files = files[:3]
|
56 |
+
process_files(files[:3], 'cohere', 514)
|
build_index/unit_test/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# -*-coding:utf-8 -*-
|
build_index/unit_test/test_faiss.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*-coding:utf-8 -*-
|
2 |
+
|
3 |
+
from langchain.vectorstores import FAISS
|
4 |
+
from langchain.embeddings import HuggingFaceEmbeddings, CohereEmbeddings
|
5 |
+
from key import CoherenceKey
|
6 |
+
|
7 |
+
## MPNET
|
8 |
+
model = HuggingFaceEmbeddings()
|
9 |
+
db = FAISS.load_local('./output/半导体', model)
|
10 |
+
|
11 |
+
docs = db.similarity_search('东吴证券观点')
|
12 |
+
print(docs[0].page_content)
|
13 |
+
|
14 |
+
|
15 |
+
docs = db.similarity_search('德邦证券')
|
16 |
+
print(docs[0].page_content)
|
17 |
+
|
18 |
+
|
19 |
+
## Coherence
|
20 |
+
model = CohereEmbeddings(cohere_api_key=CoherenceKey)
|
21 |
+
db = FAISS.load_local('./output/半导体', model)
|
22 |
+
|
23 |
+
docs = db.similarity_search('半导体指数行情')
|
24 |
+
print(docs[0].page_content)
|
25 |
+
|
26 |
+
|
27 |
+
docs = db.similarity_search('关于行业光刻胶相关新闻')
|
28 |
+
print(docs[0].page_content)
|
29 |
+
|
30 |
+
|
31 |
+
|
32 |
+
docs = db.similarity_search('2023年GDP增速预测')
|
33 |
+
print(docs[0].page_content)
|
build_index/unit_test/test_loader.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# -*-coding:utf-8 -*-
|
key.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*-coding:utf-8 -*-
|
2 |
+
CoherenceKey = '5IRbILAbjTI0VcqTsktBfKsr13Lych9iBAFbLpkj'
|
3 |
+
#OpenaiKey='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'
|
4 |
+
OpenaiKey='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'
|
prompts/__init__.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*-coding:utf-8 -*-
|
2 |
+
|
3 |
+
# load prompt template
|
4 |
+
with open("prompts/combine_prompt.txt", "r") as f:
|
5 |
+
template = f.read()
|
6 |
+
|
7 |
+
with open("prompts/combine_prompt_hist.txt", "r") as f:
|
8 |
+
template_hist = f.read()
|
9 |
+
|
10 |
+
with open("prompts/chat_combine_prompt.txt", "r") as f:
|
11 |
+
chat_combine_template = f.read()
|
12 |
+
|
13 |
+
with open("prompts/chat_reduce_prompt.txt", "r") as f:
|
14 |
+
chat_reduce_template = f.read()
|
15 |
+
|
16 |
+
|
17 |
+
MyTemplate ={
|
18 |
+
'chat_reduce_template': chat_reduce_template,
|
19 |
+
'chat_combine_template': chat_combine_template,
|
20 |
+
'template_hist': template_hist,
|
21 |
+
'template':template
|
22 |
+
}
|
prompts/chat_combine_prompt.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
You are a DocsGPT, friendly and helpful AI assistant by TianHong Asset Managementthat provides help with documents and financial news. You give thorough answers with detail number and illustrated examples if possible.
|
2 |
+
Use the following pieces of context to help answer the users question, always answer in chinese.
|
3 |
+
----------------
|
4 |
+
{summaries}
|
prompts/chat_reduce_prompt.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
Use the following portion of a long document to see if any of the text is relevant to answer the question.
|
2 |
+
{context}
|
3 |
+
Provide all relevant text to the question verbatim. Summarize if needed, Answer in Chinese and illustratively. If nothing relevant return "-".
|
prompts/combine_prompt.txt
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
You are a DocsGPT, friendly and helpful AI assistant by Arc53 that provides help with documents. You give thorough answers with code examples if possible.
|
2 |
+
|
3 |
+
QUESTION: How to merge tables in pandas?
|
4 |
+
=========
|
5 |
+
Content: pandas provides various facilities for easily combining together Series or DataFrame with various kinds of set logic for the indexes and relational algebra functionality in the case of join / merge-type operations.
|
6 |
+
Source: 28-pl
|
7 |
+
Content: pandas provides a single function, merge(), as the entry point for all standard database join operations between DataFrame or named Series objects: \n\npandas.merge(left, right, how='inner', on=None, left_on=None, right_on=None, left_index=False, right_index=False, sort=False, suffixes=('_x', '_y'), copy=True, indicator=False, validate=None)
|
8 |
+
Source: 30-pl
|
9 |
+
=========
|
10 |
+
FINAL ANSWER: To merge two tables in pandas, you can use the pd.merge() function. The basic syntax is: \n\npd.merge(left, right, on, how) \n\nwhere left and right are the two tables to merge, on is the column to merge on, and how is the type of merge to perform. \n\nFor example, to merge the two tables df1 and df2 on the column 'id', you can use: \n\npd.merge(df1, df2, on='id', how='inner')
|
11 |
+
SOURCES: 28-pl 30-pl
|
12 |
+
|
13 |
+
QUESTION: How are you?
|
14 |
+
=========
|
15 |
+
CONTENT:
|
16 |
+
SOURCE:
|
17 |
+
=========
|
18 |
+
FINAL ANSWER: I am fine, thank you. How are you?
|
19 |
+
SOURCES:
|
20 |
+
|
21 |
+
QUESTION: {{ question }}
|
22 |
+
=========
|
23 |
+
{{ summaries }}
|
24 |
+
=========
|
25 |
+
FINAL ANSWER:
|
prompts/combine_prompt_hist.txt
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
You are a DocsGPT, friendly and helpful AI assistant by Arc53 that provides help with documents. You give thorough answers with code examples if possible.
|
2 |
+
|
3 |
+
QUESTION: How to merge tables in pandas?
|
4 |
+
=========
|
5 |
+
Content: pandas provides various facilities for easily combining together Series or DataFrame with various kinds of set logic for the indexes and relational algebra functionality in the case of join / merge-type operations.
|
6 |
+
Source: 28-pl
|
7 |
+
Content: pandas provides a single function, merge(), as the entry point for all standard database join operations between DataFrame or named Series objects: \n\npandas.merge(left, right, how='inner', on=None, left_on=None, right_on=None, left_index=False, right_index=False, sort=False, suffixes=('_x', '_y'), copy=True, indicator=False, validate=None)
|
8 |
+
Source: 30-pl
|
9 |
+
=========
|
10 |
+
FINAL ANSWER: To merge two tables in pandas, you can use the pd.merge() function. The basic syntax is: \n\npd.merge(left, right, on, how) \n\nwhere left and right are the two tables to merge, on is the column to merge on, and how is the type of merge to perform. \n\nFor example, to merge the two tables df1 and df2 on the column 'id', you can use: \n\npd.merge(df1, df2, on='id', how='inner')
|
11 |
+
SOURCES: 28-pl 30-pl
|
12 |
+
|
13 |
+
QUESTION: How are you?
|
14 |
+
=========
|
15 |
+
CONTENT:
|
16 |
+
SOURCE:
|
17 |
+
=========
|
18 |
+
FINAL ANSWER: I am fine, thank you. How are you?
|
19 |
+
SOURCES:
|
20 |
+
|
21 |
+
QUESTION: {{ historyquestion }}
|
22 |
+
=========
|
23 |
+
CONTENT:
|
24 |
+
SOURCE:
|
25 |
+
=========
|
26 |
+
FINAL ANSWER: {{ historyanswer }}
|
27 |
+
SOURCES:
|
28 |
+
|
29 |
+
QUESTION: {{ question }}
|
30 |
+
=========
|
31 |
+
{{ summaries }}
|
32 |
+
=========
|
33 |
+
FINAL ANSWER:
|
prompts/question_prompt.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Use the following portion of a long document to see if any of the text is relevant to answer the question.
|
2 |
+
{{ context }}
|
3 |
+
Question: {{ question }}
|
4 |
+
Provide all relevant text to the question verbatim. Summarize if needed. If nothing relevant return "-".
|
prompts/style.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*-coding:utf-8 -*-
|
2 |
+
|
3 |
+
Prompt = {
|
4 |
+
# 文章润色: 修改拼写,语法,提高可读性
|
5 |
+
'Better': 'As a writing improvement assistant, your task is to improve the spelling, grammar, clarity, concision, and overall readability of the text provided, while breaking down long sentences, reducing repetition, and providing suggestions for improvement. Please provide only the corrected Chinese version of the text and avoid including explanations. Please begin by editing the following text: {Text}',
|
6 |
+
# 小红书画风: 适合宣传,营销场景
|
7 |
+
'RedBook':'Please edit the following passage in Chinese using the Xiaohongshu style, which is characterized by captivating headlines, the inclusion of emoticons in each paragraph, and the addition of relevant tags at the end. Be sure to maintain the original meaning of the text. Please begin by editing the following text: {Text}',
|
8 |
+
# 去口语化: 口语转正式文本,用户访谈转正式文本
|
9 |
+
'Decolloquialism': 'Using concise and clear language, please edit the following passage to improve its logical flow, eliminate any typographical errors and respond in Chinese. Be sure to maintain the original meaning of the text. Please begin by editing the following text: {Text}',
|
10 |
+
# 文本缩写摘要
|
11 |
+
'Summarize':'Summarize the following text into 100 words, making it easy to read and comprehend. The summary should be concise, clear, and capture the main points of the text. Avoid using complex sentence structures or technical jargon. Please begin by editing the following text: {Text}',
|
12 |
+
# 幽默搞笑画风
|
13 |
+
'Humor': 'Rewrite the article with a stylistic flair that retains the original semantic meaning while making it more humorous and witty. Respond in Chinese. Please begin by editing the following text: {Text} '
|
14 |
+
}
|
requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
openai==0.27.2
|
2 |
+
gradio==3.21.0
|
3 |
+
langchain==0.0.113
|
4 |
+
unstructured==0.4.11
|
5 |
+
PyPDF2==3.0.1
|
6 |
+
tiktoken==0.1.2
|
7 |
+
dataclasses_json==0.5.7
|
8 |
+
retry==0.9.2
|
9 |
+
cohere==3.4.0
|
10 |
+
faiss-cpu==1.7.3
|