Spaces:
Running
Running
import streamlit as st | |
import yaml | |
def sidebar(session_state, config): | |
COST_PER_1000_TOKENS_USD = config["COST_PER_1000_TOKENS_USD"] | |
CHAT_BOTS = config["CHAT_BOTS"] | |
def retrieval_settings(): | |
st.markdown("# Web Retrieval") | |
session_state.rag_enabled = st.toggle("Activate Web Retrieval", value=True) | |
session_state.search_vendor = st.radio( | |
"Select Search Vendor", | |
["Bing", "Google"], | |
disabled=not session_state.rag_enabled, | |
) | |
session_state.n_crawl = st.slider( | |
label="Links to Crawl", | |
key=1, | |
min_value=1, | |
max_value=10, | |
value=4, | |
disabled=not session_state.rag_enabled, | |
) | |
session_state.top_k = st.slider( | |
label="Chunks to Retrieve via Reranker", | |
key=2, | |
min_value=1, | |
max_value=20, | |
value=5, | |
disabled=not session_state.rag_enabled, | |
) | |
session_state.chunk_size = st.slider( | |
label="Chunk Size", | |
value=512, | |
min_value=128, | |
max_value=1024, | |
step=8, | |
disabled=not session_state.rag_enabled, | |
) | |
st.markdown("---") | |
def model_analytics(): | |
st.markdown("# Model Analytics") | |
st.write("Total tokens used :", session_state["tokens_used"]) | |
st.write("Speed :", session_state["tps"], " tokens/sec") | |
st.write( | |
"Total cost incurred :", | |
round( | |
COST_PER_1000_TOKENS_USD * session_state["tokens_used"] / 1000, | |
3, | |
), | |
"USD", | |
) | |
st.markdown("---") | |
def model_settings(): | |
st.markdown("# Model Settings") | |
session_state.chat_bot = st.sidebar.radio( | |
"Select one:", [key for key, _ in CHAT_BOTS.items()] | |
) | |
session_state.temp = st.slider( | |
label="Temperature", min_value=0.0, max_value=1.0, step=0.1, value=0.9 | |
) | |
session_state.max_tokens = st.slider( | |
label="New tokens to generate", | |
min_value=64, | |
max_value=2048, | |
step=32, | |
value=512, | |
) | |
session_state.repetion_penalty = st.slider( | |
label="Repetion Penalty", min_value=0.0, max_value=1.0, step=0.1, value=1.0 | |
) | |
with st.sidebar: | |
retrieval_settings() | |
model_analytics() | |
model_settings() | |