|
import gradio as gr |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
model, tokenizer = ( |
|
AutoModelForCausalLM.from_pretrained("Abhishekcr448/Tiny-Hinglish-Chat-21M"), |
|
AutoTokenizer.from_pretrained("Abhishekcr448/Tiny-Hinglish-Chat-21M"), |
|
) |
|
|
|
|
|
def generate_text(prompt, output_length, temperature, top_k, top_p): |
|
inputs = tokenizer(prompt, return_tensors='pt').to(model.device) |
|
generated_output = model.generate( |
|
inputs['input_ids'], |
|
max_length=inputs['input_ids'].shape[-1] + output_length, |
|
no_repeat_ngram_size=2, |
|
temperature=temperature, |
|
top_k=top_k, |
|
top_p=top_p, |
|
do_sample=True, |
|
) |
|
output_text = tokenizer.decode(generated_output[0], skip_special_tokens=True) |
|
return output_text |
|
|
|
|
|
with gr.Blocks(css=""" |
|
#response-text { |
|
background-color: #e1bee7; /* Light purple background */ |
|
border-radius: 8px; /* Rounded corners */ |
|
padding: 10px; /* Padding inside the textbox */ |
|
font-size: 16px; /* Font size */ |
|
color: #4a148c; /* Dark purple text color */ |
|
} |
|
""") as demo: |
|
|
|
|
|
gr.Markdown("# Hinglish Chat Prediction") |
|
|
|
|
|
with gr.Row(): |
|
chatbox = gr.Chatbot(label="Chat", type="messages", height=350, value=[{"role": "assistant", "content": "Kya kar rahe ho"}]) |
|
|
|
with gr.Row(): |
|
|
|
with gr.Column(scale=3): |
|
|
|
input_text = gr.Textbox(label="Start chatting", interactive=True) |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
|
submit_button = gr.Button("Submit", variant="primary", elem_id="submit-btn") |
|
|
|
with gr.Row(): |
|
|
|
with gr.Column(scale=3): |
|
|
|
response_text = gr.Textbox(label="Suggestion", interactive=False, elem_id="response-text") |
|
|
|
|
|
with gr.Column(scale=1): |
|
replace_button = gr.Button("Use Suggestion", variant="secondary", elem_id="replace-btn") |
|
regenerate_button = gr.Button("Regenerate", variant="secondary", elem_id="regenerate-btn") |
|
|
|
with gr.Row(): |
|
|
|
with gr.Column(scale=1): |
|
with gr.Accordion("Change Parameters", open=False): |
|
output_length_slider = gr.Slider(1, 20, value=8, label="Output Length", step=1) |
|
temperature_slider = gr.Slider(0.1, 1.0, value=0.8, label="Temperature (Controls randomness)") |
|
top_k_slider = gr.Slider(1, 100, value=50, label="Top-k (Limits vocabulary size)", step=1) |
|
top_p_slider = gr.Slider(0.1, 1.0, value=0.9, label="Top-p (Nucleus sampling)") |
|
|
|
|
|
def validate_and_generate(prompt, output_length, temperature, top_k, top_p): |
|
if prompt.strip(): |
|
print(f"Prompt: {prompt}") |
|
return generate_text(prompt, output_length, temperature, top_k, top_p) |
|
|
|
input_text.input(validate_and_generate, inputs=[input_text, output_length_slider, temperature_slider, top_k_slider, top_p_slider], outputs=response_text) |
|
replace_button.click(lambda x: x, inputs=response_text, outputs=input_text) |
|
|
|
def chat_interaction(prompt, history, output_length, temperature, top_k, top_p): |
|
if prompt.strip(): |
|
response = generate_text(prompt, output_length, temperature, top_k, top_p) |
|
|
|
|
|
response = response[len(prompt):].strip() |
|
history.append({"role": "user", "content": prompt}) |
|
history.append({"role": "assistant", "content": response}) |
|
|
|
|
|
response_text_value = validate_and_generate(response, output_length, temperature, top_k, top_p) |
|
return history, response_text_value[len(response):].strip(), "" |
|
|
|
return history, "", "" |
|
|
|
def regenerate_text(input_text, history, output_length, temperature, top_k, top_p): |
|
if input_text.strip(): |
|
return generate_text(input_text, output_length, temperature, top_k, top_p) |
|
else: |
|
last_message = history[-1]["content"] |
|
return generate_text(last_message, output_length, temperature, top_k, top_p)[len(last_message):].strip() |
|
|
|
input_text.submit(chat_interaction, inputs=[input_text, chatbox, output_length_slider, temperature_slider, top_k_slider, top_p_slider], outputs=[chatbox, response_text, input_text]) |
|
submit_button.click(chat_interaction, inputs=[input_text, chatbox, output_length_slider, temperature_slider, top_k_slider, top_p_slider], outputs=[chatbox, response_text, input_text]) |
|
regenerate_button.click(regenerate_text, inputs=[input_text, chatbox, output_length_slider, temperature_slider, top_k_slider, top_p_slider], outputs=response_text) |
|
|
|
|
|
demo.launch() |
|
|