Spaces:
Running
Running
import streamlit as st | |
from chat_client import chat | |
import time | |
from utils import gen_augmented_prompt_via_websearch, inital_prompt_engineering_dict | |
COST_PER_1000_TOKENS_USD = 0.139 * 80 | |
CHAT_BOTS = { | |
"Mixtral 8x7B v0.1": "mistralai/Mixtral-8x7B-Instruct-v0.1", | |
"Mistral 7B v0.1": "mistralai/Mistral-7B-Instruct-v0.1", | |
} | |
st.set_page_config( | |
page_title="Mixtral Playground", | |
page_icon="📚", | |
) | |
def init_state(): | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
if "tokens_used" not in st.session_state: | |
st.session_state.tokens_used = 0 | |
if "tps" not in st.session_state: | |
st.session_state.tps = 0 | |
if "temp" not in st.session_state: | |
st.session_state.temp = 0.8 | |
if "history" not in st.session_state: | |
st.session_state.history = [ | |
[ | |
inital_prompt_engineering_dict["SYSTEM_INSTRUCTION"], | |
inital_prompt_engineering_dict["SYSTEM_RESPONSE"], | |
] | |
] | |
if "n_crawl" not in st.session_state: | |
st.session_state.n_crawl = 5 | |
if "repetion_penalty" not in st.session_state: | |
st.session_state.repetion_penalty = 1 | |
if "rag_enabled" not in st.session_state: | |
st.session_state.rag_enabled = True | |
if "chat_bot" not in st.session_state: | |
st.session_state.chat_bot = "Mixtral 8x7B v0.1" | |
if "search_vendor" not in st.session_state: | |
st.session_state.search_vendor = "Bing" | |
if "system_instruction" not in st.session_state: | |
st.session_state.system_instruction = inital_prompt_engineering_dict[ | |
"SYSTEM_INSTRUCTION" | |
] | |
if "system_response" not in st.session_state: | |
st.session_state.system_instruction = inital_prompt_engineering_dict[ | |
"SYSTEM_RESPONSE" | |
] | |
if "pre_context" not in st.session_state: | |
st.session_state.pre_context = inital_prompt_engineering_dict["PRE_CONTEXT"] | |
if "post_context" not in st.session_state: | |
st.session_state.post_context = inital_prompt_engineering_dict["POST_CONTEXT"] | |
if "pre_prompt" not in st.session_state: | |
st.session_state.pre_prompt = inital_prompt_engineering_dict["PRE_PROMPT"] | |
if "post_prompt" not in st.session_state: | |
st.session_state.post_prompt = inital_prompt_engineering_dict["POST_PROMPT"] | |
if "pass_prev" not in st.session_state: | |
st.session_state.pass_prev = False | |
if "chunk_size" not in st.session_state: | |
st.session_state.chunk_size = 512 | |
def sidebar(): | |
def retrieval_settings(): | |
st.markdown("# Web Retrieval") | |
st.session_state.rag_enabled = st.toggle("Activate Web Retrieval", value=True) | |
st.session_state.search_vendor = st.radio( | |
"Select Search Vendor", | |
["Bing", "Google"], | |
disabled=not st.session_state.rag_enabled, | |
) | |
st.session_state.n_crawl = st.slider( | |
label="Links to Crawl", | |
key=1, | |
min_value=1, | |
max_value=10, | |
value=4, | |
disabled=not st.session_state.rag_enabled, | |
) | |
st.session_state.top_k = st.slider( | |
label="Chunks to Retrieve via Reranker", | |
key=2, | |
min_value=1, | |
max_value=20, | |
value=5, | |
disabled=not st.session_state.rag_enabled, | |
) | |
st.session_state.chunk_size = st.slider( | |
label="Chunk Size", | |
value=512, | |
min_value=128, | |
max_value=1024, | |
step=8, | |
disabled=not st.session_state.rag_enabled, | |
) | |
st.markdown("---") | |
def model_analytics(): | |
st.markdown("# Model Analytics") | |
st.write("Total tokens used :", st.session_state["tokens_used"]) | |
st.write("Speed :", st.session_state["tps"], " tokens/sec") | |
st.write( | |
"Total cost incurred :", | |
round( | |
COST_PER_1000_TOKENS_USD * st.session_state["tokens_used"] / 1000, | |
3, | |
), | |
"USD", | |
) | |
st.markdown("---") | |
def model_settings(): | |
st.markdown("# Model Settings") | |
st.session_state.chat_bot = st.sidebar.radio( | |
"Select one:", [key for key, _ in CHAT_BOTS.items()] | |
) | |
st.session_state.temp = st.slider( | |
label="Temperature", min_value=0.0, max_value=1.0, step=0.1, value=0.9 | |
) | |
st.session_state.max_tokens = st.slider( | |
label="New tokens to generate", | |
min_value=64, | |
max_value=2048, | |
step=32, | |
value=512, | |
) | |
st.session_state.repetion_penalty = st.slider( | |
label="Repetion Penalty", min_value=0.0, max_value=1.0, step=0.1, value=1.0 | |
) | |
with st.sidebar: | |
retrieval_settings() | |
model_analytics() | |
model_settings() | |
st.markdown( | |
""" | |
> **Created by [Pragnesh Barik](https://barik.super.site) 🔗** | |
""" | |
) | |
def prompt_engineering_dashboard(): | |
def engineer_prompt(): | |
st.session_state.history[0] = [ | |
st.session_state.system_instruction, | |
st.session_state.system_response, | |
] | |
with st.expander("Prompt Engineering Dashboard"): | |
st.info( | |
"**The input to the model follows this below template**", | |
) | |
st.code( | |
""" | |
[SYSTEM INSTRUCTION] | |
[SYSTEM RESPONSE] | |
[... LIST OF PREV INPUTS] | |
[PRE CONTEXT] | |
[CONTEXT RETRIEVED FROM THE WEB] | |
[POST CONTEXT] | |
[PRE PROMPT] | |
[PROMPT] | |
[POST PROMPT] | |
[PREV GENERATED INPUT] # Only if Pass previous prompt set True | |
""" | |
) | |
st.session_state.system_instruction = st.text_area( | |
label="SYSTEM INSTRUCTION", | |
value=inital_prompt_engineering_dict["SYSTEM_INSTRUCTION"], | |
) | |
st.session_state.system_response = st.text_area( | |
"SYSTEM RESPONSE", value=inital_prompt_engineering_dict["SYSTEM_RESPONSE"] | |
) | |
col1, col2 = st.columns(2) | |
with col1: | |
st.session_state.pre_context = st.text_input( | |
"PRE CONTEXT", | |
value=inital_prompt_engineering_dict["PRE_CONTEXT"], | |
disabled=not st.session_state.rag_enabled, | |
) | |
st.session_state.post_context = st.text_input( | |
"POST CONTEXT", | |
value=inital_prompt_engineering_dict["POST_CONTEXT"], | |
disabled=not st.session_state.rag_enabled, | |
) | |
with col2: | |
st.session_state.pre_prompt = st.text_input( | |
"PRE PROMPT", value=inital_prompt_engineering_dict["PRE_PROMPT"] | |
) | |
st.session_state.post_prompt = st.text_input( | |
"POST PROMPT", value=inital_prompt_engineering_dict["POST_PROMPT"] | |
) | |
col3, col4 = st.columns(2) | |
with col3: | |
st.session_state.pass_prev = st.toggle("Pass previous Output") | |
with col4: | |
st.button("Engineer Prompts", on_click=engineer_prompt) | |
def header(): | |
st.write("# Mixtral Playground") | |
prompt_engineering_dashboard() | |
def chat_box(): | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
def generate_chat_stream(prompt): | |
# 1. augments prompt according to the template | |
# 2. returns chat_stream and source links | |
# 3. chat_stream and source links are used by stream_handler and show_source | |
links = [] | |
if st.session_state.rag_enabled: | |
with st.spinner("Fetching relevent documents from Web...."): | |
prompt, links = gen_augmented_prompt_via_websearch( | |
prompt=prompt, | |
pre_context=st.session_state.pre_context, | |
post_context=st.session_state.post_context, | |
pre_prompt=st.session_state.pre_prompt, | |
post_prompt=st.session_state.post_prompt, | |
vendor=st.session_state.search_vendor, | |
top_k=st.session_state.top_k, | |
n_crawl=st.session_state.n_crawl, | |
pass_prev=st.session_state.pass_prev, | |
prev_output=st.session_state.history[-1][1], | |
) | |
with st.spinner("Generating response..."): | |
chat_stream = chat( | |
prompt, | |
st.session_state.history, | |
chat_client=CHAT_BOTS[st.session_state.chat_bot], | |
temperature=st.session_state.temp, | |
max_new_tokens=st.session_state.max_tokens, | |
) | |
return chat_stream, links | |
def stream_handler(chat_stream, placeholder): | |
# 1. Uses the chat_stream and streams message on placeholder | |
# 2. returns full_response for token calculation | |
start_time = time.time() | |
full_response = "" | |
for chunk in chat_stream: | |
if chunk.token.text != "</s>": | |
full_response += chunk.token.text | |
placeholder.markdown(full_response + "▌") | |
placeholder.markdown(full_response) | |
end_time = time.time() | |
elapsed_time = end_time - start_time | |
total_tokens_processed = len(full_response.split()) | |
tokens_per_second = total_tokens_processed // elapsed_time | |
len_response = (len(prompt.split()) + len(full_response.split())) * 1.25 | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
st.write(f"**{tokens_per_second} tokens/second**") | |
with col2: | |
st.write(f"**{int(len_response)} tokens generated**") | |
with col3: | |
st.write( | |
f"**$ {round(len_response * COST_PER_1000_TOKENS_USD / 1000, 5)} cost incurred**" | |
) | |
st.session_state["tps"] = tokens_per_second | |
st.session_state["tokens_used"] = len_response + st.session_state["tokens_used"] | |
return full_response | |
def show_source(links): | |
# Expander component to show source | |
with st.expander("Show source"): | |
for i, link in enumerate(links): | |
st.info(f"{link}") | |
init_state() | |
sidebar() | |
header() | |
chat_box() | |
# Main chat loop | |
if prompt := st.chat_input("Generate Ebook"): | |
st.chat_message("user").markdown(prompt) | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
chat_stream, links = generate_chat_stream(prompt) | |
with st.chat_message("assistant"): | |
placeholder = st.empty() | |
full_response = stream_handler(chat_stream, placeholder) | |
if st.session_state.rag_enabled: | |
show_source(links) | |
st.session_state.history.append([prompt, full_response]) | |
st.session_state.messages.append({"role": "assistant", "content": full_response}) | |