justinxzhao commited on
Commit
3e0f8f8
·
1 Parent(s): 577870e

Added per-response plots.

Browse files
Files changed (4) hide show
  1. app.py +215 -49
  2. constants.py +10 -0
  3. judging_dataclasses.py +15 -0
  4. prompts.py +15 -0
app.py CHANGED
@@ -7,15 +7,18 @@ import anthropic
7
  from together import Together
8
  import google.generativeai as genai
9
  import time
10
- from typing import List, Optional, Literal, Union
11
  from constants import (
12
  LLM_COUNCIL_MEMBERS,
13
  PROVIDER_TO_AVATAR_MAP,
14
  AGGREGATORS,
 
15
  )
16
  from prompts import *
17
- from judging_dataclasses import *
18
-
 
 
19
 
20
  dotenv.load_dotenv()
21
 
@@ -40,6 +43,8 @@ openai_client = OpenAI(
40
  # anthropic_client = anthropic.Client(api_key=ANTHROPIC_API_KEY)
41
  anthropic_client = anthropic.Anthropic()
42
 
 
 
43
 
44
  def anthropic_streamlit_streamer(stream):
45
  """
@@ -142,19 +147,43 @@ def get_llm_response_stream(model_identifier, prompt):
142
 
143
 
144
  def get_response_key(model):
145
- return model + ".response"
146
 
147
 
148
  def get_model_from_response_key(response_key):
149
- return response_key.split(".")[0]
150
 
151
 
152
- def get_judging_key(judge_model, response_model):
153
- return "judge." + judge_model + "." + response_model
154
 
155
 
156
  def get_aggregator_response_key(model):
157
- return model + ".aggregator_response"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
 
160
  # Streamlit form UI
@@ -177,12 +206,14 @@ def render_criteria_form(criteria_num):
177
  def get_response_mapping():
178
  # Inspect the session state for all the responses.
179
  # This is a dictionary mapping model names to their responses.
180
- # The aggregator response is also included in this mapping under the key "<model>.aggregator_response".
181
  response_mapping = {}
182
  for key in st.session_state.keys():
183
- if key.endswith(".response"):
 
 
184
  response_mapping[get_model_from_response_key(key)] = st.session_state[key]
185
- if key.endswith(".aggregator_response"):
186
  response_mapping[key] = st.session_state[key]
187
  return response_mapping
188
 
@@ -210,9 +241,9 @@ def get_direct_assessment_prompt(
210
 
211
  def get_default_direct_assessment_prompt(user_prompt):
212
  return get_direct_assessment_prompt(
213
- DEFAULT_DIRECT_ASSESSMENT_PROMPT,
214
  user_prompt=user_prompt,
215
- response="{{response}}",
216
  criteria_list=DEFAULT_DIRECT_ASSESSMENT_CRITERIA_LIST,
217
  options=SEVEN_POINT_DIRECT_ASSESSMENT_OPTIONS,
218
  )
@@ -220,7 +251,10 @@ def get_default_direct_assessment_prompt(user_prompt):
220
 
221
  def get_aggregator_prompt(aggregator_prompt, user_prompt, llms):
222
  responses_from_other_llms = "\n\n".join(
223
- [f"{model}: {st.session_state.get(get_response_key(model))}" for model in llms]
 
 
 
224
  )
225
  return aggregator_prompt.format(
226
  user_prompt=user_prompt,
@@ -236,6 +270,100 @@ def get_default_aggregator_prompt(user_prompt, llms):
236
  )
237
 
238
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
  # Main Streamlit App
240
  def main():
241
  st.set_page_config(
@@ -291,7 +419,6 @@ def main():
291
  selected_models = llm_council_selector()
292
  st.write("Selected Models:", selected_models)
293
  selected_aggregator = aggregator_selector()
294
- # st.write("Selected Aggregator:", selected_aggregator)
295
 
296
  # Prompt input
297
  user_prompt = st.text_area("Enter your prompt:")
@@ -299,19 +426,26 @@ def main():
299
  if st.button("Submit"):
300
  st.write("Responses:")
301
 
 
 
 
 
 
 
302
  # Fetching and streaming responses from each selected model
303
- # TODO: Make this asynchronous?
304
- for model in selected_models:
305
- with st.chat_message(
306
- model,
307
- avatar=PROVIDER_TO_AVATAR_MAP[model],
308
- ):
309
- message_placeholder = st.empty()
310
- stream = get_llm_response_stream(model, user_prompt)
311
- if stream:
312
- st.session_state[get_response_key(model)] = (
313
- message_placeholder.write_stream(stream)
314
- )
 
315
 
316
  # Get the aggregator prompt.
317
  aggregator_prompt = get_default_aggregator_prompt(
@@ -319,10 +453,12 @@ def main():
319
  )
320
 
321
  with st.expander("Aggregator Prompt"):
322
- st.write(aggregator_prompt)
323
 
324
  # Fetching and streaming response from the aggregator
325
- st.write(f"Mixture-of-Agents response from {selected_aggregator}:")
 
 
326
  with st.chat_message(
327
  selected_aggregator,
328
  avatar=PROVIDER_TO_AVATAR_MAP[selected_aggregator],
@@ -348,11 +484,12 @@ def main():
348
 
349
  # Depending on the assessment type, render different forms
350
  if assessment_type == "Direct Assessment":
351
- direct_assessment_prompt = st.text_area(
352
- "Prompt for the Direct Assessment",
353
- value=get_default_direct_assessment_prompt(user_prompt=user_prompt),
354
- height=500,
355
- )
 
356
 
357
  # TODO: Add option to edit criteria list with a basic text field.
358
  criteria_list = DEFAULT_DIRECT_ASSESSMENT_CRITERIA_LIST
@@ -365,7 +502,7 @@ def main():
365
 
366
  response_judging_columns = st.columns(3)
367
 
368
- responses_for_judging_to_streamlit_column_index_map = {
369
  model: response_judging_columns[i % 3]
370
  for i, model in enumerate(responses_for_judging.keys())
371
  }
@@ -373,37 +510,42 @@ def main():
373
  # Get judging responses.
374
  for response_model, response in responses_for_judging.items():
375
 
376
- st_column = response_judging_columns[
377
- responses_for_judging_to_streamlit_column_index_map[
378
- response_model
379
- ]
380
  ]
381
 
382
  with st_column:
383
-
384
- st.write(f"Judging {response_model}")
 
 
 
385
  judging_prompt = get_direct_assessment_prompt(
386
- direct_assessment_prompt,
387
- user_prompt,
388
- response,
389
- criteria_list,
390
- SEVEN_POINT_DIRECT_ASSESSMENT_OPTIONS,
391
  )
392
 
 
 
 
393
  for judging_model in selected_models:
394
- with st.expander("Detailed assessments", expanded=True):
 
 
395
  with st.chat_message(
396
  judging_model,
397
  avatar=PROVIDER_TO_AVATAR_MAP[judging_model],
398
  ):
399
- st.write(f"Judge: {judging_model}")
400
  message_placeholder = st.empty()
401
  judging_stream = get_llm_response_stream(
402
  judging_model, judging_prompt
403
  )
404
  if judging_stream:
405
  st.session_state[
406
- get_judging_key(
407
  judging_model, response_model
408
  )
409
  ] = message_placeholder.write_stream(
@@ -412,6 +554,30 @@ def main():
412
  # When all of the judging is finished for the given response, get the actual
413
  # values, parsed (use gpt-4o-mini for now) with json mode.
414
  # TODO.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
415
 
416
  elif assessment_type == "Pairwise Comparison":
417
  pairwise_comparison_prompt = st.text_area(
 
7
  from together import Together
8
  import google.generativeai as genai
9
  import time
10
+ from typing import List, Optional, Literal, Union, Dict
11
  from constants import (
12
  LLM_COUNCIL_MEMBERS,
13
  PROVIDER_TO_AVATAR_MAP,
14
  AGGREGATORS,
15
+ LLM_TO_UI_NAME_MAP,
16
  )
17
  from prompts import *
18
+ from judging_dataclasses import DirectAssessmentJudgingResponse
19
+ import pandas as pd
20
+ import seaborn as sns
21
+ import matplotlib.pyplot as plt
22
 
23
  dotenv.load_dotenv()
24
 
 
43
  # anthropic_client = anthropic.Client(api_key=ANTHROPIC_API_KEY)
44
  anthropic_client = anthropic.Anthropic()
45
 
46
+ client = OpenAI()
47
+
48
 
49
  def anthropic_streamlit_streamer(stream):
50
  """
 
147
 
148
 
149
  def get_response_key(model):
150
+ return model + "__response"
151
 
152
 
153
  def get_model_from_response_key(response_key):
154
+ return response_key.split("__")[0]
155
 
156
 
157
+ def get_direct_assessment_judging_key(judge_model, response_model):
158
+ return "direct_assessment_judge__" + judge_model + "__" + response_model
159
 
160
 
161
  def get_aggregator_response_key(model):
162
+ return model + "__aggregator_response"
163
+
164
+
165
+ def create_dataframe_for_direct_assessment_judging_response(
166
+ response: DirectAssessmentJudgingResponse,
167
+ ):
168
+ # Initialize empty list to collect data
169
+ data = []
170
+
171
+ # Loop through models
172
+ for judging_model in response.judging_models:
173
+ model_name = judging_model.model
174
+ # Loop through criteria_scores
175
+ for criteria_score in judging_model.criteria_scores:
176
+ data.append(
177
+ {
178
+ "llm_judge_model": model_name,
179
+ "criteria": criteria_score.criterion,
180
+ "score": criteria_score.score,
181
+ "explanation": criteria_score.explanation,
182
+ }
183
+ )
184
+
185
+ # Create DataFrame
186
+ return pd.DataFrame(data)
187
 
188
 
189
  # Streamlit form UI
 
206
  def get_response_mapping():
207
  # Inspect the session state for all the responses.
208
  # This is a dictionary mapping model names to their responses.
209
+ # The aggregator response is also included in this mapping under the key "<model>__aggregator_response".
210
  response_mapping = {}
211
  for key in st.session_state.keys():
212
+ if "judge" in key:
213
+ continue
214
+ if key.endswith("__response"):
215
  response_mapping[get_model_from_response_key(key)] = st.session_state[key]
216
+ if key.endswith("__aggregator_response"):
217
  response_mapping[key] = st.session_state[key]
218
  return response_mapping
219
 
 
241
 
242
  def get_default_direct_assessment_prompt(user_prompt):
243
  return get_direct_assessment_prompt(
244
+ direct_assessment_prompt=DEFAULT_DIRECT_ASSESSMENT_PROMPT,
245
  user_prompt=user_prompt,
246
+ response="{response}",
247
  criteria_list=DEFAULT_DIRECT_ASSESSMENT_CRITERIA_LIST,
248
  options=SEVEN_POINT_DIRECT_ASSESSMENT_OPTIONS,
249
  )
 
251
 
252
  def get_aggregator_prompt(aggregator_prompt, user_prompt, llms):
253
  responses_from_other_llms = "\n\n".join(
254
+ [
255
+ f"{get_ui_friendly_name(model)} START\n{st.session_state.get(get_response_key(model))}\n\n{get_ui_friendly_name(model)} END\n\n\n"
256
+ for model in llms
257
+ ]
258
  )
259
  return aggregator_prompt.format(
260
  user_prompt=user_prompt,
 
270
  )
271
 
272
 
273
+ def get_ui_friendly_name(llm):
274
+ return LLM_TO_UI_NAME_MAP.get(llm, llm)
275
+
276
+
277
+ def get_parse_judging_response_for_direct_assessment_prompt(
278
+ judging_responses: dict[str, str],
279
+ criteria_list,
280
+ options,
281
+ ):
282
+ formatted_judging_responses = "\n\n".join(
283
+ [
284
+ f"{get_ui_friendly_name(model)} START\n{judging_responses[model]}\n\n{get_ui_friendly_name(model)} END\n\n\n"
285
+ for model in judging_responses.keys()
286
+ ]
287
+ )
288
+ return PARSE_JUDGING_RESPONSE_FOR_DIRECT_ASSESSMENT_PROMPT.format(
289
+ judging_responses=formatted_judging_responses,
290
+ criteria_list=format_criteria_list(criteria_list),
291
+ options=format_likert_comparison_options(options),
292
+ )
293
+
294
+
295
+ def get_model_from_direct_assessment_judging_key(judging_key):
296
+ return judging_key.split("__")[1]
297
+
298
+
299
+ def get_direct_assessment_judging_responses():
300
+ # Get the judging responses from the session state.
301
+ judging_responses = {}
302
+ for key in st.session_state.keys():
303
+ if key.startswith("direct_assessment_judge__"):
304
+ judging_responses[get_model_from_direct_assessment_judging_key(key)] = (
305
+ st.session_state[key]
306
+ )
307
+ return judging_responses
308
+
309
+
310
+ def parse_judging_responses(prompt: str) -> DirectAssessmentJudgingResponse:
311
+ completion = client.beta.chat.completions.parse(
312
+ model="gpt-4o-mini",
313
+ messages=[
314
+ {
315
+ "role": "system",
316
+ "content": "Parse the judging responses into structured data.",
317
+ },
318
+ {"role": "user", "content": prompt},
319
+ ],
320
+ response_format=DirectAssessmentJudgingResponse,
321
+ )
322
+ return completion.choices[0].message.parsed
323
+
324
+
325
+ def plot_criteria_scores(df):
326
+ # Group by criteria and calculate mean and std over all judges.
327
+ grouped = df.groupby(["criteria"]).agg({"score": ["mean", "std"]}).reset_index()
328
+
329
+ # Flatten the MultiIndex columns
330
+ grouped.columns = ["criteria", "mean_score", "std_score"]
331
+
332
+ # Fill NaN std with zeros (in case there's only one score per group)
333
+ grouped["std_score"] = grouped["std_score"].fillna(0)
334
+
335
+ # Set up the plot
336
+ plt.figure(figsize=(8, 5))
337
+
338
+ # Create a horizontal bar plot
339
+ ax = sns.barplot(
340
+ data=grouped,
341
+ x="mean_score",
342
+ y="criteria",
343
+ hue="criteria",
344
+ errorbar=None, # Updated parameter
345
+ orient="h",
346
+ )
347
+
348
+ # Add error bars manually
349
+ # Iterate over the bars and add error bars
350
+ for i, (mean, std) in enumerate(zip(grouped["mean_score"], grouped["std_score"])):
351
+ # Get the current bar
352
+ bar = ax.patches[i]
353
+ # Calculate the center of the bar
354
+ center = bar.get_y() + bar.get_height() / 2
355
+ # Add the error bar
356
+ ax.errorbar(x=mean, y=center, xerr=std, ecolor="black", capsize=3, fmt="none")
357
+
358
+ # Set labels and title
359
+ ax.set_xlabel("")
360
+ ax.set_ylabel("")
361
+ plt.tight_layout()
362
+
363
+ # Display the plot in Streamlit
364
+ st.pyplot(plt.gcf())
365
+
366
+
367
  # Main Streamlit App
368
  def main():
369
  st.set_page_config(
 
419
  selected_models = llm_council_selector()
420
  st.write("Selected Models:", selected_models)
421
  selected_aggregator = aggregator_selector()
 
422
 
423
  # Prompt input
424
  user_prompt = st.text_area("Enter your prompt:")
 
426
  if st.button("Submit"):
427
  st.write("Responses:")
428
 
429
+ response_columns = st.columns(3)
430
+
431
+ selected_models_to_streamlit_column_map = {
432
+ model: response_columns[i] for i, model in enumerate(selected_models)
433
+ }
434
+
435
  # Fetching and streaming responses from each selected model
436
+ for selected_model in selected_models:
437
+ with selected_models_to_streamlit_column_map[selected_model]:
438
+ st.write(get_ui_friendly_name(selected_model))
439
+ with st.chat_message(
440
+ selected_model,
441
+ avatar=PROVIDER_TO_AVATAR_MAP[selected_model],
442
+ ):
443
+ message_placeholder = st.empty()
444
+ stream = get_llm_response_stream(selected_model, user_prompt)
445
+ if stream:
446
+ st.session_state[get_response_key(selected_model)] = (
447
+ message_placeholder.write_stream(stream)
448
+ )
449
 
450
  # Get the aggregator prompt.
451
  aggregator_prompt = get_default_aggregator_prompt(
 
453
  )
454
 
455
  with st.expander("Aggregator Prompt"):
456
+ st.code(aggregator_prompt)
457
 
458
  # Fetching and streaming response from the aggregator
459
+ st.write(
460
+ f"Mixture-of-Agents response from {get_ui_friendly_name(selected_aggregator)}"
461
+ )
462
  with st.chat_message(
463
  selected_aggregator,
464
  avatar=PROVIDER_TO_AVATAR_MAP[selected_aggregator],
 
484
 
485
  # Depending on the assessment type, render different forms
486
  if assessment_type == "Direct Assessment":
487
+ with st.expander("Direct Assessment Prompt"):
488
+ direct_assessment_prompt = st.text_area(
489
+ "Prompt for the Direct Assessment",
490
+ value=get_default_direct_assessment_prompt(user_prompt=user_prompt),
491
+ height=500,
492
+ )
493
 
494
  # TODO: Add option to edit criteria list with a basic text field.
495
  criteria_list = DEFAULT_DIRECT_ASSESSMENT_CRITERIA_LIST
 
502
 
503
  response_judging_columns = st.columns(3)
504
 
505
+ responses_for_judging_to_streamlit_column_map = {
506
  model: response_judging_columns[i % 3]
507
  for i, model in enumerate(responses_for_judging.keys())
508
  }
 
510
  # Get judging responses.
511
  for response_model, response in responses_for_judging.items():
512
 
513
+ st_column = responses_for_judging_to_streamlit_column_map[
514
+ response_model
 
 
515
  ]
516
 
517
  with st_column:
518
+ if "aggregator_response" in response_model:
519
+ judging_model_header = "Mixture-of-Agents Response"
520
+ else:
521
+ judging_model_header = get_ui_friendly_name(response_model)
522
+ st.write(f"Judging for {judging_model_header}")
523
  judging_prompt = get_direct_assessment_prompt(
524
+ direct_assessment_prompt=direct_assessment_prompt,
525
+ user_prompt=user_prompt,
526
+ response=response,
527
+ criteria_list=criteria_list,
528
+ options=SEVEN_POINT_DIRECT_ASSESSMENT_OPTIONS,
529
  )
530
 
531
+ with st.expander("Final Judging Prompt"):
532
+ st.code(judging_prompt)
533
+
534
  for judging_model in selected_models:
535
+ with st.expander(
536
+ get_ui_friendly_name(judging_model), expanded=False
537
+ ):
538
  with st.chat_message(
539
  judging_model,
540
  avatar=PROVIDER_TO_AVATAR_MAP[judging_model],
541
  ):
 
542
  message_placeholder = st.empty()
543
  judging_stream = get_llm_response_stream(
544
  judging_model, judging_prompt
545
  )
546
  if judging_stream:
547
  st.session_state[
548
+ get_direct_assessment_judging_key(
549
  judging_model, response_model
550
  )
551
  ] = message_placeholder.write_stream(
 
554
  # When all of the judging is finished for the given response, get the actual
555
  # values, parsed (use gpt-4o-mini for now) with json mode.
556
  # TODO.
557
+ judging_responses = get_direct_assessment_judging_responses()
558
+ parse_judging_response_prompt = (
559
+ get_parse_judging_response_for_direct_assessment_prompt(
560
+ judging_responses,
561
+ criteria_list,
562
+ SEVEN_POINT_DIRECT_ASSESSMENT_OPTIONS,
563
+ )
564
+ )
565
+ # Issue the prompt to openai mini with structured outputs
566
+ parsed_judging_responses = parse_judging_responses(
567
+ parse_judging_response_prompt
568
+ )
569
+
570
+ df = create_dataframe_for_direct_assessment_judging_response(
571
+ parsed_judging_responses
572
+ )
573
+ st.write(df)
574
+
575
+ # Log the output using st.write() under an st.expander
576
+ # with st.expander("Parsed Judging Responses", expanded=True):
577
+ # st.write(parsed_judging_responses)
578
+ plot_criteria_scores(df)
579
+
580
+ # TODO: Use parsed_judging_responses for further processing or display
581
 
582
  elif assessment_type == "Pairwise Comparison":
583
  pairwise_comparison_prompt = st.text_area(
constants.py CHANGED
@@ -24,6 +24,16 @@ PROVIDER_TO_AVATAR_MAP = {
24
  "anthropic://claude-3-haiku-20240307": "",
25
  }
26
 
 
 
 
 
 
 
 
 
 
 
27
  # AGGREGATORS = ["openai://gpt-4o-mini", "openai://gpt-4o"]
28
  AGGREGATORS = ["together://meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo"]
29
 
 
24
  "anthropic://claude-3-haiku-20240307": "",
25
  }
26
 
27
+ LLM_TO_UI_NAME_MAP = {
28
+ "openai://gpt-4o-mini": "GPT-4 Turbo Mini",
29
+ "anthropic://claude-3-5-sonnet": "Claude 3 Sonnet",
30
+ "vertex://gemini-1.5-flash-001": "Gemini 1.5 Flash",
31
+ "together://meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo": "Llama 3.1 8B Instruct",
32
+ "together://meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo": "Llama 3.1 70B Instruct",
33
+ "together://meta-llama/Llama-3.2-3B-Instruct-Turbo": "Llama 3.2 3B Instruct",
34
+ "anthropic://claude-3-haiku-20240307": "Claude 3 Haiku",
35
+ }
36
+
37
  # AGGREGATORS = ["openai://gpt-4o-mini", "openai://gpt-4o"]
38
  AGGREGATORS = ["together://meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo"]
39
 
judging_dataclasses.py CHANGED
@@ -26,3 +26,18 @@ class PairwiseComparison(BaseModel):
26
 
27
  class JudgingConfig(BaseModel):
28
  assessment: Union[DirectAssessment, PairwiseComparison]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  class JudgingConfig(BaseModel):
28
  assessment: Union[DirectAssessment, PairwiseComparison]
29
+
30
+
31
+ class DirectAssessmentCriterionScore(BaseModel):
32
+ criterion: str
33
+ score: int
34
+ explanation: str
35
+
36
+
37
+ class DirectAssessmentCriteriaScores(BaseModel):
38
+ model: str
39
+ criteria_scores: List[DirectAssessmentCriterionScore]
40
+
41
+
42
+ class DirectAssessmentJudgingResponse(BaseModel):
43
+ judging_models: List[DirectAssessmentCriteriaScores]
prompts.py CHANGED
@@ -1,6 +1,21 @@
1
  from judging_dataclasses import Criteria
2
 
3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  DEFAULT_AGGREGATOR_PROMPT = """We are trying to come up with the best response to a user query based on an aggregation of other responses.
5
 
6
  [USER PROMPT START]
 
1
  from judging_dataclasses import Criteria
2
 
3
 
4
+ PARSE_JUDGING_RESPONSE_FOR_DIRECT_ASSESSMENT_PROMPT = """We are trying to parse the responses from the judges for a direct assessment.
5
+
6
+ Each judge was asked to give a rating for each of the following criteria, along with an explanation:
7
+ {criteria_list}
8
+
9
+ The possible options for each criterion are as follows:
10
+ {options}
11
+
12
+ The responses from the judges are as follows:
13
+ {judging_responses}
14
+
15
+ Please provide a JSON object with the following structure that includes the model name and the scores for each of the criteria, along with the explanation.
16
+ """
17
+
18
+
19
  DEFAULT_AGGREGATOR_PROMPT = """We are trying to come up with the best response to a user query based on an aggregation of other responses.
20
 
21
  [USER PROMPT START]