Spaces:
Running
Running
import os | |
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import pandas as pd | |
from datetime import datetime, timedelta, timezone | |
import torch | |
from config import hugging_face_token, init_google_sheets_client, models, quantized_models, default_model_name, user_names, google_sheets_name, MAX_INTERACTIONS | |
import spaces | |
# Hack for ZeroGPU | |
torch.jit.script = lambda f: f | |
# Initialize Google Sheets client | |
client = init_google_sheets_client() | |
sheet = client.open(google_sheets_name) | |
stories_sheet = sheet.worksheet("Stories") | |
system_prompts_sheet = sheet.worksheet("System Prompts") | |
# Load stories from Google Sheets | |
def load_stories(): | |
stories_data = stories_sheet.get_all_values() | |
stories = [{"title": story[0], "story": story[1]} for story in stories_data if story[0] != "Title"] # Skip header row | |
return stories | |
# Load system prompts from Google Sheets | |
def load_system_prompts(): | |
system_prompts_data = system_prompts_sheet.get_all_values() | |
system_prompts = [prompt[0] for prompt in system_prompts_data[1:]] # Skip header row | |
return system_prompts | |
# Load available stories and system prompts | |
stories = load_stories() | |
system_prompts = load_system_prompts() | |
# Initialize the selected model | |
selected_model = default_model_name | |
tokenizer, model = None, None | |
# Initialize the data list | |
data = [] | |
# Load the model and tokenizer once at the beginning | |
def load_model(model_name): | |
global tokenizer, model, selected_model | |
try: | |
# Release the memory of the previous model if exists | |
if model is not None: | |
del model | |
torch.cuda.empty_cache() | |
# Check if the model is in models or quantized_models and load accordingly | |
if model_name in models: | |
model_path = models[model_name] | |
elif model_name in quantized_models: | |
model_path = quantized_models[model_name] | |
else: | |
raise ValueError(f"Model {model_name} not found in either models or quantized_models.") | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_path, | |
padding_side='left', | |
token=hugging_face_token, | |
trust_remote_code=True | |
) | |
# Ensure the padding token is set | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
tokenizer.add_special_tokens({'pad_token': tokenizer.eos_token}) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_path, | |
token=hugging_face_token, | |
trust_remote_code=True | |
) | |
# Only move to CUDA if it's not a quantized model | |
if model_name not in quantized_models: | |
model = model.to("cuda") | |
selected_model = model_name | |
except Exception as e: | |
print(f"Error loading model {model_name}: {e}") | |
raise e | |
return tokenizer, model | |
# Ensure the initial model is loaded | |
tokenizer, model = load_model(selected_model) | |
# Chat history | |
chat_history = [] | |
# Function to handle interaction with model | |
def interact(user_input, history, interaction_count, model_name): | |
global tokenizer, model | |
try: | |
if tokenizer is None or model is None: | |
raise ValueError("Tokenizer or model is not initialized.") | |
# Determine the device to use (either CUDA if available, or CPU) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Only move the model to the device if it's not a quantized model | |
if model_name not in quantized_models: | |
model = model.to(device) | |
if interaction_count >= MAX_INTERACTIONS: | |
user_input += ". Thank you for your questions. Our session is now over. Goodbye!" | |
messages = history + [{"role": "user", "content": user_input}] | |
# Ensure roles alternate correctly | |
for i in range(1, len(messages)): | |
if messages[i-1].get("role") == messages[i].get("role"): | |
raise ValueError("Conversation roles must alternate user/assistant/user/assistant/...") | |
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
# Move input tensor to the correct device | |
input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(device) | |
chat_history_ids = model.generate(input_ids, max_new_tokens=100, pad_token_id=tokenizer.eos_token_id, temperature=0.1) | |
response = tokenizer.decode(chat_history_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True) | |
# Update chat history with generated response | |
history.append({"role": "user", "content": user_input}) | |
history.append({"role": "assistant", "content": response}) | |
interaction_count += 1 | |
formatted_history = [(entry["content"], None) if entry["role"] == "user" else (None, entry["content"]) for entry in history if entry["role"] in ["user", "assistant"]] | |
return "", formatted_history, history, interaction_count | |
except Exception as e: | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
print(f"Error during interaction: {e}") | |
raise gr.Error(f"An error occurred during interaction: {str(e)}") | |
# Function to send selected story and initial message | |
def send_selected_story(title, model_name, system_prompt): | |
global chat_history | |
global selected_story | |
global data # Ensure data is reset | |
data = [] # Reset data for new story | |
interaction_count = 1 # Reset interaction count for new story | |
tokenizer, model = load_model(model_name) # Load the appropriate model | |
selected_story = title | |
for story in stories: | |
if story["title"] == title: | |
system_prompt = f""" | |
{system_prompt} | |
Here is the story: | |
--- | |
{story['story']} | |
--- | |
""" | |
combined_message = system_prompt.strip() | |
if combined_message: | |
chat_history = [] # Reset chat history | |
chat_history.append({"role": "system", "content": combined_message}) | |
question_prompt = "Please ask a simple question about the story to encourage interaction." | |
_, formatted_history, chat_history, interaction_count = interact(question_prompt, chat_history, interaction_count, model_name) | |
return formatted_history, chat_history, gr.update(value=[]), story["story"] | |
else: | |
print("Combined message is empty.") | |
else: | |
print("Story title does not match.") | |
# Function to save comment and score | |
def save_comment_score(chat_responses, score, comment, story_name, user_name, system_prompt): | |
full_chat_history = "" | |
# Create formatted chat history with roles | |
for message in chat_responses: | |
if message[0]: # User message | |
full_chat_history += f"User: {message[0]}\n" | |
if message[1]: # Assistant message | |
full_chat_history += f"Assistant: {message[1]}\n" | |
timestamp = datetime.now(timezone.utc) - timedelta(hours=3) # Adjust to GMT-3 | |
timestamp_str = timestamp.strftime("%Y-%m-%d %H:%M:%S") | |
model_name = selected_model | |
# Append data to local data storage | |
data.append([ | |
timestamp_str, | |
user_name, | |
model_name, | |
system_prompt, | |
story_name, | |
full_chat_history, | |
score, | |
comment | |
]) | |
# Append data to Google Sheets | |
try: | |
user_sheet = client.open(google_sheets_name).worksheet(user_name) | |
except gspread.exceptions.WorksheetNotFound: | |
user_sheet = client.open(google_sheets_name).add_worksheet(title=user_name, rows="100", cols="20") | |
user_sheet.append_row([timestamp_str, user_name, model_name, system_prompt, story_name, full_chat_history, score, comment]) | |
df = pd.DataFrame(data, columns=["Timestamp", "User Name", "Model Name", "System Prompt", "Story Name", "Chat History", "Score", "Comment"]) | |
return df[["Chat History", "Score", "Comment"]], gr.update(value="") # Show only the required columns and clear the comment input box | |
# Function to load user guide from a file | |
def load_user_guide(): | |
with open('user_guide.txt', 'r') as file: | |
return file.read() | |
# Combine both model dictionaries | |
all_models = {**models, **quantized_models} | |
# Create the chat interface using Gradio Blocks | |
with gr.Blocks() as demo: | |
with gr.Tabs(): | |
with gr.TabItem("Chat"): | |
gr.Markdown("# Demo Chatbot V3") | |
gr.Markdown("## Context") | |
with gr.Group(): | |
model_dropdown = gr.Dropdown(choices=list(all_models.keys()), label="Select Model", value=default_model_name) | |
user_dropdown = gr.Dropdown(choices=user_names, label="Select User Name") | |
initial_story = stories[0]["title"] if stories else None | |
story_dropdown = gr.Dropdown(choices=[story["title"] for story in stories], label="Select Story", value=initial_story) | |
system_prompt_dropdown = gr.Dropdown(choices=system_prompts, label="Select System Prompt", value=system_prompts[0]) | |
send_story_button = gr.Button("Send Story") | |
gr.Markdown("## Chat") | |
with gr.Group(): | |
selected_story_textbox = gr.Textbox(label="Selected Story", lines=10, interactive=False) | |
chatbot_output = gr.Chatbot(label="Chat History") | |
chatbot_input = gr.Textbox(placeholder="Type your message here...", label="User Input") | |
send_message_button = gr.Button("Send") | |
gr.Markdown("## Evaluation") | |
with gr.Group(): | |
score_input = gr.Slider(minimum=0, maximum=5, step=1, label="Score") | |
comment_input = gr.Textbox(placeholder="Add a comment...", label="Comment") | |
save_button = gr.Button("Save Score and Comment") | |
data_table = gr.DataFrame(headers=["Chat History", "Score", "Comment"]) | |
with gr.TabItem("User Guide"): | |
gr.Textbox(label="User Guide", value=load_user_guide(), lines=20) | |
chat_history_json = gr.JSON(value=[], visible=False) | |
interaction_count = gr.Number(value=0, visible=False) | |
send_story_button.click(fn=send_selected_story, inputs=[story_dropdown, model_dropdown, system_prompt_dropdown], outputs=[chatbot_output, chat_history_json, data_table, selected_story_textbox]) | |
send_message_button.click(fn=interact, inputs=[chatbot_input, chat_history_json, interaction_count, model_dropdown], outputs=[chatbot_input, chatbot_output, chat_history_json, interaction_count]) | |
save_button.click(fn=save_comment_score, inputs=[chatbot_output, score_input, comment_input, story_dropdown, user_dropdown, system_prompt_dropdown], outputs=[data_table, comment_input]) | |
demo.launch() |