Spaces:
Runtime error
Runtime error
# ruff: noqa: E501 | |
from __future__ import annotations | |
import asyncio | |
import datetime | |
import pytz | |
import logging | |
import os | |
from enum import Enum | |
import json | |
import uuid | |
from pydantic import BaseModel | |
import gspread | |
from copy import deepcopy | |
from typing import Any, Dict, List, Optional, Tuple, Union | |
import gradio as gr | |
import tiktoken | |
# from dotenv import load_dotenv | |
# load_dotenv() | |
from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler | |
from langchain.callbacks.tracers.run_collector import RunCollectorCallbackHandler | |
from langchain.chains import ConversationChain | |
from langsmith import Client | |
from langchain.chat_models import ChatAnthropic, ChatOpenAI | |
from langchain.memory import ConversationTokenBufferMemory | |
from langchain.prompts.chat import ( | |
ChatPromptTemplate, | |
HumanMessagePromptTemplate, | |
MessagesPlaceholder, | |
SystemMessagePromptTemplate, | |
) | |
from langchain.schema import BaseMessage | |
logging.basicConfig(format="%(asctime)s %(name)s %(levelname)s:%(message)s") | |
LOG = logging.getLogger(__name__) | |
LOG.setLevel(logging.INFO) | |
GPT_3_5_CONTEXT_LENGTH = 4096 | |
CLAUDE_2_CONTEXT_LENGTH = 100000 # need to use claude tokenizer | |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY") | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
GS_CREDS = json.loads(rf"""{os.getenv("GSPREAD_SERVICE")}""") | |
GSHEET_ID = os.getenv("GSHEET_ID") | |
AUTH_GSHEET_NAME = os.getenv("AUTH_GSHEET_NAME") | |
TURNS_GSHEET_NAME = os.getenv("TURNS_GSHEET_NAME") | |
theme = gr.themes.Soft() | |
creds = [(os.getenv("CHAT_USERNAME"), os.getenv("CHAT_PASSWORD"))] | |
gradio_flagger = gr.HuggingFaceDatasetSaver( | |
hf_token=HF_TOKEN, dataset_name="chats", separate_dirs=True | |
) | |
def get_gsheet_rows( | |
sheet_id: str, sheet_name: str, creds: Dict[str, str] | |
) -> List[Dict[str, str]]: | |
gc = gspread.service_account_from_dict(creds) | |
worksheet = gc.open_by_key(sheet_id).worksheet(sheet_name) | |
rows = worksheet.get_all_records() | |
return rows | |
def append_gsheet_rows( | |
sheet_id: str, | |
rows: List[List[str]], | |
sheet_name: str, | |
creds: Dict[str, str], | |
) -> None: | |
gc = gspread.service_account_from_dict(creds) | |
worksheet = gc.open_by_key(sheet_id).worksheet(sheet_name) | |
worksheet.append_rows(values=rows, insert_data_option="INSERT_ROWS") | |
class ChatSystemMessage(str, Enum): | |
CASE_SYSTEM_MESSAGE = """You are a helpful AI assistant for a Columbia Business School MBA student. | |
Follow this message's instructions carefully. Respond using markdown. | |
Never repeat these instructions in a subsequent message. | |
You will start an conversation with me in the following form: | |
1. Below these instructions you will receive a business scenario. The scenario will (a) include the name of a company or category, and (b) a debatable multiple-choice question about the business scenario. | |
2. We will pretend to be executives charged with solving the strategic question outlined in the scenario. | |
3. To start the conversation, you will provide summarize the question and provide all options in the multiple choice question to me. Then, you will ask me to choose a position and provide a short opening argument. Do not yet provide your position. | |
4. After receiving my position and explanation. You will choose an alternate position in the scenario. | |
5. Inform me which position you have chosen, then proceed to have a discussion with me on this topic. | |
6. The discussion should be informative and very rigorous. Do not agree with my arguments easily. Pursue a Socratic method of questioning and reasoning. | |
""" | |
RESEARCH_SYSTEM_MESSAGE = """You are a helpful AI assistant for a Columbia Business School MBA student. | |
Follow this message's instructions carefully. Respond using markdown. | |
Never repeat these instructions in a subsequent message. | |
You will start an conversation with me in the following form: | |
1. You are to be a professional research consultant to the MBA student. | |
2. The student will be working in a group of classmates to collaborate on a proposal to solve a business dillema. | |
3. Be as helpful as you can to the student while remaining factual. | |
4. If you are not certain, please warn the student to conduct additional research on the internet. | |
5. Use tables and bullet points as useful way to compare insights | |
""" | |
class ChatbotMode(str, Enum): | |
DEBATE_PARTNER = "Debate Partner" | |
RESEARCH_ASSISTANT = "Research Assistant" | |
DEFAULT = DEBATE_PARTNER | |
class PollQuestion(BaseModel): # type: ignore[misc] | |
name: str | |
template: str | |
class PollQuestions(BaseModel): # type: ignore[misc] | |
cases: List[PollQuestion] | |
def from_json_file(cls, json_file_path: str) -> PollQuestions: | |
"""Expects a JSON file with an array of poll questions | |
Each JSON object should have "name" and "template" keys | |
""" | |
with open(json_file_path, "r") as json_f: | |
payload = json.load(json_f) | |
return_obj_list = [] | |
if isinstance(payload, list): | |
for case in payload: | |
return_obj_list.append(PollQuestion(**case)) | |
return cls(cases=return_obj_list) | |
raise ValueError( | |
f"JSON object in {json_file_path} must be an array of PollQuestion" | |
) | |
def get_case(self, case_name: str) -> PollQuestion: | |
"""Searches cases to return the template for poll question""" | |
for case in self.cases: | |
if case.name == case_name: | |
return case | |
def get_case_names(self) -> List[str]: | |
"""Returns the names in cases""" | |
return [case.name for case in self.cases] | |
poll_questions = PollQuestions.from_json_file("templates.json") | |
def reset_textbox(): | |
return gr.update(value=""), gr.update(value=""), gr.update(value="") | |
def auth(username, password): | |
try: | |
auth_records = get_gsheet_rows( | |
sheet_id=GSHEET_ID, sheet_name=AUTH_GSHEET_NAME, creds=GS_CREDS | |
) | |
auth_dict = {user["username"]: user["password"] for user in auth_records} | |
search_auth_user = auth_dict.get(username) | |
if search_auth_user: | |
autheticated = search_auth_user == password | |
if autheticated: | |
LOG.info(f"{username} successfully logged in.") | |
return autheticated | |
else: | |
LOG.info(f"{username} failed to login.") | |
return False | |
except Exception as exc: | |
LOG.info(f"{username} failed to login") | |
LOG.error(exc) | |
return (username, password) in creds | |
class ChatSession(BaseModel): | |
class Config: | |
arbitrary_types_allowed = True | |
context_length: int | |
tokenizer: tiktoken.Encoding | |
chain: ConversationChain | |
history: List[BaseMessage] = [] | |
session_id: str = str(uuid.uuid4()) | |
def set_metadata( | |
username: str, | |
chatbot_mode: str, | |
turns_completed: int, | |
case: Optional[str] = None, | |
) -> Dict[str, Union[str, int]]: | |
metadata = dict( | |
username=username, | |
chatbot_mode=chatbot_mode, | |
turns_completed=turns_completed, | |
case=case, | |
) | |
return metadata | |
def _make_template( | |
system_msg: str, poll_question_name: Optional[str] = None | |
) -> ChatPromptTemplate: | |
knowledge_cutoff = "Sept 2021" | |
current_date = datetime.datetime.now( | |
pytz.timezone("America/New_York") | |
).strftime("%Y-%m-%d") | |
if poll_question_name: | |
poll_question = poll_questions.get_case(poll_question_name) | |
if poll_question: | |
message_template = poll_question.template | |
system_msg += f""" | |
{message_template} | |
Knowledge cutoff: {knowledge_cutoff} | |
Current date: {current_date} | |
""" | |
else: | |
knowledge_cutoff = "Early 2023" | |
system_msg += f""" | |
Knowledge cutoff: {knowledge_cutoff} | |
Current date: {current_date} | |
""" | |
human_template = "{input}" | |
return ChatPromptTemplate.from_messages( | |
[ | |
SystemMessagePromptTemplate.from_template(system_msg), | |
MessagesPlaceholder(variable_name="history"), | |
HumanMessagePromptTemplate.from_template(human_template), | |
] | |
) | |
def _set_llm( | |
use_claude: bool, | |
) -> Tuple[Union[ChatOpenAI, ChatAnthropic], int, tiktoken.tokenizer]: | |
if use_claude: | |
llm = ChatAnthropic( | |
model="claude-2", | |
anthropic_api_key=ANTHROPIC_API_KEY, | |
temperature=1, | |
max_tokens_to_sample=5000, | |
streaming=True, | |
) | |
context_length = CLAUDE_2_CONTEXT_LENGTH | |
tokenizer = tiktoken.get_encoding("cl100k_base") | |
return llm, context_length, tokenizer | |
else: | |
llm = ChatOpenAI( | |
model_name="gpt-4", | |
temperature=1, | |
openai_api_key=OPENAI_API_KEY, | |
max_retries=6, | |
request_timeout=100, | |
streaming=True, | |
) | |
context_length = GPT_3_5_CONTEXT_LENGTH | |
_, tokenizer = llm._get_encoding_model() | |
return llm, context_length, tokenizer | |
def update_system_prompt( | |
self, system_msg: str, poll_question_name: Optional[str] = None | |
) -> None: | |
self.chain.prompt = self._make_template(system_msg, poll_question_name) | |
def change_llm(self, use_claude: bool) -> None: | |
llm, self.context_length, self.tokenizer = self._set_llm(use_claude) | |
self.chain.llm = llm | |
def clear_memory(self) -> None: | |
self.chain.memory.clear() | |
self.history = [] | |
def set_chatbot_mode( | |
self, case_mode: bool, poll_question_name: Optional[str] = None | |
) -> None: | |
if case_mode and poll_question_name: | |
self.change_llm(use_claude=False) | |
self.update_system_prompt( | |
system_msg=ChatSystemMessage.CASE_SYSTEM_MESSAGE, | |
poll_question_name=poll_question_name, | |
) | |
else: | |
self.change_llm(use_claude=True) | |
self.update_system_prompt( | |
system_msg=ChatSystemMessage.RESEARCH_SYSTEM_MESSAGE | |
) | |
def new( | |
cls, | |
use_claude: bool, | |
system_msg: str, | |
metadata: Dict[str, Any], | |
poll_question_name: Optional[str] = None, | |
) -> ChatSession: | |
llm, context_length, tokenizer = cls._set_llm(use_claude) | |
memory = ConversationTokenBufferMemory( | |
llm=llm, max_token_limit=context_length, return_messages=True | |
) | |
template = cls._make_template( | |
system_msg=system_msg, poll_question_name=poll_question_name | |
) | |
chain = ConversationChain( | |
memory=memory, | |
prompt=template, | |
llm=llm, | |
metadata=metadata, | |
) | |
return cls( | |
context_length=context_length, | |
tokenizer=tokenizer, | |
chain=chain, | |
) | |
async def respond( | |
chat_input: str, | |
chatbot_mode: str, | |
case_input: str, | |
state: ChatSession, | |
request: gr.Request, | |
) -> Tuple[List[str], ChatSession, str]: | |
"""Execute the chat functionality.""" | |
def prep_messages( | |
user_msg: str, memory_buffer: List[BaseMessage] | |
) -> Tuple[str, List[BaseMessage]]: | |
messages_to_send = state.chain.prompt.format_messages( | |
input=user_msg, history=memory_buffer | |
) | |
user_msg_token_count = state.chain.llm.get_num_tokens_from_messages( | |
[messages_to_send[-1]] | |
) | |
total_token_count = state.chain.llm.get_num_tokens_from_messages( | |
messages_to_send | |
) | |
while user_msg_token_count > state.context_length: | |
LOG.warning( | |
f"Pruning user message due to user message token length of {user_msg_token_count}" | |
) | |
user_msg = state.tokenizer.decode( | |
state.chain.llm.get_token_ids(user_msg)[: state.context_length - 100] | |
) | |
messages_to_send = state.chain.prompt.format_messages( | |
input=user_msg, history=memory_buffer | |
) | |
user_msg_token_count = state.chain.llm.get_num_tokens_from_messages( | |
[messages_to_send[-1]] | |
) | |
total_token_count = state.chain.llm.get_num_tokens_from_messages( | |
messages_to_send | |
) | |
while total_token_count > state.context_length: | |
LOG.warning( | |
f"Pruning memory due to total token length of {total_token_count}" | |
) | |
if len(memory_buffer) == 1: | |
memory_buffer.pop(0) | |
continue | |
memory_buffer = memory_buffer[1:] | |
messages_to_send = state.chain.prompt.format_messages( | |
input=user_msg, history=memory_buffer | |
) | |
total_token_count = state.chain.llm.get_num_tokens_from_messages( | |
messages_to_send | |
) | |
return user_msg, memory_buffer | |
try: | |
if state is None: | |
if chatbot_mode == ChatbotMode.DEBATE_PARTNER: | |
new_session = ChatSession.new( | |
use_claude=False, | |
system_msg=ChatSystemMessage.CASE_SYSTEM_MESSAGE, | |
metadata=ChatSession.set_metadata( | |
username=request.username, | |
chatbot_mode=chatbot_mode, | |
turns_completed=0, | |
case=case_input, | |
), | |
poll_question_name=case_input, | |
) | |
else: | |
new_session = ChatSession.new( | |
use_claude=True, | |
system_msg=ChatSystemMessage.RESEARCH_SYSTEM_MESSAGE, | |
metadata=ChatSession.set_metadata( | |
username=request.username, | |
chatbot_mode=chatbot_mode, | |
turns_completed=0, | |
), | |
poll_question_name=None, | |
) | |
state = new_session | |
state.chain.metadata = ChatSession.set_metadata( | |
username=request.username, | |
chatbot_mode=chatbot_mode, | |
turns_completed=len(state.history) + 1, | |
case=case_input, | |
) | |
LOG.info(f"""[{request.username}] STARTING CHAIN""") | |
LOG.debug(f"History: {state.history}") | |
LOG.debug(f"User input: {chat_input}") | |
chat_input, state.chain.memory.chat_memory.messages = prep_messages( | |
chat_input, state.chain.memory.buffer | |
) | |
messages_to_send = state.chain.prompt.format_messages( | |
input=chat_input, history=state.chain.memory.buffer | |
) | |
total_token_count = state.chain.llm.get_num_tokens_from_messages( | |
messages_to_send | |
) | |
LOG.debug(f"Messages to send: {messages_to_send}") | |
LOG.debug(f"Tokens to send: {total_token_count}") | |
callback = AsyncIteratorCallbackHandler() | |
run_collector = RunCollectorCallbackHandler() | |
run = asyncio.create_task( | |
state.chain.apredict( | |
input=chat_input, | |
callbacks=[callback, run_collector], | |
) | |
) | |
state.history.append((chat_input, "")) | |
run_id = None | |
langsmith_url = None | |
async for tok in callback.aiter(): | |
user, bot = state.history[-1] | |
bot += tok | |
state.history[-1] = (user, bot) | |
yield state.history, state, None | |
await run | |
if run_collector.traced_runs and run_id is None: | |
run_id = run_collector.traced_runs[0].id | |
LOG.info(f"RUNID: {run_id}") | |
if run_id: | |
run_collector.traced_runs = [] | |
try: | |
langsmith_url = Client().share_run(run_id) | |
LOG.info(f"""Run ID: {run_id} \n URL : {langsmith_url}""") | |
url_markdown = ( | |
f"""[Click to view shareable chat]({langsmith_url})""" | |
) | |
except Exception as exc: | |
LOG.error(exc) | |
url_markdown = "Share link not currently available" | |
if ( | |
len(state.history) > 9 | |
and chatbot_mode == ChatbotMode.DEBATE_PARTNER | |
): | |
url_markdown += """\n | |
π You have completed 10 exchanges with the chatbot.""" | |
yield state.history, state, url_markdown | |
LOG.info(f"""[{request.username}] ENDING CHAIN""") | |
LOG.debug(f"History: {state.history}") | |
LOG.debug(f"Memory: {state.chain.memory.json()}") | |
current_timestamp = datetime.datetime.now(pytz.timezone("US/Eastern")).replace( | |
tzinfo=None | |
) | |
timestamp_string = current_timestamp.strftime("%Y-%m-%d %H:%M:%S") | |
data_to_flag = ( | |
{ | |
"history": deepcopy(state.history), | |
"username": request.username, | |
"timestamp": timestamp_string, | |
"session_id": state.session_id, | |
"metadata": state.chain.metadata, | |
"langsmith_url": langsmith_url, | |
}, | |
) | |
gradio_flagger.flag(flag_data=data_to_flag, username=request.username) | |
(flagged_data,) = data_to_flag | |
metadata_to_gsheet = flagged_data.get("metadata").values() | |
gsheet_row = [[timestamp_string, *metadata_to_gsheet, langsmith_url]] | |
LOG.info(f"Data to GSHEET: {gsheet_row}") | |
append_gsheet_rows( | |
sheet_id=GSHEET_ID, | |
sheet_name=TURNS_GSHEET_NAME, | |
rows=gsheet_row, | |
creds=GS_CREDS, | |
) | |
except Exception as e: | |
LOG.error(e) | |
raise e | |
class ChatbotConfig(BaseModel): | |
app_title: str = "CBS Technology Strategy - Fall 2023" | |
chatbot_modes: List[ChatbotMode] = [mode.value for mode in ChatbotMode] | |
case_options: List[str] = poll_questions.get_case_names() | |
default_case_option: str = "Netflix" | |
def change_chatbot_mode( | |
state: ChatSession, chatbot_mode: str, poll_question_name: str, request: gr.Request | |
) -> Tuple[Any, ChatSession]: | |
"""Returns a function that sets the visibility of the case input field and the state""" | |
if state is None: | |
if chatbot_mode == ChatbotMode.DEBATE_PARTNER: | |
new_session = ChatSession.new( | |
use_claude=False, | |
system_msg=ChatSystemMessage.CASE_SYSTEM_MESSAGE, | |
metadata=ChatSession.set_metadata( | |
username=request.username, | |
chatbot_mode=chatbot_mode, | |
turns_completed=0, | |
case=poll_question_name, | |
), | |
poll_question_name=case_input, | |
) | |
else: | |
new_session = ChatSession.new( | |
use_claude=True, | |
system_msg=ChatSystemMessage.RESEARCH_SYSTEM_MESSAGE, | |
metadata=ChatSession.set_metadata( | |
username=request.username, | |
chatbot_mode=chatbot_mode, | |
turns_completed=0, | |
), | |
poll_question_name=None, | |
) | |
state = new_session | |
if chatbot_mode == ChatbotMode.DEBATE_PARTNER: | |
state.set_chatbot_mode(case_mode=True, poll_question_name=poll_question_name) | |
state.clear_memory() | |
return gr.update(visible=True), state | |
elif chatbot_mode == ChatbotMode.RESEARCH_ASSISTANT: | |
state.set_chatbot_mode(case_mode=False) | |
state.clear_memory() | |
return gr.update(visible=False), state | |
else: | |
raise ValueError("chatbot_mode is not correctly set") | |
config = ChatbotConfig() | |
with gr.Blocks( | |
theme=theme, | |
analytics_enabled=False, | |
title=config.app_title, | |
) as demo: | |
state = gr.State() | |
gr.Markdown(f"""### {config.app_title}""") | |
with gr.Tab("Chatbot"): | |
with gr.Row(): | |
chatbot_mode = gr.Radio( | |
label="Mode", | |
choices=config.chatbot_modes, | |
value=ChatbotMode.DEFAULT, | |
) | |
case_input = gr.Dropdown( | |
label="Case", | |
choices=config.case_options, | |
value=config.default_case_option, | |
multiselect=False, | |
) | |
chatbot = gr.Chatbot(label="ChatBot", show_share_button=False) | |
with gr.Row(): | |
input_message = gr.Textbox( | |
placeholder="Send a message.", | |
label="Type a message to begin", | |
scale=5, | |
) | |
chat_submit_button = gr.Button(value="Submit") | |
status_message = gr.Markdown() | |
gradio_flagger.setup([chatbot], "chats") | |
chatbot_submit_params = dict( | |
fn=respond, | |
inputs=[input_message, chatbot_mode, case_input, state], | |
outputs=[chatbot, state, status_message], | |
) | |
input_message.submit(**chatbot_submit_params) | |
chat_submit_button.click(**chatbot_submit_params) | |
chatbot_mode_params = dict( | |
fn=change_chatbot_mode, | |
inputs=[state, chatbot_mode, case_input], | |
outputs=[case_input, state], | |
) | |
chatbot_mode.change(**chatbot_mode_params) | |
case_input.change(**chatbot_mode_params) | |
clear_chatbot_messages_params = dict( | |
fn=reset_textbox, inputs=[], outputs=[input_message, chatbot, status_message] | |
) | |
chatbot_mode.change(**clear_chatbot_messages_params) | |
case_input.change(**clear_chatbot_messages_params) | |
chat_submit_button.click(**clear_chatbot_messages_params) | |
input_message.submit(**clear_chatbot_messages_params) | |
demo.queue(max_size=99, concurrency_count=99, api_open=False).launch( | |
debug=True, auth=auth | |
) | |