import streamlit as st from chat_client import chat import time from utils import gen_augmented_prompt_via_websearch, inital_prompt_engineering_dict COST_PER_1000_TOKENS_USD = 0.139 * 80 CHAT_BOTS = { "Mixtral 8x7B v0.1": "mistralai/Mixtral-8x7B-Instruct-v0.1", "Mistral 7B v0.1": "mistralai/Mistral-7B-Instruct-v0.1", } st.set_page_config( page_title="Mixtral Playground", page_icon="📚", ) def init_state(): if "messages" not in st.session_state: st.session_state.messages = [] if "tokens_used" not in st.session_state: st.session_state.tokens_used = 0 if "tps" not in st.session_state: st.session_state.tps = 0 if "temp" not in st.session_state: st.session_state.temp = 0.8 if "history" not in st.session_state: st.session_state.history = [ [ inital_prompt_engineering_dict["SYSTEM_INSTRUCTION"], inital_prompt_engineering_dict["SYSTEM_RESPONSE"], ] ] if "n_crawl" not in st.session_state: st.session_state.n_crawl = 5 if "repetion_penalty" not in st.session_state: st.session_state.repetion_penalty = 1 if "rag_enabled" not in st.session_state: st.session_state.rag_enabled = True if "chat_bot" not in st.session_state: st.session_state.chat_bot = "Mixtral 8x7B v0.1" if "search_vendor" not in st.session_state: st.session_state.search_vendor = "Bing" if "system_instruction" not in st.session_state: st.session_state.system_instruction = inital_prompt_engineering_dict[ "SYSTEM_INSTRUCTION" ] if "system_response" not in st.session_state: st.session_state.system_instruction = inital_prompt_engineering_dict[ "SYSTEM_RESPONSE" ] if "pre_context" not in st.session_state: st.session_state.pre_context = inital_prompt_engineering_dict["PRE_CONTEXT"] if "post_context" not in st.session_state: st.session_state.post_context = inital_prompt_engineering_dict["POST_CONTEXT"] if "pre_prompt" not in st.session_state: st.session_state.pre_prompt = inital_prompt_engineering_dict["PRE_PROMPT"] if "post_prompt" not in st.session_state: st.session_state.post_prompt = inital_prompt_engineering_dict["POST_PROMPT"] if "pass_prev" not in st.session_state: st.session_state.pass_prev = False if "chunk_size" not in st.session_state: st.session_state.chunk_size = 512 def sidebar(): def retrieval_settings(): st.markdown("# Web Retrieval") st.session_state.rag_enabled = st.toggle("Activate Web Retrieval", value=True) st.session_state.search_vendor = st.radio( "Select Search Vendor", ["Bing", "Google"], disabled=not st.session_state.rag_enabled, ) st.session_state.n_crawl = st.slider( label="Links to Crawl", key=1, min_value=1, max_value=10, value=4, disabled=not st.session_state.rag_enabled, ) st.session_state.top_k = st.slider( label="Chunks to Retrieve via Reranker", key=2, min_value=1, max_value=20, value=5, disabled=not st.session_state.rag_enabled, ) st.session_state.chunk_size = st.slider( label="Chunk Size", value=512, min_value=128, max_value=1024, step=8, disabled=not st.session_state.rag_enabled, ) st.markdown("---") def model_analytics(): st.markdown("# Model Analytics") st.write("Total tokens used :", st.session_state["tokens_used"]) st.write("Speed :", st.session_state["tps"], " tokens/sec") st.write( "Total cost incurred :", round( COST_PER_1000_TOKENS_USD * st.session_state["tokens_used"] / 1000, 3, ), "USD", ) st.markdown("---") def model_settings(): st.markdown("# Model Settings") st.session_state.chat_bot = st.sidebar.radio( "Select one:", [key for key, _ in CHAT_BOTS.items()] ) st.session_state.temp = st.slider( label="Temperature", min_value=0.0, max_value=1.0, step=0.1, value=0.9 ) st.session_state.max_tokens = st.slider( label="New tokens to generate", min_value=64, max_value=2048, step=32, value=512, ) st.session_state.repetion_penalty = st.slider( label="Repetion Penalty", min_value=0.0, max_value=1.0, step=0.1, value=1.0 ) with st.sidebar: retrieval_settings() model_analytics() model_settings() st.markdown( """ > **Created by [Pragnesh Barik](https://barik.super.site) 🔗** """ ) def prompt_engineering_dashboard(): def engineer_prompt(): st.session_state.history[0] = [ st.session_state.system_instruction, st.session_state.system_response, ] with st.expander("Prompt Engineering Dashboard"): st.info( "**The input to the model follows this below template**", ) st.code( """ [SYSTEM INSTRUCTION] [SYSTEM RESPONSE] [... LIST OF PREV INPUTS] [PRE CONTEXT] [CONTEXT RETRIEVED FROM THE WEB] [POST CONTEXT] [PRE PROMPT] [PROMPT] [POST PROMPT] [PREV GENERATED INPUT] # Only if Pass previous prompt set True """ ) st.session_state.system_instruction = st.text_area( label="SYSTEM INSTRUCTION", value=inital_prompt_engineering_dict["SYSTEM_INSTRUCTION"], ) st.session_state.system_response = st.text_area( "SYSTEM RESPONSE", value=inital_prompt_engineering_dict["SYSTEM_RESPONSE"] ) col1, col2 = st.columns(2) with col1: st.session_state.pre_context = st.text_input( "PRE CONTEXT", value=inital_prompt_engineering_dict["PRE_CONTEXT"], disabled=not st.session_state.rag_enabled, ) st.session_state.post_context = st.text_input( "POST CONTEXT", value=inital_prompt_engineering_dict["POST_CONTEXT"], disabled=not st.session_state.rag_enabled, ) with col2: st.session_state.pre_prompt = st.text_input( "PRE PROMPT", value=inital_prompt_engineering_dict["PRE_PROMPT"] ) st.session_state.post_prompt = st.text_input( "POST PROMPT", value=inital_prompt_engineering_dict["POST_PROMPT"] ) col3, col4 = st.columns(2) with col3: st.session_state.pass_prev = st.toggle("Pass previous Output") with col4: st.button("Engineer Prompts", on_click=engineer_prompt) def header(): st.write("# Mixtral Playground") prompt_engineering_dashboard() def chat_box(): for message in st.session_state.messages: with st.chat_message(message["role"]): st.markdown(message["content"]) def generate_chat_stream(prompt): # 1. augments prompt according to the template # 2. returns chat_stream and source links # 3. chat_stream and source links are used by stream_handler and show_source links = [] if st.session_state.rag_enabled: with st.spinner("Fetching relevent documents from Web...."): prompt, links = gen_augmented_prompt_via_websearch( prompt=prompt, pre_context=st.session_state.pre_context, post_context=st.session_state.post_context, pre_prompt=st.session_state.pre_prompt, post_prompt=st.session_state.post_prompt, vendor=st.session_state.search_vendor, top_k=st.session_state.top_k, n_crawl=st.session_state.n_crawl, pass_prev=st.session_state.pass_prev, prev_output=st.session_state.history[-1][1], ) with st.spinner("Generating response..."): chat_stream = chat( prompt, st.session_state.history, chat_client=CHAT_BOTS[st.session_state.chat_bot], temperature=st.session_state.temp, max_new_tokens=st.session_state.max_tokens, ) return chat_stream, links def stream_handler(chat_stream, placeholder): # 1. Uses the chat_stream and streams message on placeholder # 2. returns full_response for token calculation start_time = time.time() full_response = "" for chunk in chat_stream: if chunk.token.text != "": full_response += chunk.token.text placeholder.markdown(full_response + "▌") placeholder.markdown(full_response) end_time = time.time() elapsed_time = end_time - start_time total_tokens_processed = len(full_response.split()) tokens_per_second = total_tokens_processed // elapsed_time len_response = (len(prompt.split()) + len(full_response.split())) * 1.25 col1, col2, col3 = st.columns(3) with col1: st.write(f"**{tokens_per_second} tokens/second**") with col2: st.write(f"**{int(len_response)} tokens generated**") with col3: st.write( f"**$ {round(len_response * COST_PER_1000_TOKENS_USD / 1000, 5)} cost incurred**" ) st.session_state["tps"] = tokens_per_second st.session_state["tokens_used"] = len_response + st.session_state["tokens_used"] return full_response def show_source(links): # Expander component to show source with st.expander("Show source"): for i, link in enumerate(links): st.info(f"{link}") init_state() sidebar() header() chat_box() # Main chat loop if prompt := st.chat_input("Generate Ebook"): st.chat_message("user").markdown(prompt) st.session_state.messages.append({"role": "user", "content": prompt}) chat_stream, links = generate_chat_stream(prompt) with st.chat_message("assistant"): placeholder = st.empty() full_response = stream_handler(chat_stream, placeholder) if st.session_state.rag_enabled: show_source(links) st.session_state.history.append([prompt, full_response]) st.session_state.messages.append({"role": "assistant", "content": full_response})