|
import os |
|
import time |
|
import base64 |
|
import logging |
|
import torch |
|
import streamlit as st |
|
from langchain.chains import LLMChain |
|
from langchain.prompts import PromptTemplate |
|
from langchain.llms import HuggingFacePipeline |
|
from langchain.retrievers import ContextualCompressionRetriever |
|
from langchain.retrievers.document_compressors import LLMChainExtractor |
|
from langchain.embeddings import HuggingFaceBgeEmbeddings |
|
from langchain.llms import HuggingFacePipeline |
|
from langchain.vectorstores import Chroma |
|
from templates import all_templates |
|
|
|
|
|
@st.cache_resource(show_spinner=False) |
|
def load_model(model_name): |
|
logger.info("Loading model ..") |
|
start_time = time.time() |
|
|
|
if model_name=='llama': |
|
from langchain.llms import CTransformers |
|
|
|
model = CTransformers(model="TheBloke/Llama-2-7B-Chat-GGML", |
|
model_file = 'llama-2-7b-chat.ggmlv3.q4_0.bin', |
|
model_type='llama', gpu_layers=0) |
|
tokenizer = None |
|
|
|
elif model_name=='mistral': |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline |
|
|
|
model_id="filipealmeida/Mistral-7B-Instruct-v0.1-sharded" |
|
|
|
quant_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_quant_type="nf4", |
|
bnb_4bit_use_double_quant=True, |
|
bnb_4bit_compute_dtype=torch.bfloat16) |
|
|
|
model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True, quantization_config=quant_config, device_map="auto") |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
print(f"Model Loading Time : {time.time() - start_time}.") |
|
logger.info(f"Model Loading Time : {time.time() - start_time} .") |
|
|
|
|
|
return model, tokenizer |
|
|
|
|
|
@st.cache_resource(show_spinner=False) |
|
def load_db(device, local_embed=False, CHROMA_PATH = './ChromaDB'): |
|
""" |
|
Load vector embeddings and Chroma database |
|
""" |
|
encode_kwargs = {'normalize_embeddings': True} |
|
embed_id = "BAAI/bge-large-en-v1.5" |
|
start_time = time.time() |
|
|
|
|
|
if local_embed: |
|
from transformers import AutoModel |
|
|
|
PATH_TO_EMBEDDING_FOLDER = "" |
|
|
|
embeddings = AutoModel.from_pretrained(PATH_TO_EMBEDDING_FOLDER, trust_remote_code=True) |
|
embeddings = HuggingFaceBgeEmbeddings(model_name=" ", model_kwargs={"trust_remote_code":True}) |
|
logger.info('Loading embeddings locally.') |
|
|
|
|
|
else: |
|
embeddings = HuggingFaceBgeEmbeddings(model_name=embed_id , model_kwargs={"device": device}, encode_kwargs=encode_kwargs) |
|
logger.info('Loading embeddings from Hub.') |
|
|
|
|
|
db = Chroma(persist_directory=CHROMA_PATH, embedding_function=embeddings) |
|
logger.info(f"Vector Embeddings and Chroma Database Loading Time : {time.time() - start_time} .") |
|
print(f"Vector Embeddings and Chroma Database Loading Time : {time.time() - start_time} .") |
|
return db |
|
|
|
|
|
def wrap_model(model, tokenizer): |
|
"""wrap transformers pipeline with HuggingFacePipeline |
|
""" |
|
text_generation_pipeline = pipeline( |
|
model=model, |
|
tokenizer=tokenizer, |
|
task="text-generation", |
|
temperature=0.5, |
|
repetition_penalty=2.1, |
|
no_repeat_ngram_size=3, |
|
max_new_tokens=400, |
|
num_beams=2, |
|
pad_token_id=2, |
|
do_sample=True) |
|
HF_pipeline = HuggingFacePipeline(pipeline=text_generation_pipeline) |
|
return HF_pipeline |
|
|
|
|
|
|
|
def fetch_context(db, model, model_name, query, template, use_compressor=True): |
|
""" |
|
Perform similarity search and retrieve related context to query. |
|
I have stored large documents in db so I can apply compressor on the set of retrived documents to |
|
make sure that returned compressed context is relevant to the query. |
|
""" |
|
if use_compressor: |
|
start_time = time.time() |
|
if model_name=='llama': |
|
compressor = LLMChainExtractor.from_llm(model) |
|
compressor.llm_chain.prompt.template = template['llama_rag_template'] |
|
|
|
elif model_name=='mistral': |
|
global HF_pipeline_model |
|
HF_pipeline_model = wrap_model(model) |
|
compressor = LLMChainExtractor.from_llm(HF_pipeline_model) |
|
compressor.llm_chain.prompt.template = template['rag_template'] |
|
|
|
retriever = db.as_retriever(search_type = "mmr") |
|
compression_retriever = ContextualCompressionRetriever(base_compressor=compressor, |
|
base_retriever=retriever) |
|
|
|
compressed_docs = compression_retriever.get_relevant_documents(query) |
|
|
|
print(f"Compressed context Generation Time: {time.time() - start_time}") |
|
return compressed_docs |
|
|
|
docs = db.max_marginal_relevance_search(query) |
|
|
|
|
|
return docs |
|
|
|
|
|
def format_context(docs): |
|
""" |
|
clean and format chunks into documents to pass as context |
|
""" |
|
cleaned_docs = [doc for doc in docs if ">>>" not in doc.page_content] |
|
return "\n\n".join(doc.page_content for doc in cleaned_docs) |
|
|
|
|
|
|
|
def llm_chain_with_context(model, model_name, query, context, template): |
|
""" |
|
Run simple chain with formatted prompt including query and retrieved context and the underlying model to generate a response. |
|
""" |
|
formated_context = format_context(context) |
|
|
|
start_chain_time = time.time() |
|
|
|
if model_name=='llama': |
|
prompt_template = PromptTemplate(input_variables=['context', 'user_query'], template = template['llama_prompt_template']) |
|
llm_chain = LLMChain(llm=model, prompt=prompt_template) |
|
|
|
elif model_name=='mistral': |
|
prompt_template = PromptTemplate(input_variables=['context', 'user_query'], template = template['prompt_template']) |
|
llm_chain = LLMChain(llm=HF_pipeline_model, prompt=prompt_template) |
|
|
|
print(f"LLMChain Setup Time: {time.time() - start_chain_time}") |
|
|
|
start_inference_time = time.time() |
|
|
|
output = llm_chain.predict(user_query=query, context=formated_context) |
|
|
|
print(f"LLM Inference Time: {time.time() - start_inference_time}") |
|
|
|
return output |
|
|
|
|
|
def generate_response(query, model, template): |
|
start_time = time.time() |
|
progress_text = "Running Inference. Please wait." |
|
my_bar = st.progress(0, text=progress_text) |
|
|
|
my_bar.progress(0.1, "Loading Model. Please wait.") |
|
time.sleep(2) |
|
my_bar.progress(0.4, "Running RAG. Please wait.") |
|
context = fetch_context(db, model, model_name, query, template) |
|
|
|
my_bar.progress(0.6, "Generating Answer. Please wait.") |
|
response = llm_chain_with_context(model, model_name, query, context, template) |
|
|
|
print(f"Total Execution Time: {time.time() - start_time}") |
|
logger.info(f"Total Execution Time: {time.time() - start_time}") |
|
|
|
my_bar.progress(0.9, "Post Processing. Please wait.") |
|
response = post_process(response) |
|
|
|
my_bar.progress(1.0, "Done") |
|
time.sleep(1) |
|
my_bar.empty() |
|
return response |
|
|
|
|
|
def stream_to_screen(response): |
|
for word in response.split(): |
|
yield word + " " |
|
time.sleep(0.15) |
|
|
|
|
|
def post_process(response): |
|
"""Remove if last sentence is unfinished""" |
|
if response[-1] != '.': |
|
sentences = response.split('.') |
|
del sentences[-1] |
|
if not sentences[-1].isalpha(): |
|
del sentences[-1] |
|
return '.'.join(sentences) + '.' |
|
return response |
|
|
|
|
|
def convert_to_base64(bin_file): |
|
with open(bin_file, 'rb') as f: |
|
data = f.read() |
|
return base64.b64encode(data).decode() |
|
|
|
def set_as_background_img(png_file): |
|
bin_str = convert_to_base64(png_file) |
|
background_img = ''' |
|
<link href='https://fonts.googleapis.com/css?family=Libre Baskerville' rel='stylesheet'> |
|
<style> |
|
.stApp { |
|
background-image: url("data:image/png;base64,%s"); |
|
background-size: cover; |
|
background-repeat: no-repeat; |
|
background-attachment: scroll; |
|
} |
|
</style> |
|
''' % bin_str |
|
st.markdown(background_img, unsafe_allow_html=True) |
|
return |
|
|
|
|
|
|
|
|
|
if __name__=="__main__": |
|
|
|
st.set_page_config(page_title='StoicCyber', page_icon="🏛️", layout="centered", initial_sidebar_state="collapsed") |
|
set_as_background_img('pxfuel.jpg') |
|
|
|
original_title = '<h1 style="font-family: Libre Baskerville; color:#faf8f8; font-size: 30px; text-align: left; ">STOIC Ω CYBER</h1>' |
|
st.markdown(original_title, unsafe_allow_html=True) |
|
|
|
|
|
|
|
hide_st_style = """ |
|
<style> |
|
header {visibility: hidden;} |
|
footer {visibility: hidden;} |
|
</style> |
|
""" |
|
st.markdown(hide_st_style, unsafe_allow_html=True) |
|
|
|
logger = logging.getLogger(__name__) |
|
logging.basicConfig( |
|
filename="app.log", |
|
filemode="a", |
|
format="%(asctime)s.%(msecs)03d %(levelname)s [%(funcName)s] %(message)s", |
|
level=logging.INFO, |
|
datefmt="%Y-%m-%d %H:%M:%S",) |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
model_name = "llama" if device=="cpu" else "mistral" |
|
logger.info(f"Running {model_name} model for inference on {device}") |
|
print(f"Running {model_name} model for inference on {device}") |
|
|
|
|
|
|
|
|
|
db = load_db(device) |
|
|
|
model, tokenizer = load_model(model_name) |
|
|
|
|
|
|
|
|
|
|
|
user_question = st.chat_input('What do you want to ask ..') |
|
|
|
if user_question is not None and user_question!="": |
|
with st.chat_message("Human", avatar="🧔🏻"): |
|
st.write(user_question) |
|
|
|
response = generate_response(user_question, model, all_templates) |
|
|
|
with st.chat_message("AI", avatar="🏛️"): |
|
st.write(response) |
|
|
|
|
|
|
|
|