Abhishekcr448's picture
Update app.py
9e4509a verified
raw
history blame
5.43 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
# Load the model and tokenizer
model, tokenizer = (
AutoModelForCausalLM.from_pretrained("Abhishekcr448/Tiny-Hinglish-Chat-21M"),
AutoTokenizer.from_pretrained("Abhishekcr448/Tiny-Hinglish-Chat-21M"),
)
# Function to generate text (suggestions)
def generate_text(prompt, output_length, temperature, top_k, top_p):
inputs = tokenizer(prompt, return_tensors='pt').to(model.device)
generated_output = model.generate(
inputs['input_ids'],
max_length=inputs['input_ids'].shape[-1] + output_length, # Generate 10 more tokens
no_repeat_ngram_size=2,
temperature=temperature,
top_k=top_k,
top_p=top_p,
do_sample=True,
)
output_text = tokenizer.decode(generated_output[0], skip_special_tokens=True)
return output_text
# Set up the Gradio interface with custom CSS
with gr.Blocks(css="""
#response-text {
background-color: #e1bee7; /* Light purple background */
border-radius: 8px; /* Rounded corners */
padding: 10px; /* Padding inside the textbox */
font-size: 16px; /* Font size */
color: #4a148c; /* Dark purple text color */
}
""") as demo:
# Add a title to the interface
gr.Markdown("# Hinglish Chat Prediction")
# Add a chat interface above the text boxes with reduced size
with gr.Row():
chatbox = gr.Chatbot(label="Chat", type="messages", height=350, value=[{"role": "assistant", "content": "Kya kar rahe ho"}])
with gr.Row():
# Create a column for the two text boxes
with gr.Column(scale=3):
# Input text box for user input (first column)
input_text = gr.Textbox(label="Start chatting", interactive=True)
# Create a separate column for the buttons
with gr.Column(scale=1):
# Submit button placed above the replace button
submit_button = gr.Button("Submit", variant="primary", elem_id="submit-btn")
with gr.Row():
# Create a column for the two text boxes
with gr.Column(scale=3):
# Response text box (second column)
response_text = gr.Textbox(label="Suggestion", interactive=False, elem_id="response-text")
# Create a separate column for the buttons
with gr.Column(scale=1):
replace_button = gr.Button("Use Suggestion", variant="secondary", elem_id="replace-btn")
regenerate_button = gr.Button("Regenerate", variant="secondary", elem_id="regenerate-btn")
with gr.Row():
# Create a dropdown menu for text generation parameters
with gr.Column(scale=1):
with gr.Accordion("Change Parameters", open=False):
output_length_slider = gr.Slider(1, 20, value=8, label="Output Length", step=1)
temperature_slider = gr.Slider(0.1, 1.0, value=0.8, label="Temperature (Controls randomness)")
top_k_slider = gr.Slider(1, 100, value=50, label="Top-k (Limits vocabulary size)", step=1)
top_p_slider = gr.Slider(0.1, 1.0, value=0.9, label="Top-p (Nucleus sampling)")
# Set up the interaction between input and output
def validate_and_generate(prompt, output_length, temperature, top_k, top_p):
if prompt.strip():
print(f"Prompt: {prompt}")
return generate_text(prompt, output_length, temperature, top_k, top_p)
input_text.input(validate_and_generate, inputs=[input_text, output_length_slider, temperature_slider, top_k_slider, top_p_slider], outputs=response_text)
replace_button.click(lambda x: x, inputs=response_text, outputs=input_text)
def chat_interaction(prompt, history, output_length, temperature, top_k, top_p):
if prompt.strip():
response = generate_text(prompt, output_length, temperature, top_k, top_p)
# Exclude the input prompt text from the response
response = response[len(prompt):].strip()
history.append({"role": "user", "content": prompt})
history.append({"role": "assistant", "content": response})
# Call validate_and_generate with the response
response_text_value = validate_and_generate(response, output_length, temperature, top_k, top_p)
return history, response_text_value[len(response):].strip(), ""
return history, "", ""
def regenerate_text(input_text, history, output_length, temperature, top_k, top_p):
if input_text.strip():
return generate_text(input_text, output_length, temperature, top_k, top_p)
else:
last_message = history[-1]["content"]
return generate_text(last_message, output_length, temperature, top_k, top_p)[len(last_message):].strip()
input_text.submit(chat_interaction, inputs=[input_text, chatbox, output_length_slider, temperature_slider, top_k_slider, top_p_slider], outputs=[chatbox, response_text, input_text])
submit_button.click(chat_interaction, inputs=[input_text, chatbox, output_length_slider, temperature_slider, top_k_slider, top_p_slider], outputs=[chatbox, response_text, input_text])
regenerate_button.click(regenerate_text, inputs=[input_text, chatbox, output_length_slider, temperature_slider, top_k_slider, top_p_slider], outputs=response_text)
# Launch the interface
demo.launch()