jslin09 commited on
Commit
065361a
1 Parent(s): 315025d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -0
app.py CHANGED
@@ -2,11 +2,25 @@ import gradio as gr
2
  from transformers import AutoModelForCausalLM
3
  from transformers import BloomTokenizerFast
4
  from transformers import pipeline, set_seed
 
5
 
6
  model_name = "bloom-560m"
7
  model = AutoModelForCausalLM.from_pretrained(f"jslin09/{model_name}-finetuned-fraud")
8
  tokenizer = BloomTokenizerFast.from_pretrained(f'bigscience/{model_name}', bos_token = '<s>', eos_token = '</s>')
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  def generate(prompt):
11
  result_length = len(prompt) + 4
12
  inputs = tokenizer(prompt, return_tensors="pt") # 回傳的張量使用 Pytorch的格式。如果是 Tensorflow 格式的話,則指定為 "tf"。
@@ -39,6 +53,11 @@ with gr.Blocks() as demo:
39
  prompt.change(generate, inputs=[prompt], outputs=[result])
40
  btn = gr.Button("Next sentence")
41
  btn.click(generate, inputs=[result], outputs=[result])
 
 
 
 
 
42
 
43
  if __name__ == "__main__":
44
  demo.launch()
 
2
  from transformers import AutoModelForCausalLM
3
  from transformers import BloomTokenizerFast
4
  from transformers import pipeline, set_seed
5
+ import random
6
 
7
  model_name = "bloom-560m"
8
  model = AutoModelForCausalLM.from_pretrained(f"jslin09/{model_name}-finetuned-fraud")
9
  tokenizer = BloomTokenizerFast.from_pretrained(f'bigscience/{model_name}', bos_token = '<s>', eos_token = '</s>')
10
 
11
+ def rnd_generate(prompt):
12
+ rnd_seed = random.randint(10, 500)
13
+ set_seed(rnd_seed)
14
+ inputs = tokenizer(prompt, return_tensors="pt") # 回傳的張量使用 Pytorch的格式。如果是 Tensorflow 格式的話,則指定為 "tf"。
15
+ results = model.generate(inputs["input_ids"],
16
+ max_length=500,
17
+ num_return_sequences=1,
18
+ do_sample=True,
19
+ temperature=0.75,
20
+ top_k=50,
21
+ top_p=0.9)
22
+ return tokenizer.decode(results[0])
23
+
24
  def generate(prompt):
25
  result_length = len(prompt) + 4
26
  inputs = tokenizer(prompt, return_tensors="pt") # 回傳的張量使用 Pytorch的格式。如果是 Tensorflow 格式的話,則指定為 "tf"。
 
53
  prompt.change(generate, inputs=[prompt], outputs=[result])
54
  btn = gr.Button("Next sentence")
55
  btn.click(generate, inputs=[result], outputs=[result])
56
+ with gr.Column():
57
+ result2 = gr.components.Textbox(lines=7, label="生成草稿", show_label=True, placeholder=examples[0])
58
+ gr.Examples(examples, label='例句', inputs=[result2])
59
+ btn = gr.Button("隨機生成草稿")
60
+ btn.click(rnd_generate, inputs=[result2], outputs=[result2])
61
 
62
  if __name__ == "__main__":
63
  demo.launch()