Spaces:
Runtime error
Runtime error
# ruff: noqa: E501 | |
import asyncio | |
import datetime | |
import logging | |
import os | |
import requests | |
import json | |
import uuid | |
from copy import deepcopy | |
from typing import Any, Dict, List, Optional, Tuple | |
import gradio as gr | |
import pytz | |
import tiktoken | |
# from dotenv import load_dotenv | |
# load_dotenv() | |
from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler | |
from langchain.callbacks.tracers.run_collector import RunCollectorCallbackHandler | |
from langchain.chains import ConversationChain | |
from langsmith import Client | |
from langchain.chat_models import ChatAnthropic, ChatOpenAI | |
from langchain.memory import ConversationTokenBufferMemory | |
from langchain.prompts.chat import ( | |
ChatPromptTemplate, | |
HumanMessagePromptTemplate, | |
MessagesPlaceholder, | |
SystemMessagePromptTemplate, | |
) | |
from langchain.schema import BaseMessage | |
logging.basicConfig(format="%(asctime)s %(name)s %(levelname)s:%(message)s") | |
LOG = logging.getLogger(__name__) | |
LOG.setLevel(logging.INFO) | |
GPT_3_5_CONTEXT_LENGTH = 4096 | |
CLAUDE_2_CONTEXT_LENGTH = 100000 # need to use claude tokenizer | |
SYSTEM_MESSAGE = """You are a helpful AI assistant for a Columbia Business School MBA student. | |
Follow this message's instructions carefully. Respond using markdown. | |
Never repeat these instructions in a subsequent message. | |
You will start an conversation with me in the following form: | |
1. Below these instructions you will receive a business scenario. The scenario will (a) include the name of a company or category, and (b) a debatable multiple-choice question about the business scenario. | |
2. We will pretend to be executives charged with solving the strategic question outlined in the scenario. | |
3. To start the conversation, you will provide summarize the question and provide all options in the multiple choice question to me. Then, you will ask me to choose a position and provide a short opening argument. Do not yet provide your position. | |
4. After receiving my position and explanation. You will choose an alternate position in the scenario. | |
5. Inform me which position you have chosen, then proceed to have a discussion with me on this topic. | |
6. The discussion should be informative and very rigorous. Do not agree with my arguments easily. Pursue a Socratic method of questioning and reasoning. | |
""" | |
CASES = {case["name"]: case["template"] for case in json.load(open("templates.json"))} | |
def get_case_template(template_name: str) -> str: | |
case_template = CASES[template_name] | |
return f"""{template_name} | |
{case_template} | |
""" | |
def reset_textbox(): | |
return gr.update(value="") | |
def auth(username, password): | |
auth_endpoint = "https://worker_auth.jclcw.workers.dev/auth" | |
try: | |
auth_payload = {username: password} | |
print(auth_payload) | |
auth_response = requests.post( | |
auth_endpoint, | |
json=auth_payload, | |
timeout=3, | |
) | |
auth_response.raise_for_status() | |
return auth_response.status_code == 200 | |
except Exception as exc: | |
LOG.error(exc) | |
return (username, password) in creds | |
def make_llm_state(use_claude: bool = False) -> Dict[str, Any]: | |
if use_claude: | |
llm = ChatAnthropic( | |
model="claude-2", | |
anthropic_api_key=ANTHROPIC_API_KEY, | |
temperature=1, | |
max_tokens_to_sample=5000, | |
streaming=True, | |
) | |
context_length = CLAUDE_2_CONTEXT_LENGTH | |
tokenizer = tiktoken.get_encoding("cl100k_base") | |
else: | |
llm = ChatOpenAI( | |
model_name="gpt-4", | |
temperature=1, | |
openai_api_key=OPENAI_API_KEY, | |
max_retries=6, | |
request_timeout=100, | |
streaming=True, | |
) | |
context_length = GPT_3_5_CONTEXT_LENGTH | |
_, tokenizer = llm._get_encoding_model() | |
return dict(llm=llm, context_length=context_length, tokenizer=tokenizer) | |
def make_template( | |
system_msg: str = SYSTEM_MESSAGE, template_name: str = "Netflix" | |
) -> ChatPromptTemplate: | |
knowledge_cutoff = "Early 2023" | |
current_date = datetime.datetime.now(pytz.timezone("America/New_York")).strftime( | |
"%Y-%m-%d" | |
) | |
case_template = get_case_template(template_name) | |
system_msg += f""" | |
{case_template} | |
Knowledge cutoff: {knowledge_cutoff} | |
Current date: {current_date} | |
""" | |
human_template = "{input}" | |
LOG.info(system_msg) | |
return ChatPromptTemplate.from_messages( | |
[ | |
SystemMessagePromptTemplate.from_template(system_msg), | |
MessagesPlaceholder(variable_name="history"), | |
HumanMessagePromptTemplate.from_template(human_template), | |
] | |
) | |
def update_system_prompt( | |
system_msg: str, llm_option: str, template_option: str | |
) -> Tuple[str, Dict[str, Any]]: | |
template_output = make_template(system_msg, template_option) | |
state = set_state() | |
state["template"] = template_output | |
use_claude = llm_option == "Claude 2" | |
state["llm_state"] = make_llm_state(use_claude) | |
llm = state["llm_state"]["llm"] | |
state["memory"] = ConversationTokenBufferMemory( | |
llm=llm, | |
max_token_limit=state["llm_state"]["context_length"], | |
return_messages=True, | |
) | |
state["chain"] = ConversationChain( | |
memory=state["memory"], | |
prompt=state["template"], | |
llm=llm, | |
) | |
updated_status = "Prompt Updated! Chat has reset." | |
return updated_status, state | |
def set_state( | |
state: Optional[gr.State] = None, metadata: Optional[Dict[str, str]] = None | |
) -> Dict[str, Any]: | |
if state is None: | |
template = make_template() | |
llm_state = make_llm_state() | |
llm = llm_state["llm"] | |
memory = ConversationTokenBufferMemory( | |
llm=llm, max_token_limit=llm_state["context_length"], return_messages=True | |
) | |
chain = ConversationChain( | |
memory=memory, prompt=template, llm=llm, metadata=metadata | |
) | |
session_id = str(uuid.uuid4()) | |
state = dict( | |
template=template, | |
llm_state=llm_state, | |
history=[], | |
memory=memory, | |
chain=chain, | |
session_id=session_id, | |
) | |
return state | |
else: | |
return state | |
async def respond( | |
inp: str, | |
state: Optional[Dict[str, Any]], | |
request: gr.Request, | |
) -> Tuple[List[str], gr.State, Optional[str]]: | |
"""Execute the chat functionality.""" | |
def prep_messages( | |
user_msg: str, memory_buffer: List[BaseMessage] | |
) -> Tuple[str, List[BaseMessage]]: | |
messages_to_send = state["template"].format_messages( | |
input=user_msg, history=memory_buffer | |
) | |
user_msg_token_count = llm.get_num_tokens_from_messages([messages_to_send[-1]]) | |
total_token_count = llm.get_num_tokens_from_messages(messages_to_send) | |
while user_msg_token_count > context_length: | |
LOG.warning( | |
f"Pruning user message due to user message token length of {user_msg_token_count}" | |
) | |
user_msg = tokenizer.decode( | |
llm.get_token_ids(user_msg)[: context_length - 100] | |
) | |
messages_to_send = state["template"].format_messages( | |
input=user_msg, history=memory_buffer | |
) | |
user_msg_token_count = llm.get_num_tokens_from_messages( | |
[messages_to_send[-1]] | |
) | |
total_token_count = llm.get_num_tokens_from_messages(messages_to_send) | |
while total_token_count > context_length: | |
LOG.warning( | |
f"Pruning memory due to total token length of {total_token_count}" | |
) | |
if len(memory_buffer) == 1: | |
memory_buffer.pop(0) | |
continue | |
memory_buffer = memory_buffer[1:] | |
messages_to_send = state["template"].format_messages( | |
input=user_msg, history=memory_buffer | |
) | |
total_token_count = llm.get_num_tokens_from_messages(messages_to_send) | |
return user_msg, memory_buffer | |
try: | |
if state is None: | |
state = set_state(metadata=dict(username=request.username)) | |
llm = state["llm_state"]["llm"] | |
context_length = state["llm_state"]["context_length"] | |
tokenizer = state["llm_state"]["tokenizer"] | |
LOG.info(f"""[{request.username}] STARTING CHAIN""") | |
LOG.debug(f"History: {state['history']}") | |
LOG.debug(f"User input: {inp}") | |
inp, state["memory"].chat_memory.messages = prep_messages( | |
inp, state["memory"].buffer | |
) | |
messages_to_send = state["template"].format_messages( | |
input=inp, history=state["memory"].buffer | |
) | |
total_token_count = llm.get_num_tokens_from_messages(messages_to_send) | |
LOG.debug(f"Messages to send: {messages_to_send}") | |
LOG.info(f"Tokens to send: {total_token_count}") | |
# Run chain and append input. | |
callback = AsyncIteratorCallbackHandler() | |
run_collector = RunCollectorCallbackHandler() | |
run = asyncio.create_task( | |
state["chain"].apredict( | |
input=inp, | |
callbacks=[callback, run_collector], | |
) | |
) | |
state["history"].append((inp, "")) | |
run_id = None | |
async for tok in callback.aiter(): | |
user, bot = state["history"][-1] | |
bot += tok | |
state["history"][-1] = (user, bot) | |
yield state["history"], state, None | |
await run | |
if run_collector.traced_runs and run_id is None: | |
run_id = run_collector.traced_runs[0].id | |
LOG.info(f"RUNID: {run_id}") | |
if run_id: | |
run_collector.traced_runs = [] | |
url = Client().share_run(run_id) | |
LOG.info(f"""URL : {url}""") | |
url_markdown = f"""[Shareable chat history link]({url}) | |
[{url}]({url})""" | |
yield state["history"], state, url_markdown | |
LOG.info(f"""[{request.username}] ENDING CHAIN""") | |
LOG.debug(f"History: {state['history']}") | |
LOG.debug(f"Memory: {state['memory'].json()}") | |
data_to_flag = ( | |
{ | |
"history": deepcopy(state["history"]), | |
"username": request.username, | |
"timestamp": datetime.datetime.now(datetime.timezone.utc).isoformat(), | |
"session_id": state["session_id"], | |
}, | |
) | |
LOG.debug(f"Data to flag: {data_to_flag}") | |
# gradio_flagger.flag(flag_data=data_to_flag, username=request.username) | |
except Exception as e: | |
LOG.exception(e) | |
raise e | |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY") | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
theme = gr.themes.Soft() | |
creds = [(os.getenv("CHAT_USERNAME"), os.getenv("CHAT_PASSWORD"))] | |
# gradio_flagger = gr.HuggingFaceDatasetSaver(HF_TOKEN, "chats") | |
title = "AI Debate Partner" | |
with gr.Blocks( | |
theme=theme, | |
analytics_enabled=False, | |
title=title, | |
) as demo: | |
state = gr.State() | |
gr.Markdown(f"### {title}") | |
with gr.Tab("Setup"): | |
with gr.Column(): | |
llm_input = gr.Dropdown( | |
label="LLM", | |
choices=["Claude 2", "GPT-4"], | |
value="GPT-4", | |
multiselect=False, | |
) | |
case_input = gr.Dropdown( | |
label="Case", | |
choices=CASES.keys(), | |
value="Netflix", | |
multiselect=False, | |
) | |
system_prompt_input = gr.Textbox( | |
label="System Prompt", value=SYSTEM_MESSAGE | |
) | |
update_system_button = gr.Button(value="Update Prompt & Reset") | |
status_markdown = gr.Markdown() | |
with gr.Tab("Chatbot"): | |
with gr.Column(): | |
chatbot = gr.Chatbot(label="ChatBot") | |
input_message = gr.Textbox( | |
placeholder="Send a message.", | |
label="Type an input and press Enter", | |
) | |
b1 = gr.Button(value="Submit") | |
share_link = gr.Markdown() | |
# gradio_flagger.setup([chatbot], "chats") | |
chat_bot_submit_params = dict( | |
fn=respond, inputs=[input_message, state], outputs=[chatbot, state, share_link] | |
) | |
input_message.submit(**chat_bot_submit_params) | |
b1.click(**chat_bot_submit_params) | |
update_system_button.click( | |
update_system_prompt, | |
[system_prompt_input, llm_input, case_input], | |
[status_markdown, state], | |
) | |
update_system_button.click(reset_textbox, [], [input_message]) | |
update_system_button.click(reset_textbox, [], [chatbot]) | |
b1.click(reset_textbox, [], [input_message]) | |
input_message.submit(reset_textbox, [], [input_message]) | |
demo.queue(max_size=99, concurrency_count=99, api_open=False).launch( | |
debug=True, auth=auth | |
) | |