Manel commited on
Commit
455c294
·
verified ·
1 Parent(s): b096383

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -16
app.py CHANGED
@@ -15,7 +15,7 @@ from langchain.vectorstores import Chroma
15
 
16
 
17
 
18
- @st.cache_resource
19
  def load_model(model_name):
20
  logger.info("Loading model ..")
21
  start_time = time.time()
@@ -48,7 +48,7 @@ def load_model(model_name):
48
  return model, tokenizer
49
 
50
 
51
- @st.cache_resource
52
  def load_db(device, local_embed=False, CHROMA_PATH = './ChromaDB'):
53
  """
54
  Load vector embeddings and Chroma database
@@ -64,12 +64,9 @@ def load_db(device, local_embed=False, CHROMA_PATH = './ChromaDB'):
64
  PATH_TO_EMBEDDING_FOLDER = ""
65
  # TODO : load only pytorch bin file
66
  embeddings = AutoModel.from_pretrained(PATH_TO_EMBEDDING_FOLDER, trust_remote_code=True)
67
- embeddings = HuggingFaceBgeEmbeddings(model_name="whatever-model-you-are-using", model_kwargs={"trust_remote_code":True})
68
  logger.info('Loading embeddings locally.')
69
- # Test the local embeddings
70
- embed = embeddings.get_text_embedding("Hello World!")
71
- print(len(embed))
72
- print(embed[:5])
73
 
74
  else:
75
  embeddings = HuggingFaceBgeEmbeddings(model_name=embed_id , model_kwargs={"device": device}, encode_kwargs=encode_kwargs)
@@ -160,15 +157,15 @@ def llm_chain_with_context(model, model_name, query, context, template):
160
 
161
  def generate_response(query, model, template):
162
  start_time = time.time()
163
- progress_text = "Loading model. Please wait."
164
  my_bar = st.progress(0, text=progress_text)
165
- context = fetch_context(db, model, model_name, query, template)
166
  # fill those as appropriate
167
- my_bar.progress(0.1, "Loading Database. Please wait.")
168
 
169
- my_bar.progress(0.3, "Loading Model. Please wait.")
170
 
171
- my_bar.progress(0.5, "Running RAG. Please wait.")
 
172
 
173
  my_bar.progress(0.7, "Generating Answer. Please wait.")
174
  response = llm_chain_with_context(model, model_name, query, context, template)
@@ -205,6 +202,12 @@ def set_as_background_img(png_file):
205
  st.markdown(background_img, unsafe_allow_html=True)
206
  return
207
 
 
 
 
 
 
 
208
 
209
  if __name__=="__main__":
210
 
@@ -272,10 +275,17 @@ if __name__=="__main__":
272
 
273
  Question: {question}\n> Context:\n>>>\n{context}\n>>>\nRelevant parts"""}
274
 
275
-
 
 
276
  db = load_db(device)
 
277
  model, tokenizer = load_model(model_name)
 
 
 
278
 
 
279
  response = False
280
  user_question = st.chat_input('What do you want to ask ..')
281
 
@@ -286,11 +296,12 @@ if __name__=="__main__":
286
  st.write(user_question)
287
 
288
  if response:
289
- # to empty response container after first pass
290
- st.chat_message("AI", avatar="🏛️").write(" ")
 
291
 
292
  response = generate_response(user_question, model, all_templates)
293
  with st.chat_message("AI", avatar="🏛️"):
294
- st.write(response)
295
 
296
 
 
15
 
16
 
17
 
18
+ @st.cache_resource(show_spinner=False)
19
  def load_model(model_name):
20
  logger.info("Loading model ..")
21
  start_time = time.time()
 
48
  return model, tokenizer
49
 
50
 
51
+ @st.cache_resource(show_spinner=False)
52
  def load_db(device, local_embed=False, CHROMA_PATH = './ChromaDB'):
53
  """
54
  Load vector embeddings and Chroma database
 
64
  PATH_TO_EMBEDDING_FOLDER = ""
65
  # TODO : load only pytorch bin file
66
  embeddings = AutoModel.from_pretrained(PATH_TO_EMBEDDING_FOLDER, trust_remote_code=True)
67
+ embeddings = HuggingFaceBgeEmbeddings(model_name=" ", model_kwargs={"trust_remote_code":True})
68
  logger.info('Loading embeddings locally.')
69
+
 
 
 
70
 
71
  else:
72
  embeddings = HuggingFaceBgeEmbeddings(model_name=embed_id , model_kwargs={"device": device}, encode_kwargs=encode_kwargs)
 
157
 
158
  def generate_response(query, model, template):
159
  start_time = time.time()
160
+ progress_text = "Running Inference. Please wait."
161
  my_bar = st.progress(0, text=progress_text)
 
162
  # fill those as appropriate
163
+ #my_bar.progress(0.1, "Loading Database. Please wait.")
164
 
165
+ #my_bar.progress(0.3, "Loading Model. Please wait.")
166
 
167
+ my_bar.progress(0.1, "Running RAG. Please wait.")
168
+ context = fetch_context(db, model, model_name, query, template)
169
 
170
  my_bar.progress(0.7, "Generating Answer. Please wait.")
171
  response = llm_chain_with_context(model, model_name, query, context, template)
 
202
  st.markdown(background_img, unsafe_allow_html=True)
203
  return
204
 
205
+
206
+ def stream_to_screen(response):
207
+ for word in response.split():
208
+ yield word + " "
209
+ time.sleep(0.05)
210
+
211
 
212
  if __name__=="__main__":
213
 
 
275
 
276
  Question: {question}\n> Context:\n>>>\n{context}\n>>>\nRelevant parts"""}
277
 
278
+ # Loading and caching db and model
279
+ my_bar = st.progress(0, "Loading Database. Please wait.)
280
+ my_bar.progress(0.1, "Loading Embedding & Database. Please wait.")
281
  db = load_db(device)
282
+ my_bar.progress(0.7, "Loading Model. Please wait.")
283
  model, tokenizer = load_model(model_name)
284
+ my_bar.progress(1.0, "Done")
285
+ time. sleep(1)
286
+ my_bar.empty()
287
 
288
+
289
  response = False
290
  user_question = st.chat_input('What do you want to ask ..')
291
 
 
296
  st.write(user_question)
297
 
298
  if response:
299
+ with st.chat_message("AI", avatar="🏛️"):
300
+ # to empty response container after first pass
301
+ st.write(" ")
302
 
303
  response = generate_response(user_question, model, all_templates)
304
  with st.chat_message("AI", avatar="🏛️"):
305
+ st.write(stream_to_screen(response))
306
 
307