from typing import cast import gradio as gr import torch from transformers import BertTokenizerFast, ErnieForCausalLM def load_model(): tokenizer = BertTokenizerFast.from_pretrained("wybxc/new-yiri") assert isinstance(tokenizer, BertTokenizerFast) model = ErnieForCausalLM.from_pretrained("wybxc/new-yiri") assert isinstance(model, ErnieForCausalLM) return tokenizer, model def generate( tokenizer: BertTokenizerFast, model: ErnieForCausalLM, input_str: str, alpha: float, topk: int, ): input_ids = tokenizer.encode(input_str, return_tensors="pt") input_ids = cast(torch.Tensor, input_ids) outputs = model.generate( input_ids, max_new_tokens=100, penalty_alpha=alpha, top_k=topk, early_stopping=True, decoder_start_token_id=tokenizer.sep_token_id, eos_token_id=tokenizer.sep_token_id, ) i, *_ = torch.nonzero(outputs[0] == tokenizer.sep_token_id) output = tokenizer.decode( outputs[0, i:], skip_special_tokens=True, ).replace(" ", "") return output with gr.Blocks() as demo: with gr.Row(): with gr.Column(scale=3): chatbot = gr.Chatbot().style(height=500) with gr.Row(): with gr.Column(scale=4): msg = gr.Textbox( show_label=False, placeholder="Enter text and press enter" ).style(container=False) msg = cast(gr.Textbox, msg) with gr.Column(scale=1): button = gr.Button("Generate") with gr.Column(scale=1): clear = gr.Button("Clear") with gr.Column(scale=1): alpha = gr.Slider(0, 1, 0.5, step=0.01, label="Penalty Alpha") topk = gr.Slider(1, 50, 5, step=1, label="Top K") tokenizer, model = load_model() def user(user_message: str, history: list[list[str]]): return "", [*history, [user_message, None]] def bot(history: list[list[str]], alpha: float, topk: int): user_message = history[-1][0] bot_message = generate( tokenizer, model, user_message, alpha=alpha, topk=topk, ) history[-1][1] = bot_message return history msg.submit(user, inputs=[msg, chatbot], outputs=[msg, chatbot]).then( bot, inputs=[chatbot, alpha, topk], outputs=[chatbot] ) button.click(user, inputs=[msg, chatbot], outputs=[msg, chatbot]).then( bot, inputs=[chatbot, alpha, topk], outputs=[chatbot] ) clear.click(lambda: None, None, chatbot) if __name__ == "__main__": demo.queue(concurrency_count=3) demo.launch()