|
import streamlit as st |
|
import pandas as pd |
|
from together import Together |
|
from dotenv import load_dotenv |
|
from datasets import load_dataset |
|
import json |
|
import re |
|
import os |
|
from config import DATASETS, MODELS |
|
|
|
load_dotenv() |
|
client = Together(api_key=os.getenv('TOGETHERAI_API_KEY')) |
|
|
|
@st.cache_data |
|
def load_dataset_by_name(dataset_name, split="train"): |
|
dataset_config = DATASETS[dataset_name] |
|
dataset = load_dataset(dataset_config["loader"]) |
|
df = pd.DataFrame(dataset[split]) |
|
df = df[df['choice_type'] == 'single'] |
|
|
|
questions = [] |
|
for _, row in df.iterrows(): |
|
options = [row['opa'], row['opb'], row['opc'], row['opd']] |
|
correct_answer = options[row['cop']] |
|
|
|
question_dict = { |
|
'question': row['question'], |
|
'options': options, |
|
'correct_answer': correct_answer, |
|
'subject_name': row['subject_name'], |
|
'topic_name': row['topic_name'], |
|
'explanation': row['exp'] |
|
} |
|
questions.append(question_dict) |
|
|
|
st.write(f"Loaded {len(questions)} single-select questions from {dataset_name}") |
|
return questions |
|
|
|
def get_model_response(question, options, prompt_template, model_name): |
|
try: |
|
model_config = MODELS[model_name] |
|
options_text = "\n".join([f"{chr(65+i)}. {opt}" for i, opt in enumerate(options)]) |
|
prompt = prompt_template.replace("{question}", question).replace("{options}", options_text) |
|
|
|
response = client.chat.completions.create( |
|
model=model_config["model_id"], |
|
messages=[{"role": "user", "content": prompt}] |
|
) |
|
|
|
response_text = response.choices[0].message.content.strip() |
|
json_match = re.search(r'\{.*\}', response_text, re.DOTALL) |
|
json_response = json.loads(json_match.group(0)) |
|
answer = json_response['answer'].strip() |
|
answer = re.sub(r'^[A-D]\.\s*', '', answer) |
|
|
|
if not any(answer.lower() == opt.lower() for opt in options): |
|
return f"Error: Answer '{answer}' does not match any options" |
|
|
|
return answer |
|
|
|
except Exception as e: |
|
return f"Error: {str(e)}" |
|
|
|
def evaluate_response(model_response, correct_answer): |
|
if model_response.startswith("Error:"): |
|
return False |
|
return model_response.lower().strip() == correct_answer.lower().strip() |
|
|
|
def main(): |
|
st.set_page_config(page_title="Medical LLM Evaluation", layout="wide") |
|
st.title("Medical LLM Evaluation") |
|
|
|
col1, col2 = st.columns(2) |
|
with col1: |
|
selected_dataset = st.selectbox( |
|
"Select Dataset", |
|
options=list(DATASETS.keys()), |
|
help="Choose the dataset to evaluate on" |
|
) |
|
with col2: |
|
selected_model = st.selectbox( |
|
"Select Model", |
|
options=list(MODELS.keys()), |
|
help="Choose the model to evaluate" |
|
) |
|
|
|
default_prompt = '''You are a medical AI assistant. Please answer the following multiple choice question. |
|
|
|
Question: {question} |
|
|
|
Options: |
|
{options} |
|
|
|
## Output Format: |
|
Please provide you answer in JSON format that contains an "answer" field. |
|
You may include any additional fields in your JSON response that you find relevant, such as: |
|
- "answer": the option you selected |
|
- "choice reasoning": your detailed reasoning |
|
- "elimination reasoning": why you ruled out other options |
|
|
|
Example response format: |
|
{ |
|
"answer": "exact option text here", |
|
"choice reasoning": "your detailed reasoning here", |
|
"elimination reasoning": "why you ruled out other options" |
|
} |
|
|
|
Important: |
|
- Only the "answer" field will be used for evaluation |
|
- Ensure your response is in valid JSON format''' |
|
|
|
col1, col2 = st.columns([2, 1]) |
|
with col1: |
|
prompt_template = st.text_area( |
|
"Customize Prompt Template", |
|
default_prompt, |
|
height=400, |
|
help="The below prompt is editable. Please feel free to edit it before your run." |
|
) |
|
|
|
with col2: |
|
st.markdown(""" |
|
### Prompt Variables |
|
- `{question}`: The medical question |
|
- `{options}`: The multiple choice options |
|
""") |
|
|
|
with st.spinner("Loading dataset..."): |
|
questions = load_dataset_by_name(selected_dataset) |
|
|
|
if not questions: |
|
st.error("No questions were loaded successfully.") |
|
return |
|
|
|
subjects = list(set(q['subject_name'] for q in questions)) |
|
selected_subject = st.selectbox("Filter by subject", ["All"] + subjects) |
|
|
|
if selected_subject != "All": |
|
questions = [q for q in questions if q['subject_name'] == selected_subject] |
|
|
|
num_questions = st.number_input("Number of questions to evaluate", 1, len(questions)) |
|
|
|
if st.button("Start Evaluation"): |
|
if not os.getenv('TOGETHERAI_API_KEY'): |
|
st.error("Please set the TOGETHERAI_API_KEY in your .env file") |
|
return |
|
|
|
progress_bar = st.progress(0) |
|
status_text = st.empty() |
|
results_container = st.container() |
|
|
|
results = [] |
|
for i in range(num_questions): |
|
question = questions[i] |
|
progress = (i + 1) / num_questions |
|
progress_bar.progress(progress) |
|
status_text.text(f"Evaluating question {i + 1}/{num_questions}") |
|
|
|
model_response = get_model_response( |
|
question['question'], |
|
question['options'], |
|
prompt_template, |
|
selected_model |
|
) |
|
|
|
options_text = "\n".join([f"{chr(65+i)}. {opt}" for i, opt in enumerate(question['options'])]) |
|
formatted_prompt = prompt_template.replace("{question}", question['question']).replace("{options}", options_text) |
|
raw_response = client.chat.completions.create( |
|
model=MODELS[selected_model]["model_id"], |
|
messages=[{"role": "user", "content": formatted_prompt}] |
|
).choices[0].message.content.strip() |
|
|
|
is_correct = evaluate_response(model_response, question['correct_answer']) |
|
|
|
results.append({ |
|
'question': question['question'], |
|
'options': question['options'], |
|
'model_response': model_response, |
|
'raw_llm_response': raw_response, |
|
'prompt_sent': formatted_prompt, |
|
'correct_answer': question['correct_answer'], |
|
'subject': question['subject_name'], |
|
'is_correct': is_correct, |
|
'explanation': question['explanation'] |
|
}) |
|
|
|
with results_container: |
|
st.subheader("Evaluation Results") |
|
df = pd.DataFrame(results) |
|
accuracy = df['is_correct'].mean() |
|
st.metric("Accuracy", f"{accuracy:.2%}") |
|
|
|
for idx, result in enumerate(results): |
|
st.markdown("---") |
|
st.subheader(f"Question {idx + 1} - {result['subject']}") |
|
|
|
st.write("Question:", result['question']) |
|
st.write("Options:") |
|
for i, opt in enumerate(result['options']): |
|
st.write(f"{chr(65+i)}. {opt}") |
|
|
|
col1, col2 = st.columns(2) |
|
with col1: |
|
with st.expander("Show Prompt"): |
|
st.code(result['prompt_sent']) |
|
with col2: |
|
with st.expander("Show Raw Response"): |
|
st.code(result['raw_llm_response']) |
|
|
|
col1, col2 = st.columns(2) |
|
with col1: |
|
st.write("Correct Answer:", result['correct_answer']) |
|
st.write("Model Answer:", result['model_response']) |
|
with col2: |
|
if result['is_correct']: |
|
st.success("Correct!") |
|
else: |
|
st.error("Incorrect") |
|
|
|
with st.expander("Show Explanation"): |
|
st.write(result['explanation']) |
|
|
|
if __name__ == "__main__": |
|
main() |