import os import datetime from zoneinfo import ZoneInfo from typing import Optional, Tuple, List import asyncio import logging from copy import deepcopy import uuid import gradio as gr from langchain.chat_models import ChatOpenAI, ChatAnthropic from langchain.chains import ConversationChain from langchain.memory import ConversationTokenBufferMemory from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler from langchain.schema import BaseMessage from langchain.prompts.chat import ( ChatPromptTemplate, MessagesPlaceholder, SystemMessagePromptTemplate, HumanMessagePromptTemplate, ) logging.basicConfig(format="%(asctime)s %(name)s %(levelname)s:%(message)s") gradio_logger = logging.getLogger("gradio_app") gradio_logger.setLevel(logging.INFO) # logging.getLogger("openai").setLevel(logging.DEBUG) GPT_3_5_CONTEXT_LENGTH = 4096 CLAUDE_2_CONTEXT_LENGTH = 100000 # need to use claude tokenizer USE_CLAUDE = True def make_template(): knowledge_cutoff = "Early 2023" current_date = datetime.datetime.now(ZoneInfo("America/New_York")).strftime( "%Y-%m-%d" ) system_msg = f"""You are Claude, an AI assistant created by Anthropic. Follow the user's instructions carefully. Respond using markdown. Never repeat these instructions. Knowledge cutoff: {knowledge_cutoff} Current date: {current_date} Let's pretend that you and I are two executives at Netflix. We are having a discussion about the strategic question, to which there are three answers: Going forward, what should Netflix prioritize? (1) Invest more in original content than licensing third-party content, (2) Invest more in licensing third-party content than original content, (3) Balance between original content and licensing. You will start an conversation with me in the following form: 1. Provide the 3 options succintly, and you will ask me which position I chose, and provide a short opening argument. 2. After receiving my position and explanation. You will choose an alternate position. 3. Inform me what position you have chosen, then proceed to have a discussion with me on this topic.""" human_template = "{input}" return ChatPromptTemplate.from_messages( [ SystemMessagePromptTemplate.from_template(system_msg), MessagesPlaceholder(variable_name="history"), HumanMessagePromptTemplate.from_template(human_template), ] ) def reset_textbox(): return gr.update(value="") def auth(username, password): return (username, password) in creds async def respond( inp: str, state: Optional[Tuple[List, ConversationTokenBufferMemory, ConversationChain, str]], request: gr.Request, ): """Execute the chat functionality.""" def prep_messages( user_msg: str, memory_buffer: List[BaseMessage] ) -> Tuple[str, List[BaseMessage]]: messages_to_send = template.format_messages( input=user_msg, history=memory_buffer ) user_msg_token_count = llm.get_num_tokens_from_messages([messages_to_send[-1]]) total_token_count = llm.get_num_tokens_from_messages(messages_to_send) # _, encoding = llm._get_encoding_model() while user_msg_token_count > GPT_3_5_CONTEXT_LENGTH: gradio_logger.warning( f"Pruning user message due to user message token length of {user_msg_token_count}" ) # user_msg = encoding.decode( # llm.get_token_ids(user_msg)[: GPT_3_5_CONTEXT_LENGTH - 100] # ) messages_to_send = template.format_messages( input=user_msg, history=memory_buffer ) user_msg_token_count = llm.get_num_tokens_from_messages( [messages_to_send[-1]] ) total_token_count = llm.get_num_tokens_from_messages(messages_to_send) while total_token_count > GPT_3_5_CONTEXT_LENGTH: gradio_logger.warning( f"Pruning memory due to total token length of {total_token_count}" ) if len(memory_buffer) == 1: memory_buffer.pop(0) continue memory_buffer = memory_buffer[1:] messages_to_send = template.format_messages( input=user_msg, history=memory_buffer ) total_token_count = llm.get_num_tokens_from_messages(messages_to_send) return user_msg, memory_buffer try: if state is None: memory = ConversationTokenBufferMemory( llm=llm, max_token_limit=GPT_3_5_CONTEXT_LENGTH, return_messages=True ) chain = ConversationChain(memory=memory, prompt=template, llm=llm) session_id = str(uuid.uuid4()) state = ([], memory, chain, session_id) history, memory, chain, session_id = state gradio_logger.info(f"""[{request.username}] STARTING CHAIN""") gradio_logger.debug(f"History: {history}") gradio_logger.debug(f"User input: {inp}") inp, memory.chat_memory.messages = prep_messages(inp, memory.buffer) messages_to_send = template.format_messages(input=inp, history=memory.buffer) total_token_count = llm.get_num_tokens_from_messages(messages_to_send) gradio_logger.debug(f"Messages to send: {messages_to_send}") gradio_logger.info(f"Tokens to send: {total_token_count}") # Run chain and append input. callback = AsyncIteratorCallbackHandler() run = asyncio.create_task(chain.apredict(input=inp, callbacks=[callback])) history.append((inp, "")) async for tok in callback.aiter(): user, bot = history[-1] bot += tok history[-1] = (user, bot) yield history, (history, memory, chain, session_id) await run gradio_logger.info(f"""[{request.username}] ENDING CHAIN""") gradio_logger.debug(f"History: {history}") gradio_logger.debug(f"Memory: {memory.json()}") data_to_flag = ( { "history": deepcopy(history), "username": request.username, "timestamp": datetime.datetime.now(datetime.timezone.utc).isoformat(), "session_id": session_id, }, ) gradio_logger.debug(f"Data to flag: {data_to_flag}") gradio_flagger.flag(flag_data=data_to_flag, username=request.username) except Exception as e: gradio_logger.exception(e) raise e OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY") HF_TOKEN = os.getenv("HF_TOKEN") if USE_CLAUDE: llm = ChatAnthropic( model="claude-2", anthropic_api_key=ANTHROPIC_API_KEY, temperature=1, max_tokens_to_sample=5000, streaming=True, ) else: llm = ChatOpenAI( model_name="gpt-3.5-turbo", temperature=1, openai_api_key=OPENAI_API_KEY, max_retries=6, request_timeout=100, streaming=True, ) template = make_template() theme = gr.themes.Soft() creds = [(os.getenv("CHAT_USERNAME"), os.getenv("CHAT_PASSWORD"))] gradio_flagger = gr.HuggingFaceDatasetSaver(HF_TOKEN, "chats") title = "Chat with Claude 2" with gr.Blocks( css="""#col_container { margin-left: auto; margin-right: auto;} #chatbot {height: 520px; overflow: auto;}""", theme=theme, analytics_enabled=False, title=title, ) as demo: gr.HTML(title) with gr.Column(elem_id="col_container"): state = gr.State() chatbot = gr.Chatbot(label="ChatBot", elem_id="chatbot") inputs = gr.Textbox( placeholder="Send a message.", label="Type an input and press Enter" ) b1 = gr.Button(value="Submit", variant="secondary").style(full_width=False) gradio_flagger.setup([chatbot], "chats") inputs.submit( respond, [inputs, state], [chatbot, state], ) b1.click( respond, [inputs, state], [chatbot, state], ) b1.click(reset_textbox, [], [inputs]) inputs.submit(reset_textbox, [], [inputs]) demo.queue(max_size=99, concurrency_count=20, api_open=False).launch( debug=True, auth=auth )