Spaces:
Running
Running
import gradio as gr | |
from transformers import AutoModelForCausalLM | |
from transformers import BloomTokenizerFast | |
from transformers import pipeline, set_seed | |
import random | |
model_name = "bloom-560m" | |
model = AutoModelForCausalLM.from_pretrained(f'jslin09/{model_name}-finetuned-fraud') | |
tokenizer = BloomTokenizerFast.from_pretrained(f'bigscience/{model_name}', bos_token = '<s>', eos_token = '</s>') | |
def rnd_generate(prompt): | |
rnd_seed = random.randint(10, 500) | |
set_seed(rnd_seed) | |
inputs = tokenizer(prompt, return_tensors="pt") # 回傳的張量使用 Pytorch的格式。如果是 Tensorflow 格式的話,則指定為 "tf"。 | |
results = model.generate(inputs["input_ids"], | |
max_length=500, | |
num_return_sequences=1, # 產生 1 個句子回來。 | |
do_sample=True, | |
temperature=0.75, | |
top_k=50, | |
top_p=0.9) | |
return tokenizer.decode(results[0]) | |
def generate(prompt): | |
result_length = len(prompt) + 4 | |
inputs = tokenizer(prompt, return_tensors="pt") # 回傳的張量使用 Pytorch的格式。如果是 Tensorflow 格式的話,則指定為 "tf"。 | |
results = model.generate(inputs["input_ids"], | |
num_return_sequences=2, # 產生 2 個句子回來。 | |
max_length=result_length, | |
early_stopping=True, | |
do_sample=True, | |
top_k=50, | |
top_p=0.9 | |
) | |
return tokenizer.decode(results[0]) | |
examples = [ | |
["闕很大明知金融帳戶之存摺、提款卡及密碼係供自己使用之重要理財工具,"], | |
["梅友乾明知其無資力支付酒店消費,亦無付款意願,竟意圖為自己不法之所有,"], | |
["王大明意圖為自己不法所有,基於竊盜之犯意,"] | |
] | |
prompts = [ | |
["輸入寫書類的句子,讓電腦生成下一句。或是按以下的範例句子。"], | |
["輸入寫書類的開頭句子,讓電腦隨機生成整篇草稿。"] | |
] | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
""" | |
<h1 style="text-align: center;">Legal Document Drafting</h1> | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
result = gr.components.Textbox(lines=7, label="Writing Assist", placeholder=prompts[0]) | |
prompt = gr.components.Textbox(lines=2, label="Prompt", placeholder=examples[0], visible=False) | |
gr.Examples(examples, label='Examples', inputs=[prompt]) | |
prompt.change(generate, inputs=[prompt], outputs=[result]) | |
btn = gr.Button("Next sentence") | |
btn.click(generate, inputs=[result], outputs=[result]) | |
with gr.Column(): | |
result2 = gr.components.Textbox(lines=7, label="Random Generative", show_label=True, placeholder=prompts[1]) | |
gr.Examples(examples, label='Examples', inputs=[result2]) | |
btn = gr.Button("Random Drafting") | |
btn.click(rnd_generate, inputs=[result2], outputs=[result2]) | |
if __name__ == "__main__": | |
demo.launch() | |
# gr.Interface.load("models/jslin09/bloom-560m-finetuned-fraud").launch() |