Spaces:
Running
Running
import gradio as gr | |
from functools import lru_cache | |
# Cache model loading to optimize performance | |
def load_hf_model(model_name): | |
return gr.load(f"models/{model_name}", src="huggingface") | |
# Load all models at startup | |
MODELS = { | |
"deepseek-ai/DeepSeek-R1-Distill-Qwen-32B": load_hf_model("deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"), | |
"deepseek-ai/DeepSeek-R1": load_hf_model("deepseek-ai/DeepSeek-R1"), | |
"deepseek-ai/DeepSeek-R1-Zero": load_hf_model("deepseek-ai/DeepSeek-R1-Zero") | |
} | |
# --- Chatbot function --- | |
def chatbot(input_text, history, model_choice, system_message, max_new_tokens, temperature, top_p): | |
history = history or [] | |
# Get the selected model component | |
model_component = MODELS[model_choice] | |
# Create payload for the model | |
payload = { | |
"inputs": input_text, # Directly pass the input text | |
"parameters": { | |
"max_new_tokens": max_new_tokens, | |
"temperature": temperature, | |
"top_p": top_p, | |
"return_full_text": False # Only return the generated text | |
} | |
} | |
# Run inference using the selected model | |
try: | |
response = model_component(**payload) # Pass payload as keyword arguments | |
if isinstance(response, list) and len(response) > 0: | |
# Extract the generated text from the response | |
assistant_response = response[0].get("generated_text", "No response generated.") | |
else: | |
assistant_response = "Unexpected model response format." | |
except Exception as e: | |
assistant_response = f"Error: {str(e)}" | |
# Append user and assistant messages to history | |
history.append((input_text, assistant_response)) | |
return history, history, "" | |
# --- Gradio Interface --- | |
with gr.Blocks(theme=gr.themes.Soft(), title="DeepSeek Chatbot") as demo: | |
gr.Markdown( | |
""" | |
# DeepSeek Chatbot | |
Created by [ruslanmv.com](https://ruslanmv.com/) | |
This is a demo of different DeepSeek models. Select a model, type your message, and click "Submit". | |
You can also adjust optional parameters like system message, max new tokens, temperature, and top-p. | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
chatbot_output = gr.Chatbot(label="DeepSeek Chatbot", height=500) | |
msg = gr.Textbox(label="Your Message", placeholder="Type your message here...") | |
with gr.Row(): | |
submit_btn = gr.Button("Submit", variant="primary") | |
clear_btn = gr.ClearButton([msg, chatbot_output]) | |
with gr.Row(): | |
with gr.Accordion("Options", open=True): | |
model_choice = gr.Radio( | |
choices=list(MODELS.keys()), | |
label="Choose a Model", | |
value="deepseek-ai/DeepSeek-R1" | |
) | |
with gr.Accordion("Optional Parameters", open=False): | |
system_message = gr.Textbox( | |
label="System Message", | |
value="You are a friendly Chatbot created by ruslanmv.com", | |
lines=2, | |
) | |
max_new_tokens = gr.Slider( | |
minimum=1, maximum=4000, value=200, label="Max New Tokens" | |
) | |
temperature = gr.Slider( | |
minimum=0.10, maximum=4.00, value=0.70, label="Temperature" | |
) | |
top_p = gr.Slider( | |
minimum=0.10, maximum=1.00, value=0.90, label="Top-p (nucleus sampling)" | |
) | |
chat_history = gr.State([]) | |
# Event handling | |
submit_btn.click( | |
chatbot, | |
[msg, chat_history, model_choice, system_message, max_new_tokens, temperature, top_p], | |
[chatbot_output, chat_history, msg] | |
) | |
msg.submit( | |
chatbot, | |
[msg, chat_history, model_choice, system_message, max_new_tokens, temperature, top_p], | |
[chatbot_output, chat_history, msg] | |
) | |
if __name__ == "__main__": | |
demo.launch() |