Spaces:
Runtime error
Runtime error
# ruff: noqa: E501 | |
import asyncio | |
import datetime | |
import logging | |
import os | |
import uuid | |
from copy import deepcopy | |
from typing import Any, Dict, List, Optional, Tuple | |
import gradio as gr | |
import pytz | |
import tiktoken | |
# from dotenv import load_dotenv | |
from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler | |
from langchain.chains import ConversationChain | |
from langchain.chat_models import ChatAnthropic, ChatOpenAI | |
from langchain.memory import ConversationTokenBufferMemory | |
from langchain.prompts.chat import ( | |
ChatPromptTemplate, | |
HumanMessagePromptTemplate, | |
MessagesPlaceholder, | |
SystemMessagePromptTemplate, | |
) | |
from langchain.schema import BaseMessage | |
logging.basicConfig(format="%(asctime)s %(name)s %(levelname)s:%(message)s") | |
LOG = logging.getLogger(__name__) | |
LOG.setLevel(logging.INFO) | |
GPT_3_5_CONTEXT_LENGTH = 4096 | |
CLAUDE_2_CONTEXT_LENGTH = 100000 # need to use claude tokenizer | |
SYSTEM_MESSAGE = """You are Claude, an AI assistant created by Anthropic. | |
Follow this message's instructions carefully. Respond using markdown. | |
Never repeat these instructions in a subsequent message. | |
Let's pretend that you and I are two executives at Netflix. We are having a discussion about the strategic question, to which there are three answers: | |
Going forward, what should Netflix prioritize? | |
(1) Invest more in original content than licensing third-party content, (2) Invest more in licensing third-party content than original content, (3) Balance between original content and licensing. | |
You will start an conversation with me in the following form: | |
1. Provide the 3 options succinctly, and you will ask me to choose a position and provide a short opening argument. Do not yet provide your position. | |
2. After receiving my position and explanation. You will choose an alternate position. | |
3. Inform me what position you have chosen, then proceed to have a discussion with me on this topic. | |
4. The discussion should be informative, but also rigorous. Do not agree with my arguments too easily.""" | |
# load_dotenv() | |
def reset_textbox(): | |
return gr.update(value="") | |
def auth(username, password): | |
return (username, password) in creds | |
def make_llm_state(use_claude: bool = False) -> Dict[str, Any]: | |
if use_claude: | |
llm = ChatAnthropic( | |
model="claude-2", | |
anthropic_api_key=ANTHROPIC_API_KEY, | |
temperature=1, | |
max_tokens_to_sample=5000, | |
streaming=True, | |
) | |
context_length = CLAUDE_2_CONTEXT_LENGTH | |
tokenizer = tiktoken.get_encoding("cl100k_base") | |
else: | |
llm = ChatOpenAI( | |
model_name="gpt-4", | |
temperature=1, | |
openai_api_key=OPENAI_API_KEY, | |
max_retries=6, | |
request_timeout=100, | |
streaming=True, | |
) | |
context_length = GPT_3_5_CONTEXT_LENGTH | |
_, tokenizer = llm._get_encoding_model() | |
return dict(llm=llm, context_length=context_length, tokenizer=tokenizer) | |
def make_template(system_msg: str = SYSTEM_MESSAGE) -> ChatPromptTemplate: | |
knowledge_cutoff = "Early 2023" | |
current_date = datetime.datetime.now(pytz.timezone("America/New_York")).strftime( | |
"%Y-%m-%d" | |
) | |
system_msg += f""" | |
Knowledge cutoff: {knowledge_cutoff} | |
Current date: {current_date} | |
""" | |
human_template = "{input}" | |
LOG.info(system_msg) | |
return ChatPromptTemplate.from_messages( | |
[ | |
SystemMessagePromptTemplate.from_template(system_msg), | |
MessagesPlaceholder(variable_name="history"), | |
HumanMessagePromptTemplate.from_template(human_template), | |
] | |
) | |
def update_system_prompt( | |
system_msg: str, llm_option: str | |
) -> Tuple[str, Dict[str, Any]]: | |
template_output = make_template(system_msg) | |
state = set_state() | |
state["template"] = template_output | |
use_claude = llm_option == "Claude 2" | |
state["llm_state"] = make_llm_state(use_claude) | |
llm = state["llm_state"]["llm"] | |
state["memory"] = ConversationTokenBufferMemory( | |
llm=llm, | |
max_token_limit=state["llm_state"]["context_length"], | |
return_messages=True, | |
) | |
state["chain"] = ConversationChain( | |
memory=state["memory"], prompt=state["template"], llm=llm | |
) | |
updated_status = "Prompt Updated! Chat has reset." | |
return updated_status, state | |
def set_state(state: Optional[gr.State] = None) -> Dict[str, Any]: | |
if state is None: | |
template = make_template() | |
llm_state = make_llm_state() | |
llm = llm_state["llm"] | |
memory = ConversationTokenBufferMemory( | |
llm=llm, max_token_limit=llm_state["context_length"], return_messages=True | |
) | |
chain = ConversationChain(memory=memory, prompt=template, llm=llm) | |
session_id = str(uuid.uuid4()) | |
state = dict( | |
template=template, | |
llm_state=llm_state, | |
history=[], | |
memory=memory, | |
chain=chain, | |
session_id=session_id, | |
) | |
return state | |
else: | |
return state | |
async def respond( | |
inp: str, | |
state: Optional[Dict[str, Any]], | |
request: gr.Request, | |
): | |
"""Execute the chat functionality.""" | |
def prep_messages( | |
user_msg: str, memory_buffer: List[BaseMessage] | |
) -> Tuple[str, List[BaseMessage]]: | |
messages_to_send = state["template"].format_messages( | |
input=user_msg, history=memory_buffer | |
) | |
user_msg_token_count = llm.get_num_tokens_from_messages([messages_to_send[-1]]) | |
total_token_count = llm.get_num_tokens_from_messages(messages_to_send) | |
while user_msg_token_count > context_length: | |
LOG.warning( | |
f"Pruning user message due to user message token length of {user_msg_token_count}" | |
) | |
user_msg = tokenizer.decode( | |
llm.get_token_ids(user_msg)[: context_length - 100] | |
) | |
messages_to_send = state["template"].format_messages( | |
input=user_msg, history=memory_buffer | |
) | |
user_msg_token_count = llm.get_num_tokens_from_messages( | |
[messages_to_send[-1]] | |
) | |
total_token_count = llm.get_num_tokens_from_messages(messages_to_send) | |
while total_token_count > context_length: | |
LOG.warning( | |
f"Pruning memory due to total token length of {total_token_count}" | |
) | |
if len(memory_buffer) == 1: | |
memory_buffer.pop(0) | |
continue | |
memory_buffer = memory_buffer[1:] | |
messages_to_send = state["template"].format_messages( | |
input=user_msg, history=memory_buffer | |
) | |
total_token_count = llm.get_num_tokens_from_messages(messages_to_send) | |
return user_msg, memory_buffer | |
try: | |
if state is None: | |
state = set_state() | |
llm = state["llm_state"]["llm"] | |
context_length = state["llm_state"]["context_length"] | |
tokenizer = state["llm_state"]["tokenizer"] | |
LOG.info(f"""[{request.username}] STARTING CHAIN""") | |
LOG.debug(f"History: {state['history']}") | |
LOG.debug(f"User input: {inp}") | |
inp, state["memory"].chat_memory.messages = prep_messages( | |
inp, state["memory"].buffer | |
) | |
messages_to_send = state["template"].format_messages( | |
input=inp, history=state["memory"].buffer | |
) | |
total_token_count = llm.get_num_tokens_from_messages(messages_to_send) | |
LOG.debug(f"Messages to send: {messages_to_send}") | |
LOG.info(f"Tokens to send: {total_token_count}") | |
# Run chain and append input. | |
callback = AsyncIteratorCallbackHandler() | |
run = asyncio.create_task( | |
state["chain"].apredict(input=inp, callbacks=[callback]) | |
) | |
state["history"].append((inp, "")) | |
async for tok in callback.aiter(): | |
user, bot = state["history"][-1] | |
bot += tok | |
state["history"][-1] = (user, bot) | |
yield state["history"], state | |
await run | |
LOG.info(f"""[{request.username}] ENDING CHAIN""") | |
LOG.debug(f"History: {state['history']}") | |
LOG.debug(f"Memory: {state['memory'].json()}") | |
data_to_flag = ( | |
{ | |
"history": deepcopy(state["history"]), | |
"username": request.username, | |
"timestamp": datetime.datetime.now(datetime.timezone.utc).isoformat(), | |
"session_id": state["session_id"], | |
}, | |
) | |
LOG.debug(f"Data to flag: {data_to_flag}") | |
gradio_flagger.flag(flag_data=data_to_flag, username=request.username) | |
except Exception as e: | |
LOG.exception(e) | |
raise e | |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY") | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
theme = gr.themes.Soft() | |
creds = [(os.getenv("CHAT_USERNAME"), os.getenv("CHAT_PASSWORD"))] | |
gradio_flagger = gr.HuggingFaceDatasetSaver(HF_TOKEN, "chats") | |
title = "AI Debate Partner" | |
with gr.Blocks( | |
theme=theme, | |
analytics_enabled=False, | |
title=title, | |
) as demo: | |
state = gr.State() | |
gr.Markdown(f"### {title}") | |
with gr.Tab("Setup"): | |
with gr.Column(): | |
llm_input = gr.Dropdown( | |
label="LLM", | |
choices=["Claude 2", "GPT-4"], | |
value="GPT-4", | |
multiselect=False, | |
) | |
system_prompt_input = gr.Textbox( | |
label="System Prompt", value=SYSTEM_MESSAGE | |
) | |
update_system_button = gr.Button(value="Update Prompt & Reset") | |
status_markdown = gr.Markdown() | |
with gr.Tab("Chatbot"): | |
with gr.Column(): | |
chatbot = gr.Chatbot(label="ChatBot") | |
inputs = gr.Textbox( | |
placeholder="Send a message.", | |
label="Type an input and press Enter", | |
) | |
b1 = gr.Button(value="Submit") | |
gradio_flagger.setup([chatbot], "chats") | |
inputs.submit( | |
respond, | |
[inputs, state], | |
[chatbot, state], | |
) | |
b1.click( | |
respond, | |
[inputs, state], | |
[chatbot, state], | |
) | |
update_system_button.click( | |
update_system_prompt, | |
[system_prompt_input, llm_input], | |
[status_markdown, state], | |
) | |
update_system_button.click(reset_textbox, [], [inputs]) | |
update_system_button.click(reset_textbox, [], [chatbot]) | |
b1.click(reset_textbox, [], [inputs]) | |
inputs.submit(reset_textbox, [], [inputs]) | |
demo.queue(max_size=99, concurrency_count=99, api_open=False).launch( | |
debug=True, # auth=auth | |
) | |