File size: 5,429 Bytes
2e320f4
 
 
 
 
 
 
 
 
 
c475ad8
2e320f4
 
 
 
 
c475ad8
 
 
2e320f4
 
 
 
 
c475ad8
 
 
 
 
 
 
 
 
 
2e320f4
c475ad8
 
2e320f4
c475ad8
 
 
2e320f4
 
 
 
 
c475ad8
2e320f4
c475ad8
 
 
 
 
 
 
 
2e320f4
c475ad8
 
 
 
9e4509a
c475ad8
2e320f4
c475ad8
 
2e320f4
c475ad8
 
 
 
 
2e320f4
 
c475ad8
2e320f4
c475ad8
 
2e320f4
c475ad8
2e320f4
 
c475ad8
2e320f4
c475ad8
2e320f4
 
 
 
 
 
 
c475ad8
2e320f4
c475ad8
 
 
 
 
 
 
 
 
2e320f4
c475ad8
 
 
2e320f4
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM

# Load the model and tokenizer
model, tokenizer = (
    AutoModelForCausalLM.from_pretrained("Abhishekcr448/Tiny-Hinglish-Chat-21M"),
    AutoTokenizer.from_pretrained("Abhishekcr448/Tiny-Hinglish-Chat-21M"),
)

# Function to generate text (suggestions)
def generate_text(prompt, output_length, temperature, top_k, top_p):
    inputs = tokenizer(prompt, return_tensors='pt').to(model.device)
    generated_output = model.generate(
        inputs['input_ids'],
        max_length=inputs['input_ids'].shape[-1] + output_length,  # Generate 10 more tokens
        no_repeat_ngram_size=2,
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        do_sample=True,
    )
    output_text = tokenizer.decode(generated_output[0], skip_special_tokens=True)
    return output_text

# Set up the Gradio interface with custom CSS
with gr.Blocks(css="""
    #response-text {
        background-color: #e1bee7; /* Light purple background */
        border-radius: 8px; /* Rounded corners */
        padding: 10px; /* Padding inside the textbox */
        font-size: 16px; /* Font size */
        color: #4a148c; /* Dark purple text color */
    }
""") as demo:

    # Add a title to the interface
    gr.Markdown("# Hinglish Chat Prediction")

    # Add a chat interface above the text boxes with reduced size
    with gr.Row():
        chatbox = gr.Chatbot(label="Chat", type="messages", height=350, value=[{"role": "assistant", "content": "Kya kar rahe ho"}])
    
    with gr.Row():
        # Create a column for the two text boxes
        with gr.Column(scale=3):
            # Input text box for user input (first column)
            input_text = gr.Textbox(label="Start chatting", interactive=True)
            
        # Create a separate column for the buttons
        with gr.Column(scale=1):
            # Submit button placed above the replace button
            submit_button = gr.Button("Submit", variant="primary", elem_id="submit-btn")

    with gr.Row():
        # Create a column for the two text boxes
        with gr.Column(scale=3):
            # Response text box (second column)
            response_text = gr.Textbox(label="Suggestion", interactive=False, elem_id="response-text")

        # Create a separate column for the buttons
        with gr.Column(scale=1):
            replace_button = gr.Button("Use Suggestion", variant="secondary", elem_id="replace-btn")
            regenerate_button = gr.Button("Regenerate", variant="secondary", elem_id="regenerate-btn")

    with gr.Row():
        # Create a dropdown menu for text generation parameters
        with gr.Column(scale=1):
            with gr.Accordion("Change Parameters", open=False):
                output_length_slider = gr.Slider(1, 20, value=8, label="Output Length", step=1)
                temperature_slider = gr.Slider(0.1, 1.0, value=0.8, label="Temperature (Controls randomness)")
                top_k_slider = gr.Slider(1, 100, value=50, label="Top-k (Limits vocabulary size)", step=1)
                top_p_slider = gr.Slider(0.1, 1.0, value=0.9, label="Top-p (Nucleus sampling)")
    
    # Set up the interaction between input and output
    def validate_and_generate(prompt, output_length, temperature, top_k, top_p):
        if prompt.strip():
            print(f"Prompt: {prompt}")
            return generate_text(prompt, output_length, temperature, top_k, top_p)

    input_text.input(validate_and_generate, inputs=[input_text, output_length_slider, temperature_slider, top_k_slider, top_p_slider], outputs=response_text)
    replace_button.click(lambda x: x, inputs=response_text, outputs=input_text)

    def chat_interaction(prompt, history, output_length, temperature, top_k, top_p):
        if prompt.strip():
            response = generate_text(prompt, output_length, temperature, top_k, top_p)

            # Exclude the input prompt text from the response
            response = response[len(prompt):].strip()
            history.append({"role": "user", "content": prompt})
            history.append({"role": "assistant", "content": response})
            
            # Call validate_and_generate with the response
            response_text_value = validate_and_generate(response, output_length, temperature, top_k, top_p)
            return history, response_text_value[len(response):].strip(), ""
        
        return history, "", ""
        
    def regenerate_text(input_text, history, output_length, temperature, top_k, top_p):
        if input_text.strip():
            return generate_text(input_text, output_length, temperature, top_k, top_p)
        else:
            last_message = history[-1]["content"]
            return generate_text(last_message, output_length, temperature, top_k, top_p)[len(last_message):].strip()

    input_text.submit(chat_interaction, inputs=[input_text, chatbox, output_length_slider, temperature_slider, top_k_slider, top_p_slider], outputs=[chatbox, response_text, input_text])
    submit_button.click(chat_interaction, inputs=[input_text, chatbox, output_length_slider, temperature_slider, top_k_slider, top_p_slider], outputs=[chatbox, response_text, input_text])
    regenerate_button.click(regenerate_text, inputs=[input_text, chatbox, output_length_slider, temperature_slider, top_k_slider, top_p_slider], outputs=response_text)

# Launch the interface
demo.launch()