sandbox / app.py
justinxzhao's picture
Some refactoring, judging responses for direct assessment.
577870e
raw
history blame
15.9 kB
import os
import streamlit as st
import dotenv
import openai
from openai import OpenAI
import anthropic
from together import Together
import google.generativeai as genai
import time
from typing import List, Optional, Literal, Union
from constants import (
LLM_COUNCIL_MEMBERS,
PROVIDER_TO_AVATAR_MAP,
AGGREGATORS,
)
from prompts import *
from judging_dataclasses import *
dotenv.load_dotenv()
PASSWORD = os.getenv("APP_PASSWORD")
# Load API keys from environment variables
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY")
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
TOGETHER_API_KEY = os.getenv("TOGETHER_API_KEY")
# Initialize API clients
together_client = Together(api_key=TOGETHER_API_KEY)
genai.configure(api_key=GOOGLE_API_KEY)
# Set up API clients for OpenAI and Anthropic
openai.api_key = OPENAI_API_KEY
openai_client = OpenAI(
organization="org-kUoRSK0nOw4W2nQYMVGWOt03",
project="proj_zb6k1DdgnSEbiAEMWxSOVVu4",
)
# anthropic_client = anthropic.Client(api_key=ANTHROPIC_API_KEY)
anthropic_client = anthropic.Anthropic()
def anthropic_streamlit_streamer(stream):
"""
Process the Anthropic streaming response and yield content from the deltas.
:param stream: Streaming object from Anthropic API
:return: Yields content (text) from the streaming response.
"""
for event in stream:
if hasattr(event, "type"):
# Handle content blocks
if event.type == "content_block_delta" and hasattr(event, "delta"):
# Extract text delta from the event
text_delta = getattr(event.delta, "text", None)
if text_delta:
yield text_delta
# Handle message completion events (optional if needed)
elif event.type == "message_stop":
break # End of message, stop streaming
def google_streamlit_streamer(stream):
for chunk in stream:
yield chunk.text
def together_streamlit_streamer(stream):
for chunk in stream:
yield chunk.choices[0].delta.content
def llm_streamlit_streamer(stream, llm):
if llm.startswith("anthropic"):
return anthropic_streamlit_streamer(stream)
elif llm.startswith("vertex"):
return google_streamlit_streamer(stream)
elif llm.startswith("together"):
return together_streamlit_streamer(stream)
# Helper functions for LLM council and aggregator selection
def llm_council_selector():
selected_council = st.radio(
"Choose a council configuration", options=list(LLM_COUNCIL_MEMBERS.keys())
)
return LLM_COUNCIL_MEMBERS[selected_council]
def aggregator_selector():
return st.radio("Choose an aggregator LLM", options=AGGREGATORS)
# API calls for different providers
def get_openai_response(model_name, prompt):
return openai_client.chat.completions.create(
model=model_name,
messages=[{"role": "user", "content": prompt}],
stream=True,
)
# https://docs.anthropic.com/en/api/messages-streaming
def get_anthropic_response(model_name, prompt):
return anthropic_client.messages.create(
max_tokens=1024,
messages=[{"role": "user", "content": prompt}],
model=model_name,
stream=True,
)
def get_together_response(model_name, prompt):
return together_client.chat.completions.create(
model=model_name,
messages=[{"role": "user", "content": prompt}],
stream=True,
)
# https://ai.google.dev/gemini-api/docs/text-generation?lang=python
def get_google_response(model_name, prompt):
model = genai.GenerativeModel(model_name)
return model.generate_content(prompt, stream=True)
def get_llm_response_stream(model_identifier, prompt):
"""Returns a streamlit-friendly stream of response tokens from the LLM."""
provider, model_name = model_identifier.split("://")
if provider == "openai":
return get_openai_response(model_name, prompt)
elif provider == "anthropic":
return anthropic_streamlit_streamer(get_anthropic_response(model_name, prompt))
elif provider == "together":
return together_streamlit_streamer(get_together_response(model_name, prompt))
elif provider == "vertex":
return google_streamlit_streamer(get_google_response(model_name, prompt))
else:
return None
def get_response_key(model):
return model + ".response"
def get_model_from_response_key(response_key):
return response_key.split(".")[0]
def get_judging_key(judge_model, response_model):
return "judge." + judge_model + "." + response_model
def get_aggregator_response_key(model):
return model + ".aggregator_response"
# Streamlit form UI
def render_criteria_form(criteria_num):
"""Render a criteria input form."""
with st.expander(f"Criteria {criteria_num + 1}"):
name = st.text_input(f"Name for Criteria {criteria_num + 1}")
description = st.text_area(f"Description for Criteria {criteria_num + 1}")
min_score = st.number_input(
f"Min Score for Criteria {criteria_num + 1}", min_value=0, step=1
)
max_score = st.number_input(
f"Max Score for Criteria {criteria_num + 1}", min_value=0, step=1
)
return Criteria(
name=name, description=description, min_score=min_score, max_score=max_score
)
def get_response_mapping():
# Inspect the session state for all the responses.
# This is a dictionary mapping model names to their responses.
# The aggregator response is also included in this mapping under the key "<model>.aggregator_response".
response_mapping = {}
for key in st.session_state.keys():
if key.endswith(".response"):
response_mapping[get_model_from_response_key(key)] = st.session_state[key]
if key.endswith(".aggregator_response"):
response_mapping[key] = st.session_state[key]
return response_mapping
def format_likert_comparison_options(options):
return "\n".join([f"{i + 1}: {option}" for i, option in enumerate(options)])
def format_criteria_list(criteria_list):
return "\n".join(
[f"{criteria.name}: {criteria.description}" for criteria in criteria_list]
)
def get_direct_assessment_prompt(
direct_assessment_prompt, user_prompt, response, criteria_list, options
):
return direct_assessment_prompt.format(
user_prompt=user_prompt,
response=response,
criteria_list=f"{format_criteria_list(DEFAULT_DIRECT_ASSESSMENT_CRITERIA_LIST)}",
options=f"{format_likert_comparison_options(SEVEN_POINT_DIRECT_ASSESSMENT_OPTIONS)}",
)
def get_default_direct_assessment_prompt(user_prompt):
return get_direct_assessment_prompt(
DEFAULT_DIRECT_ASSESSMENT_PROMPT,
user_prompt=user_prompt,
response="{{response}}",
criteria_list=DEFAULT_DIRECT_ASSESSMENT_CRITERIA_LIST,
options=SEVEN_POINT_DIRECT_ASSESSMENT_OPTIONS,
)
def get_aggregator_prompt(aggregator_prompt, user_prompt, llms):
responses_from_other_llms = "\n\n".join(
[f"{model}: {st.session_state.get(get_response_key(model))}" for model in llms]
)
return aggregator_prompt.format(
user_prompt=user_prompt,
responses_from_other_llms=responses_from_other_llms,
)
def get_default_aggregator_prompt(user_prompt, llms):
return get_aggregator_prompt(
DEFAULT_AGGREGATOR_PROMPT,
user_prompt=user_prompt,
llms=llms,
)
# Main Streamlit App
def main():
st.set_page_config(
page_title="Language Model Council Sandbox", page_icon="🏛️", layout="wide"
)
# Custom CSS for the chat display
center_css = """
<style>
h1, h2, h3, h6 { text-align: center; }
.chat-container {
display: flex;
align-items: flex-start;
margin-bottom: 10px;
}
.avatar {
width: 50px;
margin-right: 10px;
}
.message {
background-color: #f1f1f1;
padding: 10px;
border-radius: 10px;
width: 100%;
}
</style>
"""
st.markdown(center_css, unsafe_allow_html=True)
# App title and description
st.title("Language Model Council Sandbox")
st.markdown("###### Invoke a council of LLMs to generate and judge each other.")
st.markdown("###### [Paper](https://arxiv.org/abs/2406.08598)")
# Authentication system
if "authenticated" not in st.session_state:
st.session_state.authenticated = False
cols = st.columns([2, 1, 2])
if not st.session_state.authenticated:
with cols[1]:
password = st.text_input("Password", type="password")
if st.button("Login", use_container_width=True):
if password == PASSWORD:
st.session_state.authenticated = True
else:
st.error("Invalid credentials")
if st.session_state.authenticated:
st.success("Logged in successfully!")
# Council and aggregator selection
selected_models = llm_council_selector()
st.write("Selected Models:", selected_models)
selected_aggregator = aggregator_selector()
# st.write("Selected Aggregator:", selected_aggregator)
# Prompt input
user_prompt = st.text_area("Enter your prompt:")
if st.button("Submit"):
st.write("Responses:")
# Fetching and streaming responses from each selected model
# TODO: Make this asynchronous?
for model in selected_models:
with st.chat_message(
model,
avatar=PROVIDER_TO_AVATAR_MAP[model],
):
message_placeholder = st.empty()
stream = get_llm_response_stream(model, user_prompt)
if stream:
st.session_state[get_response_key(model)] = (
message_placeholder.write_stream(stream)
)
# Get the aggregator prompt.
aggregator_prompt = get_default_aggregator_prompt(
user_prompt=user_prompt, llms=selected_models
)
with st.expander("Aggregator Prompt"):
st.write(aggregator_prompt)
# Fetching and streaming response from the aggregator
st.write(f"Mixture-of-Agents response from {selected_aggregator}:")
with st.chat_message(
selected_aggregator,
avatar=PROVIDER_TO_AVATAR_MAP[selected_aggregator],
):
message_placeholder = st.empty()
aggregator_stream = get_llm_response_stream(
selected_aggregator, aggregator_prompt
)
if aggregator_stream:
message_placeholder.write_stream(aggregator_stream)
st.session_state[
get_aggregator_response_key(selected_aggregator)
] = message_placeholder.write_stream(aggregator_stream)
# Judging.
st.markdown("#### Judging Configuration Form")
# Choose the type of assessment
assessment_type = st.radio(
"Select the type of assessment",
options=["Direct Assessment", "Pairwise Comparison"],
)
# Depending on the assessment type, render different forms
if assessment_type == "Direct Assessment":
direct_assessment_prompt = st.text_area(
"Prompt for the Direct Assessment",
value=get_default_direct_assessment_prompt(user_prompt=user_prompt),
height=500,
)
# TODO: Add option to edit criteria list with a basic text field.
criteria_list = DEFAULT_DIRECT_ASSESSMENT_CRITERIA_LIST
# Create DirectAssessment object when form is submitted
if st.button("Submit Direct Assessment"):
# Submit direct asssessment.
responses_for_judging = get_response_mapping()
response_judging_columns = st.columns(3)
responses_for_judging_to_streamlit_column_index_map = {
model: response_judging_columns[i % 3]
for i, model in enumerate(responses_for_judging.keys())
}
# Get judging responses.
for response_model, response in responses_for_judging.items():
st_column = response_judging_columns[
responses_for_judging_to_streamlit_column_index_map[
response_model
]
]
with st_column:
st.write(f"Judging {response_model}")
judging_prompt = get_direct_assessment_prompt(
direct_assessment_prompt,
user_prompt,
response,
criteria_list,
SEVEN_POINT_DIRECT_ASSESSMENT_OPTIONS,
)
for judging_model in selected_models:
with st.expander("Detailed assessments", expanded=True):
with st.chat_message(
judging_model,
avatar=PROVIDER_TO_AVATAR_MAP[judging_model],
):
st.write(f"Judge: {judging_model}")
message_placeholder = st.empty()
judging_stream = get_llm_response_stream(
judging_model, judging_prompt
)
if judging_stream:
st.session_state[
get_judging_key(
judging_model, response_model
)
] = message_placeholder.write_stream(
judging_stream
)
# When all of the judging is finished for the given response, get the actual
# values, parsed (use gpt-4o-mini for now) with json mode.
# TODO.
elif assessment_type == "Pairwise Comparison":
pairwise_comparison_prompt = st.text_area(
"Prompt for the Pairwise Comparison"
)
granularity = st.selectbox("Granularity", ["coarse", "fine", "super fine"])
ties_allowed = st.checkbox("Are ties allowed?")
position_swapping = st.checkbox("Enable position swapping?")
reference_model = st.text_input("Reference Model")
# Create PairwiseComparison object when form is submitted
if st.button("Submit Pairwise Comparison"):
pairwise_comparison_config = PairwiseComparison(
type="pairwise_comparison",
granularity=granularity,
ties_allowed=ties_allowed,
position_swapping=position_swapping,
reference_model=reference_model,
prompt=prompt,
)
st.success(f"Pairwise Comparison Created: {pairwise_comparison_config}")
# Submit pairwise comparison.
responses_for_judging = get_response_mapping()
else:
with cols[1]:
st.warning("Please log in to access this app.")
if __name__ == "__main__":
main()