Spaces:
Running
Running
import datetime | |
import os | |
import re | |
from io import StringIO | |
import gradio as gr | |
import pandas as pd | |
from huggingface_hub import upload_file | |
from text_generation import Client | |
from dialogues import DialogueTemplate | |
from share_btn import (community_icon_html, loading_icon_html, share_btn_css, | |
share_js) | |
model2endpoint = { | |
"starchat-beta": os.environ.get("API_URL", None), | |
} | |
model_names = list(model2endpoint.keys()) | |
def get_total_inputs(inputs, chatbot, preprompt, user_name, assistant_name, sep): | |
past = [] | |
for data in chatbot: | |
user_data, model_data = data | |
if not user_data.startswith(user_name): | |
user_data = user_name + user_data | |
if not model_data.startswith(sep + assistant_name): | |
model_data = sep + assistant_name + model_data | |
past.append(user_data + model_data.rstrip() + sep) | |
if not inputs.startswith(user_name): | |
inputs = user_name + inputs | |
total_inputs = preprompt + "".join(past) + inputs + sep + assistant_name.rstrip() | |
return total_inputs | |
def wrap_html_code(text): | |
pattern = r"<.*?>" | |
matches = re.findall(pattern, text) | |
if len(matches) > 0: | |
return f"```{text}```" | |
else: | |
return text | |
def has_no_history(chatbot, history): | |
return not chatbot and not history | |
def generate( | |
model_name, | |
system_message, | |
user_message, | |
chatbot, | |
history, | |
temperature, | |
top_k, | |
top_p, | |
max_new_tokens, | |
repetition_penalty, | |
): | |
client = Client( | |
model2endpoint[model_name] | |
) | |
# Don't return meaningless message when the input is empty | |
if not user_message: | |
print("Empty input") | |
history.append(user_message) | |
past_messages = [] | |
for data in chatbot: | |
user_data, model_data = data | |
past_messages.extend( | |
[{"role": "user", "content": user_data}, {"role": "assistant", "content": model_data.rstrip()}] | |
) | |
if len(past_messages) < 1: | |
dialogue_template = DialogueTemplate( | |
system=system_message, messages=[{"role": "user", "content": user_message}] | |
) | |
prompt = dialogue_template.get_inference_prompt() | |
else: | |
dialogue_template = DialogueTemplate( | |
system=system_message, messages=past_messages + [{"role": "user", "content": user_message}] | |
) | |
prompt = dialogue_template.get_inference_prompt() | |
generate_kwargs = { | |
"temperature": temperature, | |
"top_k": top_k, | |
"top_p": top_p, | |
"max_new_tokens": max_new_tokens, | |
} | |
temperature = float(temperature) | |
if temperature < 1e-2: | |
temperature = 1e-2 | |
top_p = float(top_p) | |
generate_kwargs = dict( | |
temperature=temperature, | |
max_new_tokens=max_new_tokens, | |
top_p=top_p, | |
repetition_penalty=repetition_penalty, | |
do_sample=True, | |
truncate=1000, | |
seed=42, | |
stop_sequences=["<|end|>"], | |
) | |
stream = client.generate_stream( | |
prompt, | |
**generate_kwargs, | |
) | |
output = "" | |
for idx, response in enumerate(stream): | |
if response.token.special: | |
continue | |
output += response.token.text | |
if idx == 0: | |
history.append(" " + output) | |
else: | |
history[-1] = output | |
chat = [ | |
(wrap_html_code(history[i].strip()), wrap_html_code(history[i + 1].strip())) | |
for i in range(0, len(history) - 1, 2) | |
] | |
# chat = [(history[i].strip(), history[i + 1].strip()) for i in range(0, len(history) - 1, 2)] | |
yield chat, history, user_message, "" | |
return chat, history, user_message, "" | |
examples = [ | |
"How can I write a Python function to generate the nth Fibonacci number?", | |
"How do I get the current date using shell commands? Explain how it works.", | |
"What's the meaning of life?", | |
"Write a function in Javascript to reverse words in a given string.", | |
"Give the following data {'Name':['Tom', 'Brad', 'Kyle', 'Jerry'], 'Age':[20, 21, 19, 18], 'Height' : [6.1, 5.9, 6.0, 6.1]}. Can you plot one graph with two subplots as columns. The first is a bar graph showing the height of each person. The second is a bargraph showing the age of each person? Draw the graph in seaborn talk mode.", | |
"Create a regex to extract dates from logs", | |
"How to decode JSON into a typescript object", | |
"Write a list into a jsonlines file and save locally", | |
] | |
def clear_chat(): | |
return [], [] | |
def process_example(args): | |
for [x, y] in generate(args): | |
pass | |
return [x, y] | |
title = """<h1 align="center">⭐ StarChat Playground 💬</h1>""" | |
custom_css = """ | |
#banner-image { | |
display: block; | |
margin-left: auto; | |
margin-right: auto; | |
} | |
#chat-message { | |
font-size: 14px; | |
min-height: 300px; | |
} | |
""" | |
with gr.Blocks(analytics_enabled=False, css=custom_css) as demo: | |
gr.HTML(title) | |
with gr.Row(): | |
selected_model = gr.Radio(choices=model_names, value=model_names[0], label="Select a model") | |
with gr.Accordion(label="System Prompt", open=False, elem_id="parameters-accordion"): | |
system_message = gr.Textbox( | |
elem_id="system-message", | |
placeholder="Below is a conversation between a human user and a helpful AI coding assistant.", | |
show_label=False, | |
) | |
with gr.Row(): | |
with gr.Box(): | |
output = gr.Markdown() | |
chatbot = gr.Chatbot(elem_id="chat-message", label="Chat") | |
with gr.Row(): | |
with gr.Column(scale=3): | |
user_message = gr.Textbox(placeholder="Enter your message here", show_label=False, elem_id="q-input") | |
with gr.Row(): | |
send_button = gr.Button("Send", elem_id="send-btn", visible=True) | |
# regenerate_button = gr.Button("Regenerate", elem_id="send-btn", visible=True) | |
clear_chat_button = gr.Button("Clear chat", elem_id="clear-btn", visible=True) | |
with gr.Accordion(label="Parameters", open=False, elem_id="parameters-accordion"): | |
temperature = gr.Slider( | |
label="Temperature", | |
value=0.2, | |
minimum=0.0, | |
maximum=1.0, | |
step=0.1, | |
interactive=True, | |
info="Higher values produce more diverse outputs", | |
) | |
top_k = gr.Slider( | |
label="Top-k", | |
value=50, | |
minimum=0.0, | |
maximum=100, | |
step=1, | |
interactive=True, | |
info="Sample from a shortlist of top-k tokens", | |
) | |
top_p = gr.Slider( | |
label="Top-p (nucleus sampling)", | |
value=0.95, | |
minimum=0.0, | |
maximum=1, | |
step=0.05, | |
interactive=True, | |
info="Higher values sample more low-probability tokens", | |
) | |
max_new_tokens = gr.Slider( | |
label="Max new tokens", | |
value=1024, | |
minimum=0, | |
maximum=2048, | |
step=4, | |
interactive=True, | |
info="The maximum numbers of new tokens", | |
) | |
repetition_penalty = gr.Slider( | |
label="Repetition Penalty", | |
value=1.2, | |
minimum=0.0, | |
maximum=10, | |
step=0.1, | |
interactive=True, | |
info="The parameter for repetition penalty. 1.0 means no penalty.", | |
) | |
# with gr.Group(elem_id="share-btn-container"): | |
# community_icon = gr.HTML(community_icon_html, visible=True) | |
# loading_icon = gr.HTML(loading_icon_html, visible=True) | |
# share_button = gr.Button("Share to community", elem_id="share-btn", visible=True) | |
with gr.Row(): | |
gr.Examples( | |
examples=examples, | |
inputs=[user_message], | |
cache_examples=False, | |
fn=process_example, | |
outputs=[output], | |
) | |
history = gr.State([]) | |
# To clear out "message" input textbox and use this to regenerate message | |
last_user_message = gr.State("") | |
user_message.submit( | |
generate, | |
inputs=[ | |
selected_model, | |
system_message, | |
user_message, | |
chatbot, | |
history, | |
temperature, | |
top_k, | |
top_p, | |
max_new_tokens, | |
repetition_penalty, | |
], | |
outputs=[chatbot, history, last_user_message, user_message], | |
) | |
send_button.click( | |
generate, | |
inputs=[ | |
selected_model, | |
system_message, | |
user_message, | |
chatbot, | |
history, | |
temperature, | |
top_k, | |
top_p, | |
max_new_tokens, | |
repetition_penalty, | |
], | |
outputs=[chatbot, history, last_user_message, user_message], | |
) | |
clear_chat_button.click(clear_chat, outputs=[chatbot, history]) | |
selected_model.change(clear_chat, outputs=[chatbot, history]) | |
# share_button.click(None, [], [], _js=share_js) | |
demo.queue(concurrency_count=16).launch() | |