NeoYiri / app.py
wybxc's picture
feat: init
3cd2d35 unverified
raw
history blame
2.53 kB
from typing import cast
import gradio as gr
import torch
from transformers import BertTokenizerFast, ErnieForCausalLM
def load_model():
tokenizer = BertTokenizerFast.from_pretrained("wybxc/new-yiri")
assert isinstance(tokenizer, BertTokenizerFast)
model = ErnieForCausalLM.from_pretrained("wybxc/new-yiri")
assert isinstance(model, ErnieForCausalLM)
return tokenizer, model
def generate(
tokenizer: BertTokenizerFast,
model: ErnieForCausalLM,
input_str: str,
alpha: float,
topk: int,
):
input_ids = tokenizer.encode(input_str, return_tensors="pt")
input_ids = cast(torch.Tensor, input_ids)
outputs = model.generate(
input_ids,
max_new_tokens=100,
penalty_alpha=alpha,
top_k=topk,
early_stopping=True,
decoder_start_token_id=tokenizer.sep_token_id,
eos_token_id=tokenizer.sep_token_id,
)
i, *_ = torch.nonzero(outputs[0] == tokenizer.sep_token_id)
output = tokenizer.decode(
outputs[0, i:],
skip_special_tokens=True,
).replace(" ", "")
return output
with gr.Blocks() as demo:
with gr.Row():
with gr.Column(scale=3):
chatbot = gr.Chatbot().style(height=500)
with gr.Row():
with gr.Column(scale=4):
msg = gr.Textbox(
show_label=False, placeholder="Enter text and press enter"
).style(container=False)
msg = cast(gr.Textbox, msg)
with gr.Column(scale=1):
button = gr.Button("Generate")
with gr.Column(scale=1):
clear = gr.Button("Clear")
with gr.Column(scale=1):
alpha = gr.Slider(0, 1, 0.5, step=0.01, label="Penalty Alpha")
topk = gr.Slider(1, 50, 5, step=1, label="Top K")
tokenizer, model = load_model()
def on_message(
user_message: str, history: list[list[str]], alpha: float, topk: int
):
bot_message = generate(
tokenizer,
model,
user_message,
alpha=alpha,
topk=topk,
)
return "", [*history, [user_message, bot_message]]
msg.submit(on_message, inputs=[msg, chatbot, alpha, topk], outputs=[msg, chatbot])
button.click(on_message, inputs=[msg, chatbot, alpha, topk], outputs=[msg, chatbot])
clear.click(lambda: None, None, chatbot)
if __name__ == "__main__":
demo.queue(concurrency_count=3)
demo.launch()