Update app.py
Browse files
app.py
CHANGED
@@ -15,11 +15,9 @@ with st.sidebar:
|
|
15 |
system_prompt_input = st.text_input("Optional system prompt:")
|
16 |
temperature_slider = st.slider("Temperature", min_value=0.0, max_value=1.0, value=0.9, step=0.05)
|
17 |
max_new_tokens_slider = st.slider("Max new tokens", min_value=0.0, max_value=4096.0, value=4096.0, step=64.0)
|
18 |
-
topp_slider = st.slider("Top-p (nucleus sampling)", min_value=0.0, max_value=1.0, value=0.6, step=0.05)
|
19 |
-
repetition_penalty_slider = st.slider("Repetition penalty", min_value=0.0, max_value=2.0, value=1.2, step=0.05)
|
20 |
|
21 |
# Prediction function
|
22 |
-
def get_llama2_response(user_message, system_prompt, temperature, max_new_tokens, topp, repetition_penalty):
|
23 |
with st.status("Requesting Llama-2"):
|
24 |
st.write("Requesting API...")
|
25 |
response = llama2_client.predict(
|
@@ -57,9 +55,7 @@ if user_input := st.chat_input("Ask Llama-2-70B anything..."):
|
|
57 |
user_input,
|
58 |
system_prompt_input,
|
59 |
temperature_slider,
|
60 |
-
max_new_tokens_slider
|
61 |
-
topp_slider,
|
62 |
-
repetition_penalty_slider
|
63 |
)
|
64 |
# Display assistant response in chat message container
|
65 |
with st.chat_message("assistant", avatar='🦙'):
|
|
|
15 |
system_prompt_input = st.text_input("Optional system prompt:")
|
16 |
temperature_slider = st.slider("Temperature", min_value=0.0, max_value=1.0, value=0.9, step=0.05)
|
17 |
max_new_tokens_slider = st.slider("Max new tokens", min_value=0.0, max_value=4096.0, value=4096.0, step=64.0)
|
|
|
|
|
18 |
|
19 |
# Prediction function
|
20 |
+
def get_llama2_response(user_message, system_prompt, temperature, max_new_tokens, topp=0.6, repetition_penalty=1.2):
|
21 |
with st.status("Requesting Llama-2"):
|
22 |
st.write("Requesting API...")
|
23 |
response = llama2_client.predict(
|
|
|
55 |
user_input,
|
56 |
system_prompt_input,
|
57 |
temperature_slider,
|
58 |
+
max_new_tokens_slider
|
|
|
|
|
59 |
)
|
60 |
# Display assistant response in chat message container
|
61 |
with st.chat_message("assistant", avatar='🦙'):
|