justinxzhao commited on
Commit
38e43b5
·
1 Parent(s): 3e0f8f8

Overall scores graph complete.

Browse files
Files changed (4) hide show
  1. .gitignore +2 -1
  2. app.py +361 -123
  3. img/council_icon.png +0 -0
  4. prompts.py +4 -1
.gitignore CHANGED
@@ -1,2 +1,3 @@
1
  env/
2
- client_secret.json
 
 
1
  env/
2
+ client_secret.json
3
+ __pycache__
app.py CHANGED
@@ -15,10 +15,15 @@ from constants import (
15
  LLM_TO_UI_NAME_MAP,
16
  )
17
  from prompts import *
18
- from judging_dataclasses import DirectAssessmentJudgingResponse
 
 
 
 
19
  import pandas as pd
20
  import seaborn as sns
21
  import matplotlib.pyplot as plt
 
22
 
23
  dotenv.load_dotenv()
24
 
@@ -67,6 +72,16 @@ def anthropic_streamlit_streamer(stream):
67
  break # End of message, stop streaming
68
 
69
 
 
 
 
 
 
 
 
 
 
 
70
  def google_streamlit_streamer(stream):
71
  for chunk in stream:
72
  yield chunk.text
@@ -146,22 +161,6 @@ def get_llm_response_stream(model_identifier, prompt):
146
  return None
147
 
148
 
149
- def get_response_key(model):
150
- return model + "__response"
151
-
152
-
153
- def get_model_from_response_key(response_key):
154
- return response_key.split("__")[0]
155
-
156
-
157
- def get_direct_assessment_judging_key(judge_model, response_model):
158
- return "direct_assessment_judge__" + judge_model + "__" + response_model
159
-
160
-
161
- def get_aggregator_response_key(model):
162
- return model + "__aggregator_response"
163
-
164
-
165
  def create_dataframe_for_direct_assessment_judging_response(
166
  response: DirectAssessmentJudgingResponse,
167
  ):
@@ -203,21 +202,6 @@ def render_criteria_form(criteria_num):
203
  )
204
 
205
 
206
- def get_response_mapping():
207
- # Inspect the session state for all the responses.
208
- # This is a dictionary mapping model names to their responses.
209
- # The aggregator response is also included in this mapping under the key "<model>__aggregator_response".
210
- response_mapping = {}
211
- for key in st.session_state.keys():
212
- if "judge" in key:
213
- continue
214
- if key.endswith("__response"):
215
- response_mapping[get_model_from_response_key(key)] = st.session_state[key]
216
- if key.endswith("__aggregator_response"):
217
- response_mapping[key] = st.session_state[key]
218
- return response_mapping
219
-
220
-
221
  def format_likert_comparison_options(options):
222
  return "\n".join([f"{i + 1}: {option}" for i, option in enumerate(options)])
223
 
@@ -252,7 +236,7 @@ def get_default_direct_assessment_prompt(user_prompt):
252
  def get_aggregator_prompt(aggregator_prompt, user_prompt, llms):
253
  responses_from_other_llms = "\n\n".join(
254
  [
255
- f"{get_ui_friendly_name(model)} START\n{st.session_state.get(get_response_key(model))}\n\n{get_ui_friendly_name(model)} END\n\n\n"
256
  for model in llms
257
  ]
258
  )
@@ -270,10 +254,6 @@ def get_default_aggregator_prompt(user_prompt, llms):
270
  )
271
 
272
 
273
- def get_ui_friendly_name(llm):
274
- return LLM_TO_UI_NAME_MAP.get(llm, llm)
275
-
276
-
277
  def get_parse_judging_response_for_direct_assessment_prompt(
278
  judging_responses: dict[str, str],
279
  criteria_list,
@@ -292,34 +272,58 @@ def get_parse_judging_response_for_direct_assessment_prompt(
292
  )
293
 
294
 
295
- def get_model_from_direct_assessment_judging_key(judging_key):
296
- return judging_key.split("__")[1]
297
-
298
-
299
- def get_direct_assessment_judging_responses():
300
- # Get the judging responses from the session state.
301
- judging_responses = {}
302
- for key in st.session_state.keys():
303
- if key.startswith("direct_assessment_judge__"):
304
- judging_responses[get_model_from_direct_assessment_judging_key(key)] = (
305
- st.session_state[key]
306
- )
307
- return judging_responses
308
-
309
-
310
- def parse_judging_responses(prompt: str) -> DirectAssessmentJudgingResponse:
311
- completion = client.beta.chat.completions.parse(
312
- model="gpt-4o-mini",
313
- messages=[
314
- {
315
- "role": "system",
316
- "content": "Parse the judging responses into structured data.",
317
- },
318
- {"role": "user", "content": prompt},
319
- ],
320
- response_format=DirectAssessmentJudgingResponse,
321
- )
322
- return completion.choices[0].message.parsed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
323
 
324
 
325
  def plot_criteria_scores(df):
@@ -364,6 +368,94 @@ def plot_criteria_scores(df):
364
  st.pyplot(plt.gcf())
365
 
366
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
367
  # Main Streamlit App
368
  def main():
369
  st.set_page_config(
@@ -395,7 +487,7 @@ def main():
395
 
396
  # App title and description
397
  st.title("Language Model Council Sandbox")
398
- st.markdown("###### Invoke a council of LLMs to generate and judge each other.")
399
  st.markdown("###### [Paper](https://arxiv.org/abs/2406.08598)")
400
 
401
  # Authentication system
@@ -413,18 +505,30 @@ def main():
413
  st.error("Invalid credentials")
414
 
415
  if st.session_state.authenticated:
416
- st.success("Logged in successfully!")
 
417
 
418
  # Council and aggregator selection
419
  selected_models = llm_council_selector()
420
- st.write("Selected Models:", selected_models)
 
 
421
  selected_aggregator = aggregator_selector()
422
 
 
 
 
 
 
 
423
  # Prompt input
424
- user_prompt = st.text_area("Enter your prompt:")
 
 
 
425
 
426
- if st.button("Submit"):
427
- st.write("Responses:")
428
 
429
  response_columns = st.columns(3)
430
 
@@ -443,7 +547,7 @@ def main():
443
  message_placeholder = st.empty()
444
  stream = get_llm_response_stream(selected_model, user_prompt)
445
  if stream:
446
- st.session_state[get_response_key(selected_model)] = (
447
  message_placeholder.write_stream(stream)
448
  )
449
 
@@ -456,25 +560,25 @@ def main():
456
  st.code(aggregator_prompt)
457
 
458
  # Fetching and streaming response from the aggregator
459
- st.write(
460
- f"Mixture-of-Agents response from {get_ui_friendly_name(selected_aggregator)}"
461
- )
462
  with st.chat_message(
463
  selected_aggregator,
464
- avatar=PROVIDER_TO_AVATAR_MAP[selected_aggregator],
465
  ):
466
  message_placeholder = st.empty()
467
  aggregator_stream = get_llm_response_stream(
468
  selected_aggregator, aggregator_prompt
469
  )
470
  if aggregator_stream:
471
- message_placeholder.write_stream(aggregator_stream)
472
- st.session_state[
473
- get_aggregator_response_key(selected_aggregator)
474
- ] = message_placeholder.write_stream(aggregator_stream)
 
 
475
 
476
  # Judging.
477
- st.markdown("#### Judging Configuration Form")
478
 
479
  # Choose the type of assessment
480
  assessment_type = st.radio(
@@ -482,9 +586,48 @@ def main():
482
  options=["Direct Assessment", "Pairwise Comparison"],
483
  )
484
 
 
 
485
  # Depending on the assessment type, render different forms
486
  if assessment_type == "Direct Assessment":
487
- with st.expander("Direct Assessment Prompt"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
488
  direct_assessment_prompt = st.text_area(
489
  "Prompt for the Direct Assessment",
490
  value=get_default_direct_assessment_prompt(user_prompt=user_prompt),
@@ -495,10 +638,15 @@ def main():
495
  criteria_list = DEFAULT_DIRECT_ASSESSMENT_CRITERIA_LIST
496
 
497
  # Create DirectAssessment object when form is submitted
498
- if st.button("Submit Direct Assessment"):
 
 
499
 
500
  # Submit direct asssessment.
501
- responses_for_judging = get_response_mapping()
 
 
 
502
 
503
  response_judging_columns = st.columns(3)
504
 
@@ -515,11 +663,13 @@ def main():
515
  ]
516
 
517
  with st_column:
518
- if "aggregator_response" in response_model:
519
  judging_model_header = "Mixture-of-Agents Response"
520
  else:
521
  judging_model_header = get_ui_friendly_name(response_model)
522
  st.write(f"Judging for {judging_model_header}")
 
 
523
  judging_prompt = get_direct_assessment_prompt(
524
  direct_assessment_prompt=direct_assessment_prompt,
525
  user_prompt=user_prompt,
@@ -543,18 +693,27 @@ def main():
543
  judging_stream = get_llm_response_stream(
544
  judging_model, judging_prompt
545
  )
546
- if judging_stream:
547
- st.session_state[
548
- get_direct_assessment_judging_key(
549
- judging_model, response_model
550
- )
551
- ] = message_placeholder.write_stream(
552
- judging_stream
553
- )
554
  # When all of the judging is finished for the given response, get the actual
555
  # values, parsed (use gpt-4o-mini for now) with json mode.
556
  # TODO.
557
- judging_responses = get_direct_assessment_judging_responses()
 
 
 
 
 
 
 
 
 
558
  parse_judging_response_prompt = (
559
  get_parse_judging_response_for_direct_assessment_prompt(
560
  judging_responses,
@@ -562,45 +721,124 @@ def main():
562
  SEVEN_POINT_DIRECT_ASSESSMENT_OPTIONS,
563
  )
564
  )
 
 
565
  # Issue the prompt to openai mini with structured outputs
566
  parsed_judging_responses = parse_judging_responses(
567
- parse_judging_response_prompt
568
  )
569
 
570
- df = create_dataframe_for_direct_assessment_judging_response(
 
 
571
  parsed_judging_responses
572
  )
573
- st.write(df)
 
 
 
 
574
 
575
- # Log the output using st.write() under an st.expander
576
- # with st.expander("Parsed Judging Responses", expanded=True):
577
- # st.write(parsed_judging_responses)
578
- plot_criteria_scores(df)
 
579
 
580
- # TODO: Use parsed_judging_responses for further processing or display
 
 
 
 
 
 
581
 
582
- elif assessment_type == "Pairwise Comparison":
583
- pairwise_comparison_prompt = st.text_area(
584
- "Prompt for the Pairwise Comparison"
585
- )
586
- granularity = st.selectbox("Granularity", ["coarse", "fine", "super fine"])
587
- ties_allowed = st.checkbox("Are ties allowed?")
588
- position_swapping = st.checkbox("Enable position swapping?")
589
- reference_model = st.text_input("Reference Model")
590
-
591
- # Create PairwiseComparison object when form is submitted
592
- if st.button("Submit Pairwise Comparison"):
593
- pairwise_comparison_config = PairwiseComparison(
594
- type="pairwise_comparison",
595
- granularity=granularity,
596
- ties_allowed=ties_allowed,
597
- position_swapping=position_swapping,
598
- reference_model=reference_model,
599
- prompt=prompt,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
600
  )
601
- st.success(f"Pairwise Comparison Created: {pairwise_comparison_config}")
602
- # Submit pairwise comparison.
603
- responses_for_judging = get_response_mapping()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
604
 
605
  else:
606
  with cols[1]:
 
15
  LLM_TO_UI_NAME_MAP,
16
  )
17
  from prompts import *
18
+ from judging_dataclasses import (
19
+ DirectAssessmentJudgingResponse,
20
+ DirectAssessmentCriterionScore,
21
+ DirectAssessmentCriteriaScores,
22
+ )
23
  import pandas as pd
24
  import seaborn as sns
25
  import matplotlib.pyplot as plt
26
+ import numpy as np
27
 
28
  dotenv.load_dotenv()
29
 
 
72
  break # End of message, stop streaming
73
 
74
 
75
+ def get_ui_friendly_name(llm):
76
+ if "agg__" in llm:
77
+ return (
78
+ "MoA ("
79
+ + LLM_TO_UI_NAME_MAP.get(llm.split("__")[1], llm.split("__")[1])
80
+ + ")"
81
+ )
82
+ return LLM_TO_UI_NAME_MAP.get(llm, llm)
83
+
84
+
85
  def google_streamlit_streamer(stream):
86
  for chunk in stream:
87
  yield chunk.text
 
161
  return None
162
 
163
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  def create_dataframe_for_direct_assessment_judging_response(
165
  response: DirectAssessmentJudgingResponse,
166
  ):
 
202
  )
203
 
204
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  def format_likert_comparison_options(options):
206
  return "\n".join([f"{i + 1}: {option}" for i, option in enumerate(options)])
207
 
 
236
  def get_aggregator_prompt(aggregator_prompt, user_prompt, llms):
237
  responses_from_other_llms = "\n\n".join(
238
  [
239
+ f"{get_ui_friendly_name(model)} START\n{st.session_state['responses'][model]}\n\n{get_ui_friendly_name(model)} END\n\n\n"
240
  for model in llms
241
  ]
242
  )
 
254
  )
255
 
256
 
 
 
 
 
257
  def get_parse_judging_response_for_direct_assessment_prompt(
258
  judging_responses: dict[str, str],
259
  criteria_list,
 
272
  )
273
 
274
 
275
+ DEBUG_MODE = True
276
+
277
+
278
+ def parse_judging_responses(
279
+ prompt: str, judging_responses: dict[str, str]
280
+ ) -> DirectAssessmentJudgingResponse:
281
+ if DEBUG_MODE:
282
+ return DirectAssessmentJudgingResponse(
283
+ judging_models=[
284
+ DirectAssessmentCriteriaScores(
285
+ model="together://meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
286
+ criteria_scores=[
287
+ DirectAssessmentCriterionScore(
288
+ criterion="helpfulness", score=3, explanation="explanation1"
289
+ ),
290
+ DirectAssessmentCriterionScore(
291
+ criterion="conciseness", score=4, explanation="explanation2"
292
+ ),
293
+ DirectAssessmentCriterionScore(
294
+ criterion="relevance", score=5, explanation="explanation3"
295
+ ),
296
+ ],
297
+ ),
298
+ DirectAssessmentCriteriaScores(
299
+ model="together://meta-llama/Llama-3.2-3B-Instruct-Turbo",
300
+ criteria_scores=[
301
+ DirectAssessmentCriterionScore(
302
+ criterion="helpfulness", score=1, explanation="explanation1"
303
+ ),
304
+ DirectAssessmentCriterionScore(
305
+ criterion="conciseness", score=2, explanation="explanation2"
306
+ ),
307
+ DirectAssessmentCriterionScore(
308
+ criterion="relevance", score=3, explanation="explanation3"
309
+ ),
310
+ ],
311
+ ),
312
+ ]
313
+ )
314
+ else:
315
+ completion = client.beta.chat.completions.parse(
316
+ model="gpt-4o-mini",
317
+ messages=[
318
+ {
319
+ "role": "system",
320
+ "content": "Parse the judging responses into structured data.",
321
+ },
322
+ {"role": "user", "content": prompt},
323
+ ],
324
+ response_format=DirectAssessmentJudgingResponse,
325
+ )
326
+ return completion.choices[0].message.parsed
327
 
328
 
329
  def plot_criteria_scores(df):
 
368
  st.pyplot(plt.gcf())
369
 
370
 
371
+ def plot_overall_scores(overall_scores_df):
372
+ # Calculate mean and standard deviation
373
+ summary = (
374
+ overall_scores_df.groupby("response_model")
375
+ .agg({"score": ["mean", "std"]})
376
+ .reset_index()
377
+ )
378
+ summary.columns = ["response_model", "mean_score", "std_score"]
379
+
380
+ # Add UI-friendly names
381
+ summary["ui_friendly_name"] = summary["response_model"].apply(get_ui_friendly_name)
382
+
383
+ # Sort the summary dataframe by mean_score in descending order
384
+ summary = summary.sort_values("mean_score", ascending=False)
385
+
386
+ # Create the plot
387
+ plt.figure(figsize=(8, 5))
388
+
389
+ # Plot bars with rainbow colors
390
+ ax = sns.barplot(
391
+ x="ui_friendly_name",
392
+ y="mean_score",
393
+ data=summary,
394
+ palette="prism",
395
+ capsize=0.1,
396
+ )
397
+
398
+ # Add error bars manually
399
+ x_coords = range(len(summary))
400
+ plt.errorbar(
401
+ x=x_coords,
402
+ y=summary["mean_score"],
403
+ yerr=summary["std_score"],
404
+ fmt="none",
405
+ c="black",
406
+ capsize=5,
407
+ zorder=10, # Ensure error bars are on top
408
+ )
409
+
410
+ # Add text annotations
411
+ for i, row in summary.iterrows():
412
+ ax.text(
413
+ i,
414
+ row["mean_score"],
415
+ f"{row['mean_score']:.2f}",
416
+ ha="center",
417
+ va="bottom",
418
+ fontweight="bold",
419
+ color="black",
420
+ bbox=dict(facecolor="white", edgecolor="none", alpha=0.7, pad=0.5),
421
+ )
422
+
423
+ # Customize the plot
424
+ plt.xlabel("")
425
+ plt.ylabel("Overall Score")
426
+ plt.xticks(rotation=45, ha="right")
427
+ plt.tight_layout()
428
+
429
+ # Display the plot in Streamlit
430
+ st.pyplot(plt.gcf())
431
+
432
+
433
+ def plot_per_judge_overall_scores(df):
434
+ # Find the overall score by finding the overall score for each judge, and then averaging
435
+ # over all judges.
436
+ grouped = df.groupby(["llm_judge_model"]).agg({"score": ["mean"]}).reset_index()
437
+ grouped.columns = ["llm_judge_model", "overall_score"]
438
+
439
+ # Create the horizontal bar plot
440
+ plt.figure(figsize=(10, 6))
441
+ ax = sns.barplot(
442
+ data=grouped,
443
+ y="llm_judge_model",
444
+ x="overall_score",
445
+ hue="llm_judge_model",
446
+ orient="h",
447
+ )
448
+
449
+ # Customize the plot
450
+ plt.title("Overall Scores by LLM Judge Model")
451
+ plt.xlabel("Overall Score")
452
+ plt.ylabel("LLM Judge Model")
453
+
454
+ # Adjust layout and display the plot
455
+ plt.tight_layout()
456
+ st.pyplot(plt)
457
+
458
+
459
  # Main Streamlit App
460
  def main():
461
  st.set_page_config(
 
487
 
488
  # App title and description
489
  st.title("Language Model Council Sandbox")
490
+ st.markdown("###### Invoke a council of LLMs to judge each other's responses.")
491
  st.markdown("###### [Paper](https://arxiv.org/abs/2406.08598)")
492
 
493
  # Authentication system
 
505
  st.error("Invalid credentials")
506
 
507
  if st.session_state.authenticated:
508
+ # cols[1].success("Logged in successfully!")
509
+ st.markdown("#### LLM Council Member Selection")
510
 
511
  # Council and aggregator selection
512
  selected_models = llm_council_selector()
513
+
514
+ # st.write("Selected Models:", selected_models)
515
+
516
  selected_aggregator = aggregator_selector()
517
 
518
+ # Initialize session state for collecting responses.
519
+ if "responses" not in st.session_state:
520
+ st.session_state.responses = {}
521
+ # if "aggregator_response" not in st.session_state:
522
+ # st.session_state.aggregator_response = {}
523
+
524
  # Prompt input
525
+ st.markdown("#### Enter your prompt")
526
+ _, center_column, _ = st.columns([3, 5, 3])
527
+ with center_column:
528
+ user_prompt = st.text_area(value="Say 'Hello World'", label="")
529
 
530
+ if center_column.button("Submit", use_container_width=True):
531
+ st.markdown("#### Responses")
532
 
533
  response_columns = st.columns(3)
534
 
 
547
  message_placeholder = st.empty()
548
  stream = get_llm_response_stream(selected_model, user_prompt)
549
  if stream:
550
+ st.session_state["responses"][selected_model] = (
551
  message_placeholder.write_stream(stream)
552
  )
553
 
 
560
  st.code(aggregator_prompt)
561
 
562
  # Fetching and streaming response from the aggregator
563
+ st.write(f"Mixture-of-Agents ({get_ui_friendly_name(selected_aggregator)})")
 
 
564
  with st.chat_message(
565
  selected_aggregator,
566
+ avatar="img/council_icon.png",
567
  ):
568
  message_placeholder = st.empty()
569
  aggregator_stream = get_llm_response_stream(
570
  selected_aggregator, aggregator_prompt
571
  )
572
  if aggregator_stream:
573
+ st.session_state["responses"]["agg__" + selected_aggregator] = (
574
+ message_placeholder.write_stream(aggregator_stream)
575
+ )
576
+
577
+ # st.write("Responses (in session state):")
578
+ # st.write(st.session_state["responses"])
579
 
580
  # Judging.
581
+ st.markdown("#### Judging Configuration")
582
 
583
  # Choose the type of assessment
584
  assessment_type = st.radio(
 
586
  options=["Direct Assessment", "Pairwise Comparison"],
587
  )
588
 
589
+ _, center_column, _ = st.columns([3, 5, 3])
590
+
591
  # Depending on the assessment type, render different forms
592
  if assessment_type == "Direct Assessment":
593
+
594
+ # Initialize session state for direct assessment.
595
+ if "direct_assessment_overall_score" not in st.session_state:
596
+ st.session_state["direct_assessment_overall_score"] = {}
597
+ if "direct_assessment_judging_df" not in st.session_state:
598
+ st.session_state["direct_assessment_judging_df"] = {}
599
+ for response_model in selected_models:
600
+ st.session_state["direct_assessment_judging_df"][
601
+ response_model
602
+ ] = {}
603
+ # aggregator model
604
+ st.session_state["direct_assessment_judging_df"][
605
+ "agg__" + selected_aggregator
606
+ ] = {}
607
+ if "direct_assessment_judging_responses" not in st.session_state:
608
+ st.session_state["direct_assessment_judging_responses"] = {}
609
+ for response_model in selected_models:
610
+ st.session_state["direct_assessment_judging_responses"][
611
+ response_model
612
+ ] = {}
613
+ # aggregator model
614
+ st.session_state["direct_assessment_judging_responses"][
615
+ "agg__" + selected_aggregator
616
+ ] = {}
617
+ if "direct_assessment_overall_scores" not in st.session_state:
618
+ st.session_state["direct_assessment_overall_scores"] = {}
619
+ for response_model in selected_models:
620
+ st.session_state["direct_assessment_overall_scores"][
621
+ response_model
622
+ ] = {}
623
+ st.session_state["direct_assessment_overall_scores"][
624
+ "agg__" + selected_aggregator
625
+ ] = {}
626
+ if "judging_status" not in st.session_state:
627
+ st.session_state["judging_status"] = "incomplete"
628
+
629
+ # Direct assessment prompt.
630
+ with center_column.expander("Direct Assessment Prompt"):
631
  direct_assessment_prompt = st.text_area(
632
  "Prompt for the Direct Assessment",
633
  value=get_default_direct_assessment_prompt(user_prompt=user_prompt),
 
638
  criteria_list = DEFAULT_DIRECT_ASSESSMENT_CRITERIA_LIST
639
 
640
  # Create DirectAssessment object when form is submitted
641
+ if center_column.button(
642
+ "Submit Direct Assessment", use_container_width=True
643
+ ):
644
 
645
  # Submit direct asssessment.
646
+ responses_for_judging = st.session_state["responses"]
647
+
648
+ # st.write("Responses for judging (in session state):")
649
+ # st.write(responses_for_judging)
650
 
651
  response_judging_columns = st.columns(3)
652
 
 
663
  ]
664
 
665
  with st_column:
666
+ if "agg__" in response_model:
667
  judging_model_header = "Mixture-of-Agents Response"
668
  else:
669
  judging_model_header = get_ui_friendly_name(response_model)
670
  st.write(f"Judging for {judging_model_header}")
671
+ # st.write("Response being judged: ")
672
+ # st.write(response)
673
  judging_prompt = get_direct_assessment_prompt(
674
  direct_assessment_prompt=direct_assessment_prompt,
675
  user_prompt=user_prompt,
 
693
  judging_stream = get_llm_response_stream(
694
  judging_model, judging_prompt
695
  )
696
+ # if judging_stream:
697
+ st.session_state[
698
+ "direct_assessment_judging_responses"
699
+ ][response_model][
700
+ judging_model
701
+ ] = message_placeholder.write_stream(
702
+ judging_stream
703
+ )
704
  # When all of the judging is finished for the given response, get the actual
705
  # values, parsed (use gpt-4o-mini for now) with json mode.
706
  # TODO.
707
+ judging_responses = st.session_state[
708
+ "direct_assessment_judging_responses"
709
+ ][response_model]
710
+
711
+ # st.write("Judging responses (in session state):")
712
+ # st.write(judging_responses)
713
+
714
+ if not judging_responses:
715
+ st.error(f"No judging responses for {response_model}")
716
+ quit()
717
  parse_judging_response_prompt = (
718
  get_parse_judging_response_for_direct_assessment_prompt(
719
  judging_responses,
 
721
  SEVEN_POINT_DIRECT_ASSESSMENT_OPTIONS,
722
  )
723
  )
724
+ with st.expander("Parse Judging Response Prompt"):
725
+ st.code(parse_judging_response_prompt)
726
  # Issue the prompt to openai mini with structured outputs
727
  parsed_judging_responses = parse_judging_responses(
728
+ parse_judging_response_prompt, judging_responses
729
  )
730
 
731
+ st.session_state["direct_assessment_judging_df"][
732
+ response_model
733
+ ] = create_dataframe_for_direct_assessment_judging_response(
734
  parsed_judging_responses
735
  )
736
+ st.write(
737
+ st.session_state["direct_assessment_judging_df"][
738
+ response_model
739
+ ]
740
+ )
741
 
742
+ plot_criteria_scores(
743
+ st.session_state["direct_assessment_judging_df"][
744
+ response_model
745
+ ]
746
+ )
747
 
748
+ # Find the overall score by finding the overall score for each judge, and then averaging
749
+ # over all judges.
750
+ plot_per_judge_overall_scores(
751
+ st.session_state["direct_assessment_judging_df"][
752
+ response_model
753
+ ]
754
+ )
755
 
756
+ grouped = (
757
+ st.session_state["direct_assessment_judging_df"][
758
+ response_model
759
+ ]
760
+ .groupby(["llm_judge_model"])
761
+ .agg({"score": ["mean"]})
762
+ .reset_index()
763
+ )
764
+ grouped.columns = ["llm_judge_model", "overall_score"]
765
+
766
+ # st.write(
767
+ # "Extracting overall scores from this grouped dataframe:"
768
+ # )
769
+ # st.write(grouped)
770
+
771
+ # Save the overall scores to the session state.
772
+ for record in grouped.to_dict(orient="records"):
773
+ st.session_state["direct_assessment_overall_scores"][
774
+ response_model
775
+ ][record["llm_judge_model"]] = record["overall_score"]
776
+
777
+ overall_score = grouped["overall_score"].mean()
778
+ controversy = grouped["overall_score"].std()
779
+ st.write(f"Overall Score: {overall_score:.2f}")
780
+ st.write(f"Controversy: {controversy:.2f}")
781
+
782
+ st.session_state["judging_status"] = "complete"
783
+
784
+ # Judging is complete.
785
+ st.write("#### Results")
786
+ # The session state now contains the overall scores for each response from each judge.
787
+ if st.session_state["judging_status"] == "complete":
788
+ overall_scores_df_raw = pd.DataFrame(
789
+ st.session_state["direct_assessment_overall_scores"]
790
+ ).reset_index()
791
+
792
+ overall_scores_df = pd.melt(
793
+ overall_scores_df_raw,
794
+ id_vars=["index"],
795
+ var_name="response_model",
796
+ value_name="score",
797
+ ).rename(columns={"index": "judging_model"})
798
+
799
+ # Print the overall winner.
800
+ overall_winner = overall_scores_df.loc[
801
+ overall_scores_df["score"].idxmax()
802
+ ]
803
+
804
+ st.write(
805
+ f"**Overall Winner:** {get_ui_friendly_name(overall_winner['response_model'])}"
806
  )
807
+ # Find how much the standard deviation overlaps with other models.
808
+ # Calculate separability.
809
+ # TODO.
810
+ st.write(f"**Confidence:** {overall_winner['score']:.2f}")
811
+
812
+ left_column, right_column = st.columns([1, 1])
813
+ with left_column:
814
+ plot_overall_scores(overall_scores_df)
815
+
816
+ with right_column:
817
+ st.dataframe(overall_scores_df)
818
+
819
+ elif assessment_type == "Pairwise Comparison":
820
+ pass
821
+ # pairwise_comparison_prompt = st.text_area(
822
+ # "Prompt for the Pairwise Comparison"
823
+ # )
824
+ # granularity = st.selectbox("Granularity", ["coarse", "fine", "super fine"])
825
+ # ties_allowed = st.checkbox("Are ties allowed?")
826
+ # position_swapping = st.checkbox("Enable position swapping?")
827
+ # reference_model = st.text_input("Reference Model")
828
+
829
+ # # Create PairwiseComparison object when form is submitted
830
+ # if st.button("Submit Pairwise Comparison"):
831
+ # pairwise_comparison_config = PairwiseComparison(
832
+ # type="pairwise_comparison",
833
+ # granularity=granularity,
834
+ # ties_allowed=ties_allowed,
835
+ # position_swapping=position_swapping,
836
+ # reference_model=reference_model,
837
+ # prompt=prompt,
838
+ # )
839
+ # st.success(f"Pairwise Comparison Created: {pairwise_comparison_config}")
840
+ # # Submit pairwise comparison.
841
+ # responses_for_judging = st.session_state["responses"]
842
 
843
  else:
844
  with cols[1]:
img/council_icon.png ADDED
prompts.py CHANGED
@@ -25,7 +25,10 @@ DEFAULT_AGGREGATOR_PROMPT = """We are trying to come up with the best response t
25
  Responses from other LLMs:
26
  {responses_from_other_llms}
27
 
28
- Please provide a response the combines the best aspects of the responses above."""
 
 
 
29
 
30
 
31
  DEFAULT_DIRECT_ASSESSMENT_PROMPT = """We are trying to assess the quality of a response to a user query.
 
25
  Responses from other LLMs:
26
  {responses_from_other_llms}
27
 
28
+ Consider how you would combine the best aspects of the responses above into a single response.
29
+
30
+ Directly provide your response to the user's query as if you were the original LLM. Do not mention that you are synthesizing the responses from other LLMs.
31
+ """
32
 
33
 
34
  DEFAULT_DIRECT_ASSESSMENT_PROMPT = """We are trying to assess the quality of a response to a user query.