import os import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer import pandas as pd from datetime import datetime, timedelta, timezone import torch from config import hugging_face_token, replicate_token, init_google_sheets_client, huggingface_tokenizer, replicate_model, default_model_name, user_names, google_sheets_name, MAX_INTERACTIONS import replicate # Hack for ZeroGPU torch.jit.script = lambda f: f # Initialize Google Sheets client client = init_google_sheets_client() sheet = client.open(google_sheets_name) stories_sheet = sheet.worksheet("Stories") system_prompts_sheet = sheet.worksheet("System Prompts") # Load stories from Google Sheets def load_stories(): stories_data = stories_sheet.get_all_values() stories = [{"title": story[0], "story": story[1]} for story in stories_data if story[0] != "Title"] # Skip header row return stories # Load system prompts from Google Sheets def load_system_prompts(): system_prompts_data = system_prompts_sheet.get_all_values() system_prompts = [prompt[0] for prompt in system_prompts_data[1:]] # Skip header row return system_prompts # Load available stories and system prompts stories = load_stories() system_prompts = load_system_prompts() # Initialize the selected model selected_model = default_model_name tokenizer, model = None, None # Initialize the data list data = [] #Initialize replicate client replicate_api = replicate.Client(api_token=replicate_token) # Load the model and tokenizer once at the beginning def load_model(model_name): global tokenizer, selected_model #model try: # Check if the model is models and load accordingly if model_name in huggingface_tokenizer: model_path = huggingface_tokenizer[model_name] else: raise ValueError(f"Model {model_name} not found in models") tokenizer = AutoTokenizer.from_pretrained( model_path, padding_side='left', token=hugging_face_token, trust_remote_code=True ) # Ensure the padding token is set if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token tokenizer.add_special_tokens({'pad_token': tokenizer.eos_token}) selected_model = model_name except Exception as e: print(f"Error loading model {model_name}: {e}") raise e return tokenizer#, model # Ensure the initial model is loaded #ahora solo load tokenizer tokenizer = load_model(selected_model) #, model # Chat history chat_history = [] # Function to handle interaction with model def interact(user_input, history, interaction_count, model_name): #global tokenizer, model tokenizer = load_model(model_name) if tokenizer is None: #or model is None: raise ValueError("Tokenizer or model is not initialized.") if interaction_count >= MAX_INTERACTIONS: user_input += ". Thank you for your questions. Our session is now over. Goodbye!" messages = history + [{"role": "user", "content": user_input}] # Ensure roles alternate correctly for i in range(1, len(messages)): if messages[i-1].get("role") == messages[i].get("role"): raise ValueError("Conversation roles must alternate user/assistant/user/assistant/...") prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) # Generate request inpt = {"prompt": prompt, "max_new_tokens": 100, "temperature": 0.1, "prompt_template": "{prompt}",} #"num_return_sequences": 1, #"pad_token_id": tokenizer.eos_token_id} #make request response = replicate_api.run( replicate_model[model_name], input=inpt ) response = "".join(response).strip() # Update chat history with generated response history.append({"role": "user", "content": user_input}) history.append({"role": "assistant", "content": response}) interaction_count += 1 formatted_history = [(entry["content"], None) if entry["role"] == "user" else (None, entry["content"]) for entry in history if entry["role"] in ["user", "assistant"]] return "", formatted_history, history, interaction_count # Function to send selected story and initial message def send_selected_story(title, model_name, system_prompt): global chat_history global selected_story global data # Ensure data is reset data = [] # Reset data for new story interaction_count = 1 # Reset interaction count for new story tokenizer = load_model(model_name) # Load the appropriate model selected_story = title for story in stories: if story["title"] == title: system_prompt = f""" {system_prompt} Here is the story: --- {story['story']} --- """ combined_message = system_prompt.strip() if combined_message: chat_history = [] # Reset chat history chat_history.append({"role": "system", "content": combined_message}) question_prompt = "Please ask a simple question about the story to encourage interaction." _, formatted_history, chat_history, interaction_count = interact(question_prompt, chat_history, interaction_count, model_name) return formatted_history, chat_history, gr.update(value=[]), story["story"] else: print("Combined message is empty.") else: print("Story title does not match.") # Function to save comment and score def save_comment_score(chat_responses, score, comment, story_name, user_name, system_prompt): full_chat_history = "" # Create formatted chat history with roles for message in chat_responses: if message[0]: # User message full_chat_history += f"User: {message[0]}\n" if message[1]: # Assistant message full_chat_history += f"Assistant: {message[1]}\n" timestamp = datetime.now(timezone.utc) - timedelta(hours=3) # Adjust to GMT-3 timestamp_str = timestamp.strftime("%Y-%m-%d %H:%M:%S") model_name = selected_model # Append data to local data storage data.append([ timestamp_str, user_name, model_name, system_prompt, story_name, full_chat_history, score, comment ]) # Append data to Google Sheets try: user_sheet = client.open(google_sheets_name).worksheet(user_name) except gspread.exceptions.WorksheetNotFound: user_sheet = client.open(google_sheets_name).add_worksheet(title=user_name, rows="100", cols="20") user_sheet.append_row([timestamp_str, user_name, model_name, system_prompt, story_name, full_chat_history, score, comment]) df = pd.DataFrame(data, columns=["Timestamp", "User Name", "Model Name", "System Prompt", "Story Name", "Chat History", "Score", "Comment"]) return df[["Chat History", "Score", "Comment"]], gr.update(value="") # Show only the required columns and clear the comment input box # Function to load user guide from a file def load_user_guide(): with open('user_guide.txt', 'r') as file: return file.read() # Combine both model dictionaries all_models = {**huggingface_tokenizer} # Create the chat interface using Gradio Blocks with gr.Blocks() as demo: with gr.Tabs(): with gr.TabItem("Chat"): gr.Markdown("# Demo Chatbot V3") gr.Markdown("## Context") with gr.Group(): model_dropdown = gr.Dropdown(choices=list(all_models.keys()), label="Select Model", value=default_model_name) user_dropdown = gr.Dropdown(choices=user_names, label="Select User Name") initial_story = stories[0]["title"] if stories else None story_dropdown = gr.Dropdown(choices=[story["title"] for story in stories], label="Select Story", value=initial_story) system_prompt_dropdown = gr.Dropdown(choices=system_prompts, label="Select System Prompt", value=system_prompts[0]) send_story_button = gr.Button("Send Story") gr.Markdown("## Chat") with gr.Group(): selected_story_textbox = gr.Textbox(label="Selected Story", lines=10, interactive=False) chatbot_output = gr.Chatbot(label="Chat History") chatbot_input = gr.Textbox(placeholder="Type your message here...", label="User Input") send_message_button = gr.Button("Send") gr.Markdown("## Evaluation") with gr.Group(): score_input = gr.Slider(minimum=0, maximum=5, step=1, label="Score") comment_input = gr.Textbox(placeholder="Add a comment...", label="Comment") save_button = gr.Button("Save Score and Comment") data_table = gr.DataFrame(headers=["Chat History", "Score", "Comment"]) with gr.TabItem("User Guide"): gr.Textbox(label="User Guide", value=load_user_guide(), lines=20) chat_history_json = gr.JSON(value=[], visible=False) interaction_count = gr.Number(value=0, visible=False) send_story_button.click(fn=send_selected_story, inputs=[story_dropdown, model_dropdown, system_prompt_dropdown], outputs=[chatbot_output, chat_history_json, data_table, selected_story_textbox]) send_message_button.click(fn=interact, inputs=[chatbot_input, chat_history_json, interaction_count, model_dropdown], outputs=[chatbot_input, chatbot_output, chat_history_json, interaction_count]) save_button.click(fn=save_comment_score, inputs=[chatbot_output, score_input, comment_input, story_dropdown, user_dropdown, system_prompt_dropdown], outputs=[data_table, comment_input]) demo.launch()