Spaces:
Sleeping
Sleeping
import os | |
import streamlit as st | |
import dotenv | |
import openai | |
from openai import OpenAI | |
import anthropic | |
from together import Together | |
import google.generativeai as genai | |
import time | |
from typing import List, Optional, Literal, Union | |
from constants import ( | |
LLM_COUNCIL_MEMBERS, | |
PROVIDER_TO_AVATAR_MAP, | |
AGGREGATORS, | |
) | |
from prompts import * | |
from judging_dataclasses import * | |
dotenv.load_dotenv() | |
PASSWORD = os.getenv("APP_PASSWORD") | |
# Load API keys from environment variables | |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY") | |
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") | |
TOGETHER_API_KEY = os.getenv("TOGETHER_API_KEY") | |
# Initialize API clients | |
together_client = Together(api_key=TOGETHER_API_KEY) | |
genai.configure(api_key=GOOGLE_API_KEY) | |
# Set up API clients for OpenAI and Anthropic | |
openai.api_key = OPENAI_API_KEY | |
openai_client = OpenAI( | |
organization="org-kUoRSK0nOw4W2nQYMVGWOt03", | |
project="proj_zb6k1DdgnSEbiAEMWxSOVVu4", | |
) | |
# anthropic_client = anthropic.Client(api_key=ANTHROPIC_API_KEY) | |
anthropic_client = anthropic.Anthropic() | |
def anthropic_streamlit_streamer(stream): | |
""" | |
Process the Anthropic streaming response and yield content from the deltas. | |
:param stream: Streaming object from Anthropic API | |
:return: Yields content (text) from the streaming response. | |
""" | |
for event in stream: | |
if hasattr(event, "type"): | |
# Handle content blocks | |
if event.type == "content_block_delta" and hasattr(event, "delta"): | |
# Extract text delta from the event | |
text_delta = getattr(event.delta, "text", None) | |
if text_delta: | |
yield text_delta | |
# Handle message completion events (optional if needed) | |
elif event.type == "message_stop": | |
break # End of message, stop streaming | |
def google_streamlit_streamer(stream): | |
for chunk in stream: | |
yield chunk.text | |
def together_streamlit_streamer(stream): | |
for chunk in stream: | |
yield chunk.choices[0].delta.content | |
def llm_streamlit_streamer(stream, llm): | |
if llm.startswith("anthropic"): | |
return anthropic_streamlit_streamer(stream) | |
elif llm.startswith("vertex"): | |
return google_streamlit_streamer(stream) | |
elif llm.startswith("together"): | |
return together_streamlit_streamer(stream) | |
# Helper functions for LLM council and aggregator selection | |
def llm_council_selector(): | |
selected_council = st.radio( | |
"Choose a council configuration", options=list(LLM_COUNCIL_MEMBERS.keys()) | |
) | |
return LLM_COUNCIL_MEMBERS[selected_council] | |
def aggregator_selector(): | |
return st.radio("Choose an aggregator LLM", options=AGGREGATORS) | |
# API calls for different providers | |
def get_openai_response(model_name, prompt): | |
return openai_client.chat.completions.create( | |
model=model_name, | |
messages=[{"role": "user", "content": prompt}], | |
stream=True, | |
) | |
# https://docs.anthropic.com/en/api/messages-streaming | |
def get_anthropic_response(model_name, prompt): | |
return anthropic_client.messages.create( | |
max_tokens=1024, | |
messages=[{"role": "user", "content": prompt}], | |
model=model_name, | |
stream=True, | |
) | |
def get_together_response(model_name, prompt): | |
return together_client.chat.completions.create( | |
model=model_name, | |
messages=[{"role": "user", "content": prompt}], | |
stream=True, | |
) | |
# https://ai.google.dev/gemini-api/docs/text-generation?lang=python | |
def get_google_response(model_name, prompt): | |
model = genai.GenerativeModel(model_name) | |
return model.generate_content(prompt, stream=True) | |
def get_llm_response_stream(model_identifier, prompt): | |
"""Returns a streamlit-friendly stream of response tokens from the LLM.""" | |
provider, model_name = model_identifier.split("://") | |
if provider == "openai": | |
return get_openai_response(model_name, prompt) | |
elif provider == "anthropic": | |
return anthropic_streamlit_streamer(get_anthropic_response(model_name, prompt)) | |
elif provider == "together": | |
return together_streamlit_streamer(get_together_response(model_name, prompt)) | |
elif provider == "vertex": | |
return google_streamlit_streamer(get_google_response(model_name, prompt)) | |
else: | |
return None | |
def get_response_key(model): | |
return model + ".response" | |
def get_model_from_response_key(response_key): | |
return response_key.split(".")[0] | |
def get_judging_key(judge_model, response_model): | |
return "judge." + judge_model + "." + response_model | |
def get_aggregator_response_key(model): | |
return model + ".aggregator_response" | |
# Streamlit form UI | |
def render_criteria_form(criteria_num): | |
"""Render a criteria input form.""" | |
with st.expander(f"Criteria {criteria_num + 1}"): | |
name = st.text_input(f"Name for Criteria {criteria_num + 1}") | |
description = st.text_area(f"Description for Criteria {criteria_num + 1}") | |
min_score = st.number_input( | |
f"Min Score for Criteria {criteria_num + 1}", min_value=0, step=1 | |
) | |
max_score = st.number_input( | |
f"Max Score for Criteria {criteria_num + 1}", min_value=0, step=1 | |
) | |
return Criteria( | |
name=name, description=description, min_score=min_score, max_score=max_score | |
) | |
def get_response_mapping(): | |
# Inspect the session state for all the responses. | |
# This is a dictionary mapping model names to their responses. | |
# The aggregator response is also included in this mapping under the key "<model>.aggregator_response". | |
response_mapping = {} | |
for key in st.session_state.keys(): | |
if key.endswith(".response"): | |
response_mapping[get_model_from_response_key(key)] = st.session_state[key] | |
if key.endswith(".aggregator_response"): | |
response_mapping[key] = st.session_state[key] | |
return response_mapping | |
def format_likert_comparison_options(options): | |
return "\n".join([f"{i + 1}: {option}" for i, option in enumerate(options)]) | |
def format_criteria_list(criteria_list): | |
return "\n".join( | |
[f"{criteria.name}: {criteria.description}" for criteria in criteria_list] | |
) | |
def get_direct_assessment_prompt( | |
direct_assessment_prompt, user_prompt, response, criteria_list, options | |
): | |
return direct_assessment_prompt.format( | |
user_prompt=user_prompt, | |
response=response, | |
criteria_list=f"{format_criteria_list(DEFAULT_DIRECT_ASSESSMENT_CRITERIA_LIST)}", | |
options=f"{format_likert_comparison_options(SEVEN_POINT_DIRECT_ASSESSMENT_OPTIONS)}", | |
) | |
def get_default_direct_assessment_prompt(user_prompt): | |
return get_direct_assessment_prompt( | |
DEFAULT_DIRECT_ASSESSMENT_PROMPT, | |
user_prompt=user_prompt, | |
response="{{response}}", | |
criteria_list=DEFAULT_DIRECT_ASSESSMENT_CRITERIA_LIST, | |
options=SEVEN_POINT_DIRECT_ASSESSMENT_OPTIONS, | |
) | |
def get_aggregator_prompt(aggregator_prompt, user_prompt, llms): | |
responses_from_other_llms = "\n\n".join( | |
[f"{model}: {st.session_state.get(get_response_key(model))}" for model in llms] | |
) | |
return aggregator_prompt.format( | |
user_prompt=user_prompt, | |
responses_from_other_llms=responses_from_other_llms, | |
) | |
def get_default_aggregator_prompt(user_prompt, llms): | |
return get_aggregator_prompt( | |
DEFAULT_AGGREGATOR_PROMPT, | |
user_prompt=user_prompt, | |
llms=llms, | |
) | |
# Main Streamlit App | |
def main(): | |
st.set_page_config( | |
page_title="Language Model Council Sandbox", page_icon="🏛️", layout="wide" | |
) | |
# Custom CSS for the chat display | |
center_css = """ | |
<style> | |
h1, h2, h3, h6 { text-align: center; } | |
.chat-container { | |
display: flex; | |
align-items: flex-start; | |
margin-bottom: 10px; | |
} | |
.avatar { | |
width: 50px; | |
margin-right: 10px; | |
} | |
.message { | |
background-color: #f1f1f1; | |
padding: 10px; | |
border-radius: 10px; | |
width: 100%; | |
} | |
</style> | |
""" | |
st.markdown(center_css, unsafe_allow_html=True) | |
# App title and description | |
st.title("Language Model Council Sandbox") | |
st.markdown("###### Invoke a council of LLMs to generate and judge each other.") | |
st.markdown("###### [Paper](https://arxiv.org/abs/2406.08598)") | |
# Authentication system | |
if "authenticated" not in st.session_state: | |
st.session_state.authenticated = False | |
cols = st.columns([2, 1, 2]) | |
if not st.session_state.authenticated: | |
with cols[1]: | |
password = st.text_input("Password", type="password") | |
if st.button("Login", use_container_width=True): | |
if password == PASSWORD: | |
st.session_state.authenticated = True | |
else: | |
st.error("Invalid credentials") | |
if st.session_state.authenticated: | |
st.success("Logged in successfully!") | |
# Council and aggregator selection | |
selected_models = llm_council_selector() | |
st.write("Selected Models:", selected_models) | |
selected_aggregator = aggregator_selector() | |
# st.write("Selected Aggregator:", selected_aggregator) | |
# Prompt input | |
user_prompt = st.text_area("Enter your prompt:") | |
if st.button("Submit"): | |
st.write("Responses:") | |
# Fetching and streaming responses from each selected model | |
# TODO: Make this asynchronous? | |
for model in selected_models: | |
with st.chat_message( | |
model, | |
avatar=PROVIDER_TO_AVATAR_MAP[model], | |
): | |
message_placeholder = st.empty() | |
stream = get_llm_response_stream(model, user_prompt) | |
if stream: | |
st.session_state[get_response_key(model)] = ( | |
message_placeholder.write_stream(stream) | |
) | |
# Get the aggregator prompt. | |
aggregator_prompt = get_default_aggregator_prompt( | |
user_prompt=user_prompt, llms=selected_models | |
) | |
with st.expander("Aggregator Prompt"): | |
st.write(aggregator_prompt) | |
# Fetching and streaming response from the aggregator | |
st.write(f"Mixture-of-Agents response from {selected_aggregator}:") | |
with st.chat_message( | |
selected_aggregator, | |
avatar=PROVIDER_TO_AVATAR_MAP[selected_aggregator], | |
): | |
message_placeholder = st.empty() | |
aggregator_stream = get_llm_response_stream( | |
selected_aggregator, aggregator_prompt | |
) | |
if aggregator_stream: | |
message_placeholder.write_stream(aggregator_stream) | |
st.session_state[ | |
get_aggregator_response_key(selected_aggregator) | |
] = message_placeholder.write_stream(aggregator_stream) | |
# Judging. | |
st.markdown("#### Judging Configuration Form") | |
# Choose the type of assessment | |
assessment_type = st.radio( | |
"Select the type of assessment", | |
options=["Direct Assessment", "Pairwise Comparison"], | |
) | |
# Depending on the assessment type, render different forms | |
if assessment_type == "Direct Assessment": | |
direct_assessment_prompt = st.text_area( | |
"Prompt for the Direct Assessment", | |
value=get_default_direct_assessment_prompt(user_prompt=user_prompt), | |
height=500, | |
) | |
# TODO: Add option to edit criteria list with a basic text field. | |
criteria_list = DEFAULT_DIRECT_ASSESSMENT_CRITERIA_LIST | |
# Create DirectAssessment object when form is submitted | |
if st.button("Submit Direct Assessment"): | |
# Submit direct asssessment. | |
responses_for_judging = get_response_mapping() | |
response_judging_columns = st.columns(3) | |
responses_for_judging_to_streamlit_column_index_map = { | |
model: response_judging_columns[i % 3] | |
for i, model in enumerate(responses_for_judging.keys()) | |
} | |
# Get judging responses. | |
for response_model, response in responses_for_judging.items(): | |
st_column = response_judging_columns[ | |
responses_for_judging_to_streamlit_column_index_map[ | |
response_model | |
] | |
] | |
with st_column: | |
st.write(f"Judging {response_model}") | |
judging_prompt = get_direct_assessment_prompt( | |
direct_assessment_prompt, | |
user_prompt, | |
response, | |
criteria_list, | |
SEVEN_POINT_DIRECT_ASSESSMENT_OPTIONS, | |
) | |
for judging_model in selected_models: | |
with st.expander("Detailed assessments", expanded=True): | |
with st.chat_message( | |
judging_model, | |
avatar=PROVIDER_TO_AVATAR_MAP[judging_model], | |
): | |
st.write(f"Judge: {judging_model}") | |
message_placeholder = st.empty() | |
judging_stream = get_llm_response_stream( | |
judging_model, judging_prompt | |
) | |
if judging_stream: | |
st.session_state[ | |
get_judging_key( | |
judging_model, response_model | |
) | |
] = message_placeholder.write_stream( | |
judging_stream | |
) | |
# When all of the judging is finished for the given response, get the actual | |
# values, parsed (use gpt-4o-mini for now) with json mode. | |
# TODO. | |
elif assessment_type == "Pairwise Comparison": | |
pairwise_comparison_prompt = st.text_area( | |
"Prompt for the Pairwise Comparison" | |
) | |
granularity = st.selectbox("Granularity", ["coarse", "fine", "super fine"]) | |
ties_allowed = st.checkbox("Are ties allowed?") | |
position_swapping = st.checkbox("Enable position swapping?") | |
reference_model = st.text_input("Reference Model") | |
# Create PairwiseComparison object when form is submitted | |
if st.button("Submit Pairwise Comparison"): | |
pairwise_comparison_config = PairwiseComparison( | |
type="pairwise_comparison", | |
granularity=granularity, | |
ties_allowed=ties_allowed, | |
position_swapping=position_swapping, | |
reference_model=reference_model, | |
prompt=prompt, | |
) | |
st.success(f"Pairwise Comparison Created: {pairwise_comparison_config}") | |
# Submit pairwise comparison. | |
responses_for_judging = get_response_mapping() | |
else: | |
with cols[1]: | |
st.warning("Please log in to access this app.") | |
if __name__ == "__main__": | |
main() | |