Update app.py
Browse files
app.py
CHANGED
@@ -15,7 +15,7 @@ from langchain.vectorstores import Chroma
|
|
15 |
|
16 |
|
17 |
|
18 |
-
@st.cache_resource
|
19 |
def load_model(model_name):
|
20 |
logger.info("Loading model ..")
|
21 |
start_time = time.time()
|
@@ -48,7 +48,7 @@ def load_model(model_name):
|
|
48 |
return model, tokenizer
|
49 |
|
50 |
|
51 |
-
@st.cache_resource
|
52 |
def load_db(device, local_embed=False, CHROMA_PATH = './ChromaDB'):
|
53 |
"""
|
54 |
Load vector embeddings and Chroma database
|
@@ -64,12 +64,9 @@ def load_db(device, local_embed=False, CHROMA_PATH = './ChromaDB'):
|
|
64 |
PATH_TO_EMBEDDING_FOLDER = ""
|
65 |
# TODO : load only pytorch bin file
|
66 |
embeddings = AutoModel.from_pretrained(PATH_TO_EMBEDDING_FOLDER, trust_remote_code=True)
|
67 |
-
embeddings = HuggingFaceBgeEmbeddings(model_name="
|
68 |
logger.info('Loading embeddings locally.')
|
69 |
-
|
70 |
-
embed = embeddings.get_text_embedding("Hello World!")
|
71 |
-
print(len(embed))
|
72 |
-
print(embed[:5])
|
73 |
|
74 |
else:
|
75 |
embeddings = HuggingFaceBgeEmbeddings(model_name=embed_id , model_kwargs={"device": device}, encode_kwargs=encode_kwargs)
|
@@ -160,15 +157,15 @@ def llm_chain_with_context(model, model_name, query, context, template):
|
|
160 |
|
161 |
def generate_response(query, model, template):
|
162 |
start_time = time.time()
|
163 |
-
progress_text = "
|
164 |
my_bar = st.progress(0, text=progress_text)
|
165 |
-
context = fetch_context(db, model, model_name, query, template)
|
166 |
# fill those as appropriate
|
167 |
-
my_bar.progress(0.1, "Loading Database. Please wait.")
|
168 |
|
169 |
-
my_bar.progress(0.3, "Loading Model. Please wait.")
|
170 |
|
171 |
-
my_bar.progress(0.
|
|
|
172 |
|
173 |
my_bar.progress(0.7, "Generating Answer. Please wait.")
|
174 |
response = llm_chain_with_context(model, model_name, query, context, template)
|
@@ -205,6 +202,12 @@ def set_as_background_img(png_file):
|
|
205 |
st.markdown(background_img, unsafe_allow_html=True)
|
206 |
return
|
207 |
|
|
|
|
|
|
|
|
|
|
|
|
|
208 |
|
209 |
if __name__=="__main__":
|
210 |
|
@@ -272,10 +275,17 @@ if __name__=="__main__":
|
|
272 |
|
273 |
Question: {question}\n> Context:\n>>>\n{context}\n>>>\nRelevant parts"""}
|
274 |
|
275 |
-
|
|
|
|
|
276 |
db = load_db(device)
|
|
|
277 |
model, tokenizer = load_model(model_name)
|
|
|
|
|
|
|
278 |
|
|
|
279 |
response = False
|
280 |
user_question = st.chat_input('What do you want to ask ..')
|
281 |
|
@@ -286,11 +296,12 @@ if __name__=="__main__":
|
|
286 |
st.write(user_question)
|
287 |
|
288 |
if response:
|
289 |
-
|
290 |
-
|
|
|
291 |
|
292 |
response = generate_response(user_question, model, all_templates)
|
293 |
with st.chat_message("AI", avatar="🏛️"):
|
294 |
-
st.write(response)
|
295 |
|
296 |
|
|
|
15 |
|
16 |
|
17 |
|
18 |
+
@st.cache_resource(show_spinner=False)
|
19 |
def load_model(model_name):
|
20 |
logger.info("Loading model ..")
|
21 |
start_time = time.time()
|
|
|
48 |
return model, tokenizer
|
49 |
|
50 |
|
51 |
+
@st.cache_resource(show_spinner=False)
|
52 |
def load_db(device, local_embed=False, CHROMA_PATH = './ChromaDB'):
|
53 |
"""
|
54 |
Load vector embeddings and Chroma database
|
|
|
64 |
PATH_TO_EMBEDDING_FOLDER = ""
|
65 |
# TODO : load only pytorch bin file
|
66 |
embeddings = AutoModel.from_pretrained(PATH_TO_EMBEDDING_FOLDER, trust_remote_code=True)
|
67 |
+
embeddings = HuggingFaceBgeEmbeddings(model_name=" ", model_kwargs={"trust_remote_code":True})
|
68 |
logger.info('Loading embeddings locally.')
|
69 |
+
|
|
|
|
|
|
|
70 |
|
71 |
else:
|
72 |
embeddings = HuggingFaceBgeEmbeddings(model_name=embed_id , model_kwargs={"device": device}, encode_kwargs=encode_kwargs)
|
|
|
157 |
|
158 |
def generate_response(query, model, template):
|
159 |
start_time = time.time()
|
160 |
+
progress_text = "Running Inference. Please wait."
|
161 |
my_bar = st.progress(0, text=progress_text)
|
|
|
162 |
# fill those as appropriate
|
163 |
+
#my_bar.progress(0.1, "Loading Database. Please wait.")
|
164 |
|
165 |
+
#my_bar.progress(0.3, "Loading Model. Please wait.")
|
166 |
|
167 |
+
my_bar.progress(0.1, "Running RAG. Please wait.")
|
168 |
+
context = fetch_context(db, model, model_name, query, template)
|
169 |
|
170 |
my_bar.progress(0.7, "Generating Answer. Please wait.")
|
171 |
response = llm_chain_with_context(model, model_name, query, context, template)
|
|
|
202 |
st.markdown(background_img, unsafe_allow_html=True)
|
203 |
return
|
204 |
|
205 |
+
|
206 |
+
def stream_to_screen(response):
|
207 |
+
for word in response.split():
|
208 |
+
yield word + " "
|
209 |
+
time.sleep(0.05)
|
210 |
+
|
211 |
|
212 |
if __name__=="__main__":
|
213 |
|
|
|
275 |
|
276 |
Question: {question}\n> Context:\n>>>\n{context}\n>>>\nRelevant parts"""}
|
277 |
|
278 |
+
# Loading and caching db and model
|
279 |
+
my_bar = st.progress(0, "Loading Database. Please wait.)
|
280 |
+
my_bar.progress(0.1, "Loading Embedding & Database. Please wait.")
|
281 |
db = load_db(device)
|
282 |
+
my_bar.progress(0.7, "Loading Model. Please wait.")
|
283 |
model, tokenizer = load_model(model_name)
|
284 |
+
my_bar.progress(1.0, "Done")
|
285 |
+
time. sleep(1)
|
286 |
+
my_bar.empty()
|
287 |
|
288 |
+
|
289 |
response = False
|
290 |
user_question = st.chat_input('What do you want to ask ..')
|
291 |
|
|
|
296 |
st.write(user_question)
|
297 |
|
298 |
if response:
|
299 |
+
with st.chat_message("AI", avatar="🏛️"):
|
300 |
+
# to empty response container after first pass
|
301 |
+
st.write(" ")
|
302 |
|
303 |
response = generate_response(user_question, model, all_templates)
|
304 |
with st.chat_message("AI", avatar="🏛️"):
|
305 |
+
st.write(stream_to_screen(response))
|
306 |
|
307 |
|