|
import gradio as gr |
|
from datasets import load_dataset |
|
import torch |
|
from transformers import AutoTokenizer, T5ForConditionalGeneration |
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
print(f"Using device: {device}") |
|
|
|
|
|
ds = load_dataset("AI-MO/NuminaMath-CoT") |
|
|
|
|
|
model_name = "google/flan-t5-base" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = T5ForConditionalGeneration.from_pretrained(model_name).to(device) |
|
|
|
def process_example(example): |
|
"""Process a single example from the dataset""" |
|
question = example['question'] |
|
solution = example['solution'] |
|
answer = example['answer'] |
|
return f"Question: {question}\nSolution: {solution}\nAnswer: {answer}" |
|
|
|
def get_random_example(): |
|
"""Get a random example from the dataset""" |
|
import random |
|
idx = random.randint(0, len(ds['train']) - 1) |
|
return process_example(ds['train'][idx]) |
|
|
|
def solve_math_problem(question): |
|
"""Generate solution for a given math problem""" |
|
|
|
input_text = "solve math: " + question |
|
inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True).to(device) |
|
|
|
|
|
outputs = model.generate( |
|
inputs["input_ids"], |
|
max_length=200, |
|
num_return_sequences=1, |
|
temperature=0.7, |
|
do_sample=True, |
|
top_p=0.9, |
|
) |
|
|
|
response = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
return response |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Math Problem Solver") |
|
gr.Markdown("Using FLAN-T5 model to solve mathematical problems with step-by-step solutions.") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
input_text = gr.Textbox( |
|
label="Enter your math problem", |
|
placeholder="Type your math problem here...", |
|
lines=3 |
|
) |
|
with gr.Row(): |
|
submit_btn = gr.Button("Solve Problem", variant="primary") |
|
example_btn = gr.Button("Show Random Example") |
|
|
|
with gr.Column(): |
|
output_text = gr.Textbox( |
|
label="Solution", |
|
lines=8, |
|
show_copy_button=True |
|
) |
|
|
|
|
|
submit_btn.click( |
|
fn=solve_math_problem, |
|
inputs=input_text, |
|
outputs=output_text |
|
) |
|
|
|
example_btn.click( |
|
fn=get_random_example, |
|
inputs=None, |
|
outputs=input_text |
|
) |
|
|
|
|
|
demo.launch(share=True) |