ShawnAI commited on
Commit
735f8b3
·
1 Parent(s): 353b744

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -33
app.py CHANGED
@@ -8,7 +8,6 @@ from langchain.chat_models import ChatOpenAI
8
  from langchain.embeddings import HuggingFaceEmbeddings
9
  from langchain.vectorstores import Pinecone
10
  from langchain.chains import LLMChain
11
- from langchain.chains.retrieval_qa.base import RetrievalQA
12
  from langchain.chains.question_answering import load_qa_chain
13
  import pinecone
14
 
@@ -33,6 +32,7 @@ EMBEDDING_MODEL = os.environ.get("PINECONE_INDEX", "sentence-transformers/all-mp
33
  # return top-k text chunks from vector store
34
  TOP_K_DEFAULT = 15
35
  TOP_K_MAX = 30
 
36
 
37
 
38
  BUTTON_MIN_WIDTH = 215
@@ -106,6 +106,45 @@ init_message = f"""This demonstration website is based on \
106
  2. Insert your **Question** and click `{KEY_SUBMIT}`
107
  """
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
  #----------------------------------------------------------------------------------------------------------
111
  #----------------------------------------------------------------------------------------------------------
@@ -122,7 +161,8 @@ def init_model(api_key, emb_name, db_api_key, db_env, db_index):
122
  if llm_name == "gpt-3.5-turbo":
123
  llm_dict[llm_name] = ChatOpenAI(model_name=llm_name,
124
  temperature = OPENAI_TEMP,
125
- openai_api_key = api_key)
 
126
  else:
127
  llm_dict[llm_name] = OpenAI(model_name=llm_name,
128
  temperature = OPENAI_TEMP,
@@ -148,22 +188,25 @@ def init_model(api_key, emb_name, db_api_key, db_env, db_index):
148
  def get_chat_history(inputs) -> str:
149
  res = []
150
  for human, ai in inputs:
151
- res.append(f"Human: {human}\nAI: {ai}")
152
  return "\n".join(res)
153
 
154
- def remove_duplicates(documents):
155
  seen_content = set()
156
  unique_documents = []
157
- for doc in documents:
158
- if doc.page_content not in seen_content:
159
  seen_content.add(doc.page_content)
160
  unique_documents.append(doc)
161
  return unique_documents
162
 
163
- def doc_similarity(query, db, top_k):
164
- docsearch = db.as_retriever(search_kwargs={'k':top_k})
165
- docs = docsearch.get_relevant_documents(query)
166
- udocs = remove_duplicates(docs)
 
 
 
167
  return udocs
168
 
169
  def user(user_message, history):
@@ -171,7 +214,7 @@ def user(user_message, history):
171
 
172
  def bot(box_message, ref_message,
173
  llm_dropdown, llm_dict, doc_list,
174
- db, top_k):
175
 
176
  # bot_message = random.choice(["Yes", "No"])
177
  # 0 is user question, 1 is bot response
@@ -184,9 +227,9 @@ def bot(box_message, ref_message,
184
 
185
  if not ref_message:
186
  ref_message = question
187
- details = f"Q: {question}"
188
  else:
189
- details = f"Q: {question}\nR: {ref_message}"
190
 
191
 
192
  llm = llm_dict[llm_dropdown]
@@ -196,27 +239,31 @@ def bot(box_message, ref_message,
196
  box_message[-1][1] = DOCS_WARNING
197
  return box_message, "", ""
198
 
199
- chain = load_qa_chain(llm, chain_type="stuff")
200
- docs = doc_similarity(ref_message, db, top_k)
201
  delta_top_k = top_k - len(docs)
202
 
203
  if delta_top_k > 0:
204
- docs = doc_similarity(ref_message, db, top_k+delta_top_k)
205
-
 
 
 
206
  else:
207
- chain = LLMChain(llm = llm,
208
- prompt = PromptTemplate(template='{question}',
209
- input_variables=['question']),
210
- output_key = 'output_text')
211
  docs = []
212
-
213
- all_output = chain({"input_documents": docs,
214
- "question": question,
215
- "chat_history": get_chat_history(history)})
 
 
 
 
 
 
216
 
217
  bot_message = all_output['output_text']
218
 
219
-
220
  source = "".join([f"""<details> <summary>{doc.metadata["source"]}</summary>
221
  {doc.page_content}
222
 
@@ -288,12 +335,21 @@ with gr.Blocks(
288
 
289
 
290
  with gr.Tab(TAB_2):
291
- top_k = gr.Slider(1,
292
- TOP_K_MAX,
293
- value=TOP_K_DEFAULT,
294
- step=1,
295
- label="Vector similarity top_k",
296
- interactive=True)
 
 
 
 
 
 
 
 
 
297
  detail_panel = gr.Chatbot(label="Related Docs")
298
 
299
  with gr.Tab(TAB_3):
@@ -349,7 +405,7 @@ with gr.Blocks(
349
  bot,
350
  [chatbot, ref,
351
  llm_dropdown, llm, doc_check,
352
- vector_db, top_k],
353
  [chatbot, ref, detail_panel]
354
  )
355
 
 
8
  from langchain.embeddings import HuggingFaceEmbeddings
9
  from langchain.vectorstores import Pinecone
10
  from langchain.chains import LLMChain
 
11
  from langchain.chains.question_answering import load_qa_chain
12
  import pinecone
13
 
 
32
  # return top-k text chunks from vector store
33
  TOP_K_DEFAULT = 15
34
  TOP_K_MAX = 30
35
+ SCORE_DEFAULT = 0.3
36
 
37
 
38
  BUTTON_MIN_WIDTH = 215
 
106
  2. Insert your **Question** and click `{KEY_SUBMIT}`
107
  """
108
 
109
+ PROMPT_DOC = PromptTemplate(
110
+ input_variables=["context", "chat_history", "question"],
111
+ template="""Context:
112
+ ##
113
+ {context}
114
+ ##
115
+
116
+ Chat History:
117
+ ##
118
+ {chat_history}
119
+ ##
120
+
121
+ Question:
122
+ {question}
123
+
124
+ Optinal:
125
+ Don't use standalone clause/figure name in the answer, expand it with corresponding metadata TS name
126
+
127
+ Desired format:
128
+ Clause/figure name: <dot_separated_numbers>
129
+ TS name: [\w\.]
130
+
131
+ Answer:"""
132
+ )
133
+
134
+ PROMPT_BASE = PromptTemplate(
135
+ input_variables=['question', "chat_history"],
136
+ template="""Chat History:
137
+ ##
138
+ {chat_history}
139
+ ##
140
+
141
+ Question:
142
+ ##
143
+ {question}
144
+ ##
145
+
146
+ Answer:"""
147
+ )
148
 
149
  #----------------------------------------------------------------------------------------------------------
150
  #----------------------------------------------------------------------------------------------------------
 
161
  if llm_name == "gpt-3.5-turbo":
162
  llm_dict[llm_name] = ChatOpenAI(model_name=llm_name,
163
  temperature = OPENAI_TEMP,
164
+ openai_api_key = api_key
165
+ )
166
  else:
167
  llm_dict[llm_name] = OpenAI(model_name=llm_name,
168
  temperature = OPENAI_TEMP,
 
188
  def get_chat_history(inputs) -> str:
189
  res = []
190
  for human, ai in inputs:
191
+ res.append(f"Q: {human}\nA: {ai}")
192
  return "\n".join(res)
193
 
194
+ def remove_duplicates(documents, score_min):
195
  seen_content = set()
196
  unique_documents = []
197
+ for (doc, score) in documents:
198
+ if (doc.page_content not in seen_content) and (score >= score_min):
199
  seen_content.add(doc.page_content)
200
  unique_documents.append(doc)
201
  return unique_documents
202
 
203
+ def doc_similarity(query, db, top_k, score):
204
+ docs = db.similarity_search_with_score(query = query,
205
+ k=top_k)
206
+ #docsearch = db.as_retriever(search_kwargs={'k':top_k})
207
+ #docs = docsearch.get_relevant_documents(query)
208
+ # print(docs)
209
+ udocs = remove_duplicates(docs, score)
210
  return udocs
211
 
212
  def user(user_message, history):
 
214
 
215
  def bot(box_message, ref_message,
216
  llm_dropdown, llm_dict, doc_list,
217
+ db, top_k, score):
218
 
219
  # bot_message = random.choice(["Yes", "No"])
220
  # 0 is user question, 1 is bot response
 
227
 
228
  if not ref_message:
229
  ref_message = question
230
+ details = f"Q: {question}"
231
  else:
232
+ details = f"Q: {question}\nR: {ref_message}"
233
 
234
 
235
  llm = llm_dict[llm_dropdown]
 
239
  box_message[-1][1] = DOCS_WARNING
240
  return box_message, "", ""
241
 
242
+ docs = doc_similarity(ref_message, db, top_k, score)
 
243
  delta_top_k = top_k - len(docs)
244
 
245
  if delta_top_k > 0:
246
+ docs = doc_similarity(ref_message, db, top_k+delta_top_k, score)
247
+
248
+ prompt = PROMPT_DOC
249
+ #chain = load_qa_chain(llm, chain_type="stuff")
250
+
251
  else:
252
+ prompt = PROMPT_BASE
 
 
 
253
  docs = []
254
+
255
+ chain = LLMChain(llm = llm,
256
+ prompt = prompt,
257
+ output_key = 'output_text')
258
+
259
+ all_output = chain({"question": question,
260
+ "context": docs,
261
+ "chat_history": get_chat_history(history)
262
+ })
263
+
264
 
265
  bot_message = all_output['output_text']
266
 
 
267
  source = "".join([f"""<details> <summary>{doc.metadata["source"]}</summary>
268
  {doc.page_content}
269
 
 
335
 
336
 
337
  with gr.Tab(TAB_2):
338
+ with gr.Row():
339
+ with gr.Column():
340
+ top_k = gr.Slider(1,
341
+ TOP_K_MAX,
342
+ value=TOP_K_DEFAULT,
343
+ step=1,
344
+ label="Vector similarity top_k",
345
+ interactive=True)
346
+ with gr.Column():
347
+ score = gr.Slider(0.01,
348
+ 0.99,
349
+ value=SCORE_DEFAULT,
350
+ step=0.01,
351
+ label="Vector similarity score",
352
+ interactive=True)
353
  detail_panel = gr.Chatbot(label="Related Docs")
354
 
355
  with gr.Tab(TAB_3):
 
405
  bot,
406
  [chatbot, ref,
407
  llm_dropdown, llm, doc_check,
408
+ vector_db, top_k, score],
409
  [chatbot, ref, detail_panel]
410
  )
411