Manel commited on
Commit
1893b91
·
verified ·
1 Parent(s): fc08454

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +293 -0
app.py CHANGED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import base64
4
+ import logging
5
+ import torch
6
+ import streamlit as st
7
+ from langchain.chains import LLMChain
8
+ from langchain.prompts import PromptTemplate
9
+ from langchain.llms import HuggingFacePipeline
10
+ from langchain.retrievers import ContextualCompressionRetriever
11
+ from langchain.retrievers.document_compressors import LLMChainExtractor
12
+ from langchain.embeddings import HuggingFaceBgeEmbeddings
13
+ from langchain.llms import HuggingFacePipeline
14
+ from langchain.vectorstores import Chroma
15
+
16
+
17
+
18
+ @st.cache_resource
19
+ def load_model(model_name, logger, ):
20
+ logger.info("Loading model ..")
21
+ start_time = time.time()
22
+
23
+ if model_name=='llama':
24
+ from langchain.llms import CTransformers
25
+
26
+ model = CTransformers(model="TheBloke/Llama-2-7B-Chat-GGML", model_file = 'llama-2-7b-chat.ggmlv3.q2_K.bin',
27
+ model_type='llama', gpu_layers=0, config={"context_length":2048,})
28
+ tokenizer = None
29
+
30
+ elif model_name=='mistral':
31
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
32
+
33
+ model_id="filipealmeida/Mistral-7B-Instruct-v0.1-sharded"
34
+
35
+ quant_config = BitsAndBytesConfig(
36
+ load_in_4bit=True,
37
+ bnb_4bit_quant_type="nf4",
38
+ bnb_4bit_use_double_quant=True,
39
+ bnb_4bit_compute_dtype=torch.bfloat16)
40
+
41
+ model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True, quantization_config=quant_config, device_map="auto")
42
+
43
+ tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
44
+ tokenizer.pad_token = tokenizer.eos_token
45
+
46
+ logger.info(f"Model Loading Time : {time.time() - start_time} .")
47
+
48
+ return model, tokenizer
49
+
50
+
51
+ @st.cache_resource
52
+ def load_db(logger, device, local_embed=False, CHROMA_PATH = './ChromaDB'):
53
+ """
54
+ Load vector embeddings and Chroma database
55
+ """
56
+ encode_kwargs = {'normalize_embeddings': True}
57
+ embed_id = "BAAI/bge-large-en-v1.5"
58
+ start_time = time.time()
59
+
60
+ #TODO : LOOK INTO LOADING ONLY A SINGLE FILE FROM HF REPO TO REDUCE MEMORY
61
+ if local_embed:
62
+ from transformers import AutoModel
63
+
64
+ PATH_TO_EMBEDDING_FOLDER = ""
65
+ # TODO : load only pytorch bin file
66
+ embeddings = AutoModel.from_pretrained(PATH_TO_EMBEDDING_FOLDER, trust_remote_code=True)
67
+ embeddings = HuggingFaceBgeEmbeddings(model_name="whatever-model-you-are-using", model_kwargs={"trust_remote_code":True})
68
+ logger.info('Loading embeddings locally.')
69
+ # Test the local embeddings
70
+ embed = embeddings.get_text_embedding("Hello World!")
71
+ print(len(embed))
72
+ print(embed[:5])
73
+
74
+ else:
75
+ embeddings = HuggingFaceBgeEmbeddings(model_name=embed_id , model_kwargs={"device": device}, encode_kwargs=encode_kwargs)
76
+ logger.info('Loading embeddings from Hub.')
77
+
78
+
79
+ db = Chroma(persist_directory=CHROMA_PATH, embedding_function=embeddings)
80
+ logger.info(f"Vector Embeddings and Chroma Database Loading Time : {time.time() - start_time} .")
81
+ return db
82
+
83
+
84
+ def wrap_model(model, tokenizer):
85
+ """wrap transformers pipeline with HuggingFacePipeline
86
+ """
87
+ text_generation_pipeline = pipeline(
88
+ model=model,
89
+ tokenizer=tokenizer,
90
+ task="text-generation",
91
+ temperature=0.2,
92
+ repetition_penalty=1.1,
93
+ #return_full_text=True,
94
+ max_new_tokens=1000,
95
+ pad_token_id=2,
96
+ do_sample=True)
97
+ HF_pipeline = HuggingFacePipeline(pipeline=text_generation_pipeline)
98
+ return HF_pipeline
99
+
100
+
101
+
102
+ def fetch_context(db, model, query, logger, template, use_compressor=True):
103
+ """
104
+ Perform similarity search and retrieve related context to query.
105
+ I have stored large documents in db so I can apply compressor on the set of retrived documents to
106
+ make sure that returned compressed context is relevant to the query.
107
+ """
108
+ if use_compressor:
109
+ if model_name=='llama':
110
+ compressor = LLMChainExtractor.from_llm(model)
111
+ compressor.llm_chain.prompt.template = template['llama_rag_template']
112
+
113
+ elif model_name=='mistral':
114
+ HF_pipeline_model = wrap_model(model)
115
+ global HF_pipeline_model
116
+ compressor = LLMChainExtractor.from_llm(HF_pipeline_model)
117
+ compressor.llm_chain.prompt.template = template['rag_template']
118
+
119
+ retriever = db.as_retriever(search_type = "mmr")
120
+ compression_retriever = ContextualCompressionRetriever(base_compressor=compressor,
121
+ base_retriever=retriever)
122
+ logger.info(f"User Query : {query}")
123
+ compressed_docs = compression_retriever.get_relevant_documents(query)
124
+ logger.info(f"Retrieved Compressed Docs : {compressed_docs}")
125
+
126
+ return compressed_docs
127
+
128
+ docs = db.max_marginal_relevance_search(query)
129
+ logger.info(f"Retrieved Docs : {docs}")
130
+
131
+ return docs
132
+
133
+
134
+ def format_context(docs):
135
+ """
136
+ clean and format chunks into documents to pass as context
137
+ """
138
+ cleaned_docs = [doc for doc in docs if ">>>" not in doc.page_content]
139
+ return "\n\n".join(doc.page_content for doc in cleaned_docs)
140
+
141
+
142
+
143
+ def llm_chain_with_context(model, model_name, query, context, template, logger):
144
+ """
145
+ Run simple chain with formatted prompt including query and retrieved context and the underlying model to generate a response.
146
+ """
147
+ formated_context = format_context(context)
148
+ # Give a precise answer to the question based on the context. Don't be verbose.
149
+ if model_name=='llama':
150
+ prompt_template = PromptTemplate(input_variables=['context', 'user_query'], template = template['llama_prompt_template'])
151
+ llm_chain = LLMChain(llm=model, prompt=prompt_template)
152
+
153
+ elif model_name=='mistral':
154
+ prompt_template = PromptTemplate(input_variables=['context', 'user_query'], template = template['prompt_template'])
155
+ llm_chain = LLMChain(llm=HF_pipeline_model, prompt=prompt_template)
156
+
157
+ output = llm_chain.predict(user_query=query, context=formated_context)
158
+ return output
159
+
160
+
161
+ def generate_response(query, model, template, logger):
162
+ start_time = time.time()
163
+ progress_text = "Loading model. Please wait."
164
+ my_bar = st.progress(0, text=progress_text)
165
+ context = fetch_context(db, model, model_name, query, template, logger)
166
+ # fill those as appropriate
167
+ my_bar.progress(0.1, "Loading Database. Please wait.")
168
+
169
+ my_bar.progress(0.3, "Loading Model. Please wait.")
170
+
171
+ my_bar.progress(0.5, "Running RAG. Please wait.")
172
+
173
+ my_bar.progress(0.7, "Generating Answer. Please wait.")
174
+ response = llm_chain_with_context(model, model_name, query, context, template, logger)
175
+
176
+ logger.info(f"Total Execution Time: {time.time() - start_time}")
177
+
178
+ my_bar.progress(0.9, "Post Processing. Please wait.")
179
+
180
+ my_bar.progress(1.0, "Done")
181
+ time. sleep(1)
182
+ my_bar.empty()
183
+ return response
184
+
185
+
186
+ # show background image
187
+ def convert_to_base64(bin_file):
188
+ with open(bin_file, 'rb') as f:
189
+ data = f.read()
190
+ return base64.b64encode(data).decode()
191
+
192
+ def set_as_background_img(png_file):
193
+ bin_str = convert_to_base64(png_file)
194
+ background_img = '''
195
+ <link href='https://fonts.googleapis.com/css?family=Libre Baskerville' rel='stylesheet'>
196
+ <style>
197
+ .stApp {
198
+ background-image: url("data:image/png;base64,%s");
199
+ background-size: cover;
200
+ background-repeat: no-repeat;
201
+ background-attachment: scroll;
202
+ }
203
+ </style>
204
+ ''' % bin_str
205
+ st.markdown(background_img, unsafe_allow_html=True)
206
+ return
207
+
208
+
209
+ if __name__=="__main__":
210
+
211
+ st.set_page_config(page_title='StoicCyber', page_icon="🏛️", layout="centered", initial_sidebar_state="collapsed")
212
+ set_as_background_img('pxfuel.jpg')
213
+ # header
214
+ original_title = '<h1 style="font-family: Libre Baskerville; color:#faf8f8; font-size: 30px; text-align: left; ">STOIC Ω CYBER</h1>'
215
+ st.markdown(original_title, unsafe_allow_html=True)
216
+
217
+ user_question = st.chat_input('What do you want to ask ..')
218
+
219
+ # hide footer and header
220
+ hide_st_style = """
221
+ <style>
222
+ header {visibility: hidden;}
223
+ footer {visibility: hidden;}
224
+ </style>
225
+ """
226
+ st.markdown(hide_st_style, unsafe_allow_html=True)
227
+
228
+ # set logger
229
+ logger = logging.getLogger(__name__)
230
+ logging.basicConfig(
231
+ filename="app.log",
232
+ filemode="a",
233
+ format="%(asctime)s.%(msecs)03d %(levelname)s [%(funcName)s] %(message)s",
234
+ level=logging.INFO,
235
+ datefmt="%Y-%m-%d %H:%M:%S",)
236
+
237
+
238
+ # model to use in spaces depends on the available device
239
+ device = "cuda" if torch.cuda.is_available() else "cpu
240
+
241
+ model_name = "llama" if device=="cpu" else "mistral"
242
+
243
+ logger.info(f'Running {model_name} model for inference on {device}')
244
+
245
+
246
+ all_templates = { "llama_prompt_template" : """<s>[INST]\n<<SYS>>\nYou are a stoic teacher that provide guidance and advice inspired by Stoic philosophy on navigating life's challenges with resilience and inner peace. Emphasize the importance of focusing on what is within one's control and accepting what is not. Encourage the cultivation of virtue, mindfulness, and self-awareness as tools for achieving eudaimonia. Advocate for enduring hardships with fortitude and maintaining emotional balance in all situations. Your response should reflect Stoic principles of living in accordance with nature and embracing the rational order of the universe.
247
+ You should guide the reader towards a fulfilling life focused on virtue rather than external things because living in accordance with virtue leads to eudaimonia or flourishing.
248
+ context:
249
+ {context}\n<</SYS>>\n\n
250
+ question:
251
+ {user_query}
252
+ [/INST]""",
253
+
254
+ "llmaa_rag_prompt" :"""<s>[INST]\n<<SYS>>\nGiven the following question and context, summarize the parts that are relevant to answer the question. If none of the context is relevant return NO_OUTPUT.\n\n>
255
+ - Do not mention quotes.\n\n
256
+ - Reply using a single sentence.\n\n
257
+ > Context:\n
258
+ >>>\n{context}\n>>>\n<</SYS>>\n\n
259
+ Question: {question}\n
260
+ [/INST]
261
+ The relevant parts of the context are:
262
+ """,
263
+
264
+ "prompt_template":"""You are a stoic teacher that provide guidance and advice inspired by Stoic philosophy on navigating life's challenges with resilience and inner peace. Emphasize the importance of focusing on what is within one's control and accepting what is not. Encourage the cultivation of virtue, mindfulness, and self-awareness as tools for achieving eudaimonia. Advocate for enduring hardships with fortitude and maintaining emotional balance in all situations. Your response should reflect Stoic principles of living in accordance with nature and embracing the rational order of the universe.
265
+ You should guide the reader towards a fulfilling life focused on virtue rather than external things because living in accordance with virtue leads to eudaimonia or flourishing.
266
+ context:
267
+ {context}
268
+
269
+ question:
270
+ {user_query}
271
+
272
+ Answer:
273
+ """,
274
+ "rag_prompt" : """Given the following question and context, summarize the parts that are relevant to answer the question. If none of the context is relevant return NO_OUTPUT.\n\n>
275
+ - Do not mention quotes.\n\n>
276
+ - Reply using a single sentence.\n\n>
277
+
278
+ Question: {question}\n> Context:\n>>>\n{context}\n>>>\nRelevant parts"""}
279
+
280
+
281
+ db = load_db(logger, device)
282
+
283
+ model, tokenizer = load_model(model_name, logger)
284
+
285
+ # streamlit chat
286
+ if user_question is not None and user_question!="":
287
+ with st.chat_message("Human", avatar="🧔🏻"):
288
+ st.write(user_question)
289
+ response = generate_response(user_question, model, all_templates, logger)
290
+ with st.chat_message("AI", avatar="🏛️"):
291
+ st.write(response)
292
+
293
+