File size: 3,376 Bytes
32b0450
dccf4fa
 
d97db75
065361a
7fc842c
dccf4fa
e4bbe61
c203e02
7fc842c
065361a
4eebfca
065361a
 
 
e0219f0
4eebfca
065361a
 
 
 
 
 
dccf4fa
 
 
 
315025d
dccf4fa
 
 
2ace237
dccf4fa
 
 
1616f44
351b9e5
081ad0c
9bd8a55
dccf4fa
7fc842c
 
a7c67ff
 
 
 
 
28f307c
 
 
48e0bd2
28f307c
f2aa585
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dccf4fa
8c28f7d
8d21a67
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import gradio as gr
from transformers import AutoModelForCausalLM
from transformers import BloomTokenizerFast
from transformers import pipeline, set_seed
import random

model_name = "bloom-560m"
model = AutoModelForCausalLM.from_pretrained(f'jslin09/{model_name}-finetuned-fraud')
tokenizer = BloomTokenizerFast.from_pretrained(f'bigscience/{model_name}', bos_token = '<s>', eos_token = '</s>', pad_token = '<pad>')

def rnd_generate(prompt):
    rnd_seed = random.randint(10, 500)
    set_seed(rnd_seed)
    inputs = tokenizer(prompt, return_tensors="pt") # 回傳的張量使用 Pytorch的格式。如果是 Tensorflow 格式的話,則指定為 "tf"。
    results = model.generate(inputs["input_ids"],
                       max_length=500,
                       num_return_sequences=1, # 產生 1 個句子回來。
                       do_sample=True,
                       temperature=0.75,
                       top_k=50, 
                       top_p=0.9)
    return tokenizer.decode(results[0])

def generate(prompt):
    result_length = len(prompt) + 4
    inputs = tokenizer(prompt, return_tensors="pt") # 回傳的張量使用 Pytorch的格式。如果是 Tensorflow 格式的話,則指定為 "tf"。
    results = model.generate(inputs["input_ids"],
                       num_return_sequences=2, # 產生 2 個句子回來。
                       max_length=result_length,
                       early_stopping=True,
                       do_sample=True, 
                       top_k=50, 
                       top_p=0.9
                      )
    return tokenizer.decode(results[0])

examples = [
    ["闕很大明知金融帳戶之存摺、提款卡及密碼係供自己使用之重要理財工具,"],
    ["梅友乾明知其無資力支付酒店消費,亦無付款意願,竟意圖為自己不法之所有,"],
    ["王大明意圖為自己不法所有,基於竊盜之犯意,"]
]

prompts = [
    ["輸入寫書類的句子,讓電腦生成下一句。或是按以下的範例句子。"],
    ["輸入寫書類的開頭句子,讓電腦隨機生成整篇草稿。"]
]

with gr.Blocks() as demo:
    gr.Markdown(
    """
    <h1 style="text-align: center;">Legal Document Drafting</h1>
    """)
    with gr.Row() as row:
        with gr.Column(scale=1, min_width=600):
            with gr.Tab("Writing Assist"):
                result = gr.components.Textbox(lines=7, label="Writing Assist", placeholder=prompts[0])
                prompt = gr.components.Textbox(lines=2, label="Prompt", placeholder=examples[0], visible=False)
                gr.Examples(examples, label='Examples', inputs=[prompt])
                prompt.change(generate, inputs=[prompt], outputs=[result])
                btn = gr.Button("Next sentence")
                btn.click(generate, inputs=[result], outputs=[result])
            with gr.Tab("Random Generative"):
                result2 = gr.components.Textbox(lines=7, label="Random Generative", show_label=True, placeholder=prompts[1])
                gr.Examples(examples, label='Examples', inputs=[result2])
                rnd_btn = gr.Button("Random Drafting")
                rnd_btn.click(rnd_generate, inputs=[result2], outputs=[result2])
        with gr.Column(scale=2, min_width=600):
            gr.Markdown("Legal Document drafting demo")
    
if __name__ == "__main__":
    demo.launch()