ebook-gen / app.py
pragneshbarik's picture
removed debug print statement
ac9e241
raw
history blame
10.9 kB
import streamlit as st
from chat_client import chat
import time
from utils import gen_augmented_prompt_via_websearch, inital_prompt_engineering_dict
COST_PER_1000_TOKENS_USD = 0.139 * 80
CHAT_BOTS = {
"Mixtral 8x7B v0.1": "mistralai/Mixtral-8x7B-Instruct-v0.1",
"Mistral 7B v0.1": "mistralai/Mistral-7B-Instruct-v0.1",
}
st.set_page_config(
page_title="Mixtral Playground",
page_icon="📚",
)
def init_state():
if "messages" not in st.session_state:
st.session_state.messages = []
if "tokens_used" not in st.session_state:
st.session_state.tokens_used = 0
if "tps" not in st.session_state:
st.session_state.tps = 0
if "temp" not in st.session_state:
st.session_state.temp = 0.8
if "history" not in st.session_state:
st.session_state.history = [
[
inital_prompt_engineering_dict["SYSTEM_INSTRUCTION"],
inital_prompt_engineering_dict["SYSTEM_RESPONSE"],
]
]
if "n_crawl" not in st.session_state:
st.session_state.n_crawl = 5
if "repetion_penalty" not in st.session_state:
st.session_state.repetion_penalty = 1
if "rag_enabled" not in st.session_state:
st.session_state.rag_enabled = True
if "chat_bot" not in st.session_state:
st.session_state.chat_bot = "Mixtral 8x7B v0.1"
if "search_vendor" not in st.session_state:
st.session_state.search_vendor = "Bing"
if "system_instruction" not in st.session_state:
st.session_state.system_instruction = inital_prompt_engineering_dict[
"SYSTEM_INSTRUCTION"
]
if "system_response" not in st.session_state:
st.session_state.system_instruction = inital_prompt_engineering_dict[
"SYSTEM_RESPONSE"
]
if "pre_context" not in st.session_state:
st.session_state.pre_context = inital_prompt_engineering_dict["PRE_CONTEXT"]
if "post_context" not in st.session_state:
st.session_state.post_context = inital_prompt_engineering_dict["POST_CONTEXT"]
if "pre_prompt" not in st.session_state:
st.session_state.pre_prompt = inital_prompt_engineering_dict["PRE_PROMPT"]
if "post_prompt" not in st.session_state:
st.session_state.post_prompt = inital_prompt_engineering_dict["POST_PROMPT"]
if "pass_prev" not in st.session_state:
st.session_state.pass_prev = False
if "chunk_size" not in st.session_state:
st.session_state.chunk_size = 512
def sidebar():
def retrieval_settings():
st.markdown("# Web Retrieval")
st.session_state.rag_enabled = st.toggle("Activate Web Retrieval", value=True)
st.session_state.search_vendor = st.radio(
"Select Search Vendor",
["Bing", "Google"],
disabled=not st.session_state.rag_enabled,
)
st.session_state.n_crawl = st.slider(
label="Links to Crawl",
key=1,
min_value=1,
max_value=10,
value=4,
disabled=not st.session_state.rag_enabled,
)
st.session_state.top_k = st.slider(
label="Chunks to Retrieve via Reranker",
key=2,
min_value=1,
max_value=20,
value=5,
disabled=not st.session_state.rag_enabled,
)
st.session_state.chunk_size = st.slider(
label="Chunk Size",
value=512,
min_value=128,
max_value=1024,
step=8,
disabled=not st.session_state.rag_enabled,
)
st.markdown("---")
def model_analytics():
st.markdown("# Model Analytics")
st.write("Total tokens used :", st.session_state["tokens_used"])
st.write("Speed :", st.session_state["tps"], " tokens/sec")
st.write(
"Total cost incurred :",
round(
COST_PER_1000_TOKENS_USD * st.session_state["tokens_used"] / 1000,
3,
),
"USD",
)
st.markdown("---")
def model_settings():
st.markdown("# Model Settings")
st.session_state.chat_bot = st.sidebar.radio(
"Select one:", [key for key, _ in CHAT_BOTS.items()]
)
st.session_state.temp = st.slider(
label="Temperature", min_value=0.0, max_value=1.0, step=0.1, value=0.9
)
st.session_state.max_tokens = st.slider(
label="New tokens to generate",
min_value=64,
max_value=2048,
step=32,
value=512,
)
st.session_state.repetion_penalty = st.slider(
label="Repetion Penalty", min_value=0.0, max_value=1.0, step=0.1, value=1.0
)
with st.sidebar:
retrieval_settings()
model_analytics()
model_settings()
st.markdown(
"""
> **Created by [Pragnesh Barik](https://barik.super.site) 🔗**
"""
)
def prompt_engineering_dashboard():
def engineer_prompt():
st.session_state.history[0] = [
st.session_state.system_instruction,
st.session_state.system_response,
]
with st.expander("Prompt Engineering Dashboard"):
st.info(
"**The input to the model follows this below template**",
)
st.code(
"""
[SYSTEM INSTRUCTION]
[SYSTEM RESPONSE]
[... LIST OF PREV INPUTS]
[PRE CONTEXT]
[CONTEXT RETRIEVED FROM THE WEB]
[POST CONTEXT]
[PRE PROMPT]
[PROMPT]
[POST PROMPT]
[PREV GENERATED INPUT] # Only if Pass previous prompt set True
"""
)
st.session_state.system_instruction = st.text_area(
label="SYSTEM INSTRUCTION",
value=inital_prompt_engineering_dict["SYSTEM_INSTRUCTION"],
)
st.session_state.system_response = st.text_area(
"SYSTEM RESPONSE", value=inital_prompt_engineering_dict["SYSTEM_RESPONSE"]
)
col1, col2 = st.columns(2)
with col1:
st.session_state.pre_context = st.text_input(
"PRE CONTEXT",
value=inital_prompt_engineering_dict["PRE_CONTEXT"],
disabled=not st.session_state.rag_enabled,
)
st.session_state.post_context = st.text_input(
"POST CONTEXT",
value=inital_prompt_engineering_dict["POST_CONTEXT"],
disabled=not st.session_state.rag_enabled,
)
with col2:
st.session_state.pre_prompt = st.text_input(
"PRE PROMPT", value=inital_prompt_engineering_dict["PRE_PROMPT"]
)
st.session_state.post_prompt = st.text_input(
"POST PROMPT", value=inital_prompt_engineering_dict["POST_PROMPT"]
)
col3, col4 = st.columns(2)
with col3:
st.session_state.pass_prev = st.toggle("Pass previous Output")
with col4:
st.button("Engineer Prompts", on_click=engineer_prompt)
def header():
st.write("# Mixtral Playground")
prompt_engineering_dashboard()
def chat_box():
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
def generate_chat_stream(prompt):
# 1. augments prompt according to the template
# 2. returns chat_stream and source links
# 3. chat_stream and source links are used by stream_handler and show_source
links = []
if st.session_state.rag_enabled:
with st.spinner("Fetching relevent documents from Web...."):
prompt, links = gen_augmented_prompt_via_websearch(
prompt=prompt,
pre_context=st.session_state.pre_context,
post_context=st.session_state.post_context,
pre_prompt=st.session_state.pre_prompt,
post_prompt=st.session_state.post_prompt,
vendor=st.session_state.search_vendor,
top_k=st.session_state.top_k,
n_crawl=st.session_state.n_crawl,
pass_prev=st.session_state.pass_prev,
prev_output=st.session_state.history[-1][1],
)
with st.spinner("Generating response..."):
chat_stream = chat(
prompt,
st.session_state.history,
chat_client=CHAT_BOTS[st.session_state.chat_bot],
temperature=st.session_state.temp,
max_new_tokens=st.session_state.max_tokens,
)
return chat_stream, links
def stream_handler(chat_stream, placeholder):
# 1. Uses the chat_stream and streams message on placeholder
# 2. returns full_response for token calculation
start_time = time.time()
full_response = ""
for chunk in chat_stream:
if chunk.token.text != "</s>":
full_response += chunk.token.text
placeholder.markdown(full_response + "▌")
placeholder.markdown(full_response)
end_time = time.time()
elapsed_time = end_time - start_time
total_tokens_processed = len(full_response.split())
tokens_per_second = total_tokens_processed // elapsed_time
len_response = (len(prompt.split()) + len(full_response.split())) * 1.25
col1, col2, col3 = st.columns(3)
with col1:
st.write(f"**{tokens_per_second} tokens/second**")
with col2:
st.write(f"**{int(len_response)} tokens generated**")
with col3:
st.write(
f"**$ {round(len_response * COST_PER_1000_TOKENS_USD / 1000, 5)} cost incurred**"
)
st.session_state["tps"] = tokens_per_second
st.session_state["tokens_used"] = len_response + st.session_state["tokens_used"]
return full_response
def show_source(links):
# Expander component to show source
with st.expander("Show source"):
for i, link in enumerate(links):
st.info(f"{link}")
init_state()
sidebar()
header()
chat_box()
# Main chat loop
if prompt := st.chat_input("Generate Ebook"):
st.chat_message("user").markdown(prompt)
st.session_state.messages.append({"role": "user", "content": prompt})
chat_stream, links = generate_chat_stream(prompt)
with st.chat_message("assistant"):
placeholder = st.empty()
full_response = stream_handler(chat_stream, placeholder)
if st.session_state.rag_enabled:
show_source(links)
st.session_state.history.append([prompt, full_response])
st.session_state.messages.append({"role": "assistant", "content": full_response})