|
from typing import cast |
|
|
|
import gradio as gr |
|
import torch |
|
from transformers import BertTokenizerFast, ErnieForCausalLM |
|
|
|
|
|
def load_model(): |
|
tokenizer = BertTokenizerFast.from_pretrained("wybxc/new-yiri") |
|
assert isinstance(tokenizer, BertTokenizerFast) |
|
model = ErnieForCausalLM.from_pretrained("wybxc/new-yiri") |
|
assert isinstance(model, ErnieForCausalLM) |
|
|
|
return tokenizer, model |
|
|
|
|
|
def generate( |
|
tokenizer: BertTokenizerFast, |
|
model: ErnieForCausalLM, |
|
input_str: str, |
|
alpha: float, |
|
topk: int, |
|
): |
|
input_ids = tokenizer.encode(input_str, return_tensors="pt") |
|
input_ids = cast(torch.Tensor, input_ids) |
|
outputs = model.generate( |
|
input_ids, |
|
max_new_tokens=100, |
|
penalty_alpha=alpha, |
|
top_k=topk, |
|
early_stopping=True, |
|
decoder_start_token_id=tokenizer.sep_token_id, |
|
eos_token_id=tokenizer.sep_token_id, |
|
) |
|
i, *_ = torch.nonzero(outputs[0] == tokenizer.sep_token_id) |
|
output = tokenizer.decode( |
|
outputs[0, i:], |
|
skip_special_tokens=True, |
|
).replace(" ", "") |
|
return output |
|
|
|
|
|
with gr.Blocks() as demo: |
|
with gr.Row(): |
|
with gr.Column(scale=3): |
|
chatbot = gr.Chatbot().style(height=500) |
|
with gr.Row(): |
|
with gr.Column(scale=4): |
|
msg = gr.Textbox( |
|
show_label=False, placeholder="Enter text and press enter" |
|
).style(container=False) |
|
msg = cast(gr.Textbox, msg) |
|
with gr.Column(scale=1): |
|
button = gr.Button("Generate") |
|
with gr.Column(scale=1): |
|
clear = gr.Button("Clear") |
|
with gr.Column(scale=1): |
|
alpha = gr.Slider(0, 1, 0.5, step=0.01, label="Penalty Alpha") |
|
topk = gr.Slider(1, 50, 5, step=1, label="Top K") |
|
|
|
tokenizer, model = load_model() |
|
|
|
def on_message( |
|
user_message: str, history: list[list[str]], alpha: float, topk: int |
|
): |
|
bot_message = generate( |
|
tokenizer, |
|
model, |
|
user_message, |
|
alpha=alpha, |
|
topk=topk, |
|
) |
|
return "", [*history, [user_message, bot_message]] |
|
|
|
msg.submit(on_message, inputs=[msg, chatbot, alpha, topk], outputs=[msg, chatbot]) |
|
button.click(on_message, inputs=[msg, chatbot, alpha, topk], outputs=[msg, chatbot]) |
|
|
|
clear.click(lambda: None, None, chatbot) |
|
|
|
if __name__ == "__main__": |
|
demo.queue(concurrency_count=3) |
|
demo.launch() |
|
|