hymai commited on
Commit
2781c38
·
verified ·
1 Parent(s): 35996ad

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +210 -0
app.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import chromadb
2
+ from llama_index.core.base.embeddings.base import similarity
3
+ #from llama_index.llms.ollama import Ollama
4
+ from llama_index.llms.groq import Groq
5
+ from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, Settings, DocumentSummaryIndex
6
+ from llama_index.core import StorageContext, get_response_synthesizer
7
+ from llama_index.core.retrievers import VectorIndexRetriever
8
+ from llama_index.core.query_engine import RetrieverQueryEngine
9
+ from llama_index.vector_stores.chroma import ChromaVectorStore
10
+ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
11
+ from llama_index.core import load_index_from_storage
12
+ import os
13
+ from dotenv import load_dotenv
14
+ from llama_index.core.callbacks import CallbackManager, LlamaDebugHandler, CBEventType
15
+ from llama_index.core.node_parser import SentenceSplitter
16
+ from llama_index.core.postprocessor import SimilarityPostprocessor
17
+ import time
18
+ import gradio as gr
19
+ from llama_index.core.memory import ChatMemoryBuffer
20
+ from llama_parse import LlamaParse
21
+ from llama_index.core import PromptTemplate
22
+ from llama_index.core.llms import ChatMessage, MessageRole
23
+ from llama_index.core.chat_engine import CondenseQuestionChatEngine
24
+
25
+
26
+ # load env file
27
+ load_dotenv()
28
+ GROQ_API_KEY = os.getenv('GROQ_API_KEY')
29
+ LLAMAINDEX_API_KEY = os.getenv('LLAMAINDEX_API_KEY')
30
+
31
+ # set up callback manager
32
+ llama_debug = LlamaDebugHandler(print_trace_on_end=True)
33
+ callback_manager = CallbackManager([llama_debug])
34
+ Settings.callback_manager = callback_manager
35
+
36
+ # set up LLM
37
+ llm = Groq(model="llama3-70b-8192")#"llama3-8b-8192")
38
+ Settings.llm = llm
39
+
40
+ # set up embedding model
41
+ embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5")
42
+ Settings.embed_model = embed_model
43
+
44
+ # create splitter
45
+ splitter = SentenceSplitter(chunk_size=2048, chunk_overlap=50)
46
+ Settings.transformations = [splitter]
47
+
48
+ # create parser
49
+ parser = LlamaParse(
50
+ api_key=LLAMAINDEX_API_KEY,
51
+ result_type="markdown", # "markdown" and "text" are available
52
+ verbose=True,
53
+ )
54
+
55
+ #create index
56
+ if os.path.exists("./vectordb"):
57
+ print("Index Exists!")
58
+ storage_context = StorageContext.from_defaults(persist_dir="./vectordb")
59
+ index = load_index_from_storage(storage_context)
60
+ else:
61
+ filename_fn = lambda filename: {"file_name": filename}
62
+ required_exts = [".pdf",".docx"]
63
+ file_extractor = {".pdf": parser}
64
+ reader = SimpleDirectoryReader(
65
+ input_dir="./data",
66
+ file_extractor=file_extractor,
67
+ required_exts=required_exts,
68
+ recursive=True,
69
+ file_metadata=filename_fn
70
+ )
71
+ documents = reader.load_data()
72
+ print("index creating with `%d` documents", len(documents))
73
+ index = VectorStoreIndex.from_documents(documents, embed_model=embed_model, transformations=[splitter])
74
+ index.storage_context.persist(persist_dir="./vectordb")
75
+
76
+ """
77
+ #create document summary index
78
+ if os.path.exists("./docsummarydb"):
79
+ print("Index Exists!")
80
+ storage_context = StorageContext.from_defaults(persist_dir="./docsummarydb")
81
+ doc_index = load_index_from_storage(storage_context)
82
+ else:
83
+ filename_fn = lambda filename: {"file_name": filename}
84
+ required_exts = [".pdf",".docx"]
85
+ reader = SimpleDirectoryReader(
86
+ input_dir="./data",
87
+ required_exts=required_exts,
88
+ recursive=True,
89
+ file_metadata=filename_fn
90
+ )
91
+ documents = reader.load_data()
92
+ print("index creating with `%d` documents", len(documents))
93
+
94
+ response_synthesizer = get_response_synthesizer(
95
+ response_mode="tree_summarize", use_async=True
96
+ )
97
+ doc_index = DocumentSummaryIndex.from_documents(
98
+ documents,
99
+ llm = llm,
100
+ transformations = [splitter],
101
+ response_synthesizer = response_synthesizer,
102
+ show_progress = True
103
+ )
104
+ doc_index.storage_context.persist(persist_dir="./docsummarydb")
105
+ """
106
+ """
107
+ retriever = DocumentSummaryIndexEmbeddingRetriever(
108
+ doc_index,
109
+ similarity_top_k=5,
110
+ )
111
+ """
112
+
113
+ # set up retriever
114
+ retriever = VectorIndexRetriever(
115
+ index = index,
116
+ similarity_top_k = 10,
117
+ #vector_store_query_mode="mmr",
118
+ #vector_store_kwargs={"mmr_threshold": 0.4}
119
+ )
120
+
121
+ # set up response synthesizer
122
+ response_synthesizer = get_response_synthesizer()
123
+
124
+ ### customising prompts worsened the result###
125
+ """
126
+ # set up prompt template
127
+ qa_prompt_tmpl = (
128
+ "Context information from multiple sources is below.\n"
129
+ "---------------------\n"
130
+ "{context_str}\n"
131
+ "---------------------\n"
132
+ "Given the information from multiple sources and not prior knowledge, "
133
+ "answer the query.\n"
134
+ "Query: {query_str}\n"
135
+ "Answer: "
136
+ )
137
+ qa_prompt = PromptTemplate(qa_prompt_tmpl)
138
+ """
139
+ # setting up query engine
140
+ query_engine = RetrieverQueryEngine(
141
+ retriever = retriever,
142
+ node_postprocessors=[SimilarityPostprocessor(similarity_cutoff=0.53)],
143
+ response_synthesizer=get_response_synthesizer(response_mode="tree_summarize",verbose=True)
144
+ )
145
+ print(query_engine.get_prompts())
146
+
147
+ #response = query_engine.query("What happens if the distributor wants its own warehouse for pizzahood?")
148
+ #print(response)
149
+
150
+
151
+ memory = ChatMemoryBuffer.from_defaults(token_limit=10000)
152
+
153
+ custom_prompt = PromptTemplate(
154
+ """\
155
+ Given a conversation (between Human and Assistant) and a follow up message from Human, \
156
+ rewrite the message to be a standalone question that captures all relevant context \
157
+ from the conversation. If you are unsure, ask for more information.
158
+ <Chat History>
159
+ {chat_history}
160
+ <Follow Up Message>
161
+ {question}
162
+ <Standalone question>
163
+ """
164
+ )
165
+
166
+ # list of `ChatMessage` objects
167
+ custom_chat_history = [
168
+ ChatMessage(
169
+ role=MessageRole.USER,
170
+ content="Hello assistant.",
171
+ ),
172
+ ChatMessage(role=MessageRole.ASSISTANT, content="Hello user."),
173
+ ]
174
+
175
+ chat_engine = CondenseQuestionChatEngine.from_defaults(
176
+ query_engine=query_engine,
177
+ condense_question_prompt=custom_prompt,
178
+ chat_history=custom_chat_history,
179
+ verbose=True,
180
+ memory=memory
181
+ )
182
+
183
+
184
+ # gradio with streaming support
185
+ with gr.Blocks() as demo:
186
+ chat_engine = chat_engine
187
+ chatbot = gr.Chatbot()
188
+ msg = gr.Textbox(label="⏎ for sending",
189
+ placeholder="Ask me something",)
190
+ clear = gr.Button("Delete")
191
+
192
+ def user(user_message, history):
193
+ return "", history + [[user_message, None]]
194
+
195
+ def bot(history):
196
+ user_message = history[-1][0]
197
+ #bot_message = chat_engine.chat(user_message)
198
+ bot_message = query_engine.query(user_message + "Let's think step by step to get the correct answer. If you cannot provide an answer, say you don't know.")
199
+ history[-1][1] = ""
200
+ for character in bot_message.response:
201
+ history[-1][1] += character
202
+ time.sleep(0.01)
203
+ yield history
204
+
205
+ msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
206
+ bot, chatbot, chatbot
207
+ )
208
+ clear.click(lambda: None, None, chatbot, queue=False)
209
+ # demo.queue()
210
+ demo.launch(share=False)