|
import gradio as gr |
|
from huggingface_hub import InferenceClient |
|
|
|
|
|
SYSTEM_MESSAGE_DEFAULT = "You are a friendly Chatbot." |
|
MAX_TOKENS_DEFAULT = 512 |
|
TEMPERATURE_DEFAULT = 0.7 |
|
TOP_P_DEFAULT = 0.95 |
|
|
|
inference_client = InferenceClient("HuggingFaceH4/zephyr-7b-beta") |
|
|
|
|
|
def respond( |
|
user_message: str, |
|
conversation_history: list[tuple[str, str]], |
|
system_message: str, |
|
max_tokens: int, |
|
temperature: float, |
|
top_p: float, |
|
): |
|
""" |
|
Respond to a user message given the conversation history and other parameters. |
|
|
|
Args: |
|
user_message (str): The user's message. |
|
conversation_history (list[tuple[str, str]]): The conversation history. |
|
system_message (str): The system message to display at the top of the chat interface. |
|
max_tokens (int): The maximum number of tokens to generate in the response. |
|
temperature (float): The temperature to use when generating text. |
|
top_p (float): The top-p value to use when generating text. |
|
|
|
Yields: |
|
list[tuple[str, str]]: Updated conversation history with the new assistant response. |
|
""" |
|
messages = [{"role": "system", "content": system_message}] |
|
|
|
for user_input, assistant_response in conversation_history: |
|
if user_input: |
|
messages.append({"role": "user", "content": user_input}) |
|
if assistant_response: |
|
messages.append({"role": "assistant", "content": assistant_response}) |
|
|
|
|
|
messages.append({"role": "user", "content": user_message}) |
|
|
|
|
|
response = "" |
|
|
|
|
|
for message in inference_client.chat_completion( |
|
messages, |
|
max_tokens=max_tokens, |
|
stream=True, |
|
temperature=temperature, |
|
top_p=top_p, |
|
): |
|
token = message.choices[0].delta.content |
|
response += token |
|
|
|
updated_history = conversation_history + [(user_message, response)] |
|
yield updated_history |
|
|
|
|
|
|
|
chatbot_interface = gr.ChatInterface( |
|
fn=respond, |
|
chatbot=gr.Chatbot(height=600), |
|
additional_inputs=[ |
|
gr.Textbox( |
|
value=SYSTEM_MESSAGE_DEFAULT, |
|
label="System message", |
|
), |
|
gr.Slider( |
|
minimum=1, |
|
maximum=2048, |
|
value=MAX_TOKENS_DEFAULT, |
|
step=1, |
|
label="Max new tokens", |
|
), |
|
gr.Slider( |
|
minimum=0.1, |
|
maximum=4.0, |
|
value=TEMPERATURE_DEFAULT, |
|
step=0.1, |
|
label="Temperature", |
|
), |
|
gr.Slider( |
|
minimum=0.1, |
|
maximum=1.0, |
|
value=TOP_P_DEFAULT, |
|
step=0.05, |
|
label="Top-p (nucleus sampling)", |
|
), |
|
], |
|
) |
|
|
|
if __name__ == "__main__": |
|
chatbot_interface.launch() |
|
|