Spaces:
Sleeping
Sleeping
import torch | |
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
# Load model and tokenizer | |
def load_model(): | |
model_name = "zeyadusf/text2pandas-T5" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = model.to(device) | |
return model, tokenizer | |
model, tokenizer = load_model() | |
# Define the function to generate text | |
def generate_text(question, context, max_length=512, num_beams=4, early_stopping=True): | |
input_text = f"<question> {question} <context> {context}" | |
inputs = tokenizer.encode(input_text, return_tensors="pt", truncation=True, max_length=max_length).to(model.device) | |
with torch.no_grad(): | |
outputs = model.generate(inputs, max_length=max_length, num_beams=num_beams, early_stopping=early_stopping) | |
predicted_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return predicted_text | |
# Gradio interface | |
def gradio_interface(question, context, max_length, num_beams, early_stopping): | |
return generate_text(question, context, max_length, num_beams, early_stopping) | |
# Gradio UI Components | |
question_input = gr.Textbox(label="Enter the Question", value="what is the total amount of players for the rockets in 1998 only?") | |
context_input = gr.Textbox(label="Enter the Context", value="df = pd.DataFrame(columns=['player', 'years_for_rockets'])") | |
max_length_input = gr.Slider(minimum=50, maximum=1024, value=512, label="Max Length") | |
num_beams_input = gr.Slider(minimum=1, maximum=10, value=4, label="Number of Beams") | |
early_stopping_input = gr.Checkbox(value=True, label="Early Stopping") | |
# Custom CSS to style the slider, checkbox, and center the button | |
custom_css = """ | |
/* Make the slider handle and bar light green */ | |
input[type="range"] { | |
accent-color: lightgreen; | |
} | |
input[type="range"]::-webkit-slider-thumb { | |
background-color: #90EE90; /* Light green slider thumb */ | |
} | |
input[type="range"]::-webkit-slider-runnable-track { | |
background-color: #32CD32; /* Light green slider track */ | |
} | |
/* Make the checkbox light green */ | |
input[type="checkbox"] { | |
accent-color: lightgreen; | |
} | |
/* Center the button */ | |
.gr-button.gr-button-primary { | |
display: block; | |
margin: 0 auto; | |
background-color: #90EE90; /* Light green button */ | |
color: black; | |
border-radius: 8px; | |
border: 2px solid #006400; /* Dark green border */ | |
} | |
""" | |
# Create Gradio Interface | |
gr.Interface( | |
fn=gradio_interface, | |
inputs=[question_input, context_input, max_length_input, num_beams_input, early_stopping_input], | |
outputs="text", | |
title="Text to Pandas Code Generator", | |
description="Generate Pandas code by providing a question and a context.", | |
css=custom_css, # Apply the custom CSS | |
).launch() | |