ehagey's picture
Update app.py
9c56e37 verified
raw
history blame
8.16 kB
import streamlit as st
import pandas as pd
from together import Together
from dotenv import load_dotenv
from datasets import load_dataset
import json
import re
import os
from config import DATASETS, MODELS
load_dotenv()
client = Together(api_key=os.getenv('TOGETHERAI_API_KEY'))
@st.cache_data
def load_dataset_by_name(dataset_name, split="train"):
dataset_config = DATASETS[dataset_name]
dataset = load_dataset(dataset_config["loader"])
df = pd.DataFrame(dataset[split])
df = df[df['choice_type'] == 'single']
questions = []
for _, row in df.iterrows():
options = [row['opa'], row['opb'], row['opc'], row['opd']]
correct_answer = options[row['cop']]
question_dict = {
'question': row['question'],
'options': options,
'correct_answer': correct_answer,
'subject_name': row['subject_name'],
'topic_name': row['topic_name'],
'explanation': row['exp']
}
questions.append(question_dict)
st.write(f"Loaded {len(questions)} single-select questions from {dataset_name}")
return questions
def get_model_response(question, options, prompt_template, model_name):
try:
model_config = MODELS[model_name]
options_text = "\n".join([f"{chr(65+i)}. {opt}" for i, opt in enumerate(options)])
prompt = prompt_template.replace("{question}", question).replace("{options}", options_text)
response = client.chat.completions.create(
model=model_config["model_id"],
messages=[{"role": "user", "content": prompt}]
)
response_text = response.choices[0].message.content.strip()
json_match = re.search(r'\{.*\}', response_text, re.DOTALL)
json_response = json.loads(json_match.group(0))
answer = json_response['answer'].strip()
answer = re.sub(r'^[A-D]\.\s*', '', answer)
if not any(answer.lower() == opt.lower() for opt in options):
return f"Error: Answer '{answer}' does not match any options"
return answer
except Exception as e:
return f"Error: {str(e)}"
def evaluate_response(model_response, correct_answer):
if model_response.startswith("Error:"):
return False
return model_response.lower().strip() == correct_answer.lower().strip()
def main():
st.set_page_config(page_title="Medical LLM Evaluation", layout="wide")
st.title("Medical LLM Evaluation")
col1, col2 = st.columns(2)
with col1:
selected_dataset = st.selectbox(
"Select Dataset",
options=list(DATASETS.keys()),
help="Choose the dataset to evaluate on"
)
with col2:
selected_model = st.selectbox(
"Select Model",
options=list(MODELS.keys()),
help="Choose the model to evaluate"
)
default_prompt = '''You are a medical AI assistant. Please answer the following multiple choice question.
Question: {question}
Options:
{options}
## Output Format:
Please provide you answer in JSON format that contains an "answer" field.
You may include any additional fields in your JSON response that you find relevant, such as:
- "answer": the option you selected
- "choice reasoning": your detailed reasoning
- "elimination reasoning": why you ruled out other options
Example response format:
{
"answer": "exact option text here",
"choice reasoning": "your detailed reasoning here",
"elimination reasoning": "why you ruled out other options"
}
Important:
- Only the "answer" field will be used for evaluation
- Ensure your response is in valid JSON format'''
col1, col2 = st.columns([2, 1])
with col1:
prompt_template = st.text_area(
"Customize Prompt Template",
default_prompt,
height=400,
help="The below prompt is editable. Please feel free to edit it before your run."
)
with col2:
st.markdown("""
### Prompt Variables
- `{question}`: The medical question
- `{options}`: The multiple choice options
""")
with st.spinner("Loading dataset..."):
questions = load_dataset_by_name(selected_dataset)
if not questions:
st.error("No questions were loaded successfully.")
return
subjects = list(set(q['subject_name'] for q in questions))
selected_subject = st.selectbox("Filter by subject", ["All"] + subjects)
if selected_subject != "All":
questions = [q for q in questions if q['subject_name'] == selected_subject]
num_questions = st.number_input("Number of questions to evaluate", 1, len(questions))
if st.button("Start Evaluation"):
if not os.getenv('TOGETHERAI_API_KEY'):
st.error("Please set the TOGETHERAI_API_KEY in your .env file")
return
progress_bar = st.progress(0)
status_text = st.empty()
results_container = st.container()
results = []
for i in range(num_questions):
question = questions[i]
progress = (i + 1) / num_questions
progress_bar.progress(progress)
status_text.text(f"Evaluating question {i + 1}/{num_questions}")
model_response = get_model_response(
question['question'],
question['options'],
prompt_template,
selected_model
)
options_text = "\n".join([f"{chr(65+i)}. {opt}" for i, opt in enumerate(question['options'])])
formatted_prompt = prompt_template.replace("{question}", question['question']).replace("{options}", options_text)
raw_response = client.chat.completions.create(
model=MODELS[selected_model]["model_id"],
messages=[{"role": "user", "content": formatted_prompt}]
).choices[0].message.content.strip()
is_correct = evaluate_response(model_response, question['correct_answer'])
results.append({
'question': question['question'],
'options': question['options'],
'model_response': model_response,
'raw_llm_response': raw_response,
'prompt_sent': formatted_prompt,
'correct_answer': question['correct_answer'],
'subject': question['subject_name'],
'is_correct': is_correct,
'explanation': question['explanation']
})
with results_container:
st.subheader("Evaluation Results")
df = pd.DataFrame(results)
accuracy = df['is_correct'].mean()
st.metric("Accuracy", f"{accuracy:.2%}")
for idx, result in enumerate(results):
st.markdown("---")
st.subheader(f"Question {idx + 1} - {result['subject']}")
st.write("Question:", result['question'])
st.write("Options:")
for i, opt in enumerate(result['options']):
st.write(f"{chr(65+i)}. {opt}")
col1, col2 = st.columns(2)
with col1:
with st.expander("Show Prompt"):
st.code(result['prompt_sent'])
with col2:
with st.expander("Show Raw Response"):
st.code(result['raw_llm_response'])
col1, col2 = st.columns(2)
with col1:
st.write("Correct Answer:", result['correct_answer'])
st.write("Model Answer:", result['model_response'])
with col2:
if result['is_correct']:
st.success("Correct!")
else:
st.error("Incorrect")
with st.expander("Show Explanation"):
st.write(result['explanation'])
if __name__ == "__main__":
main()