ybelkada commited on
Commit
3c0d201
1 Parent(s): 5c7de89

few updates

Browse files
Files changed (1) hide show
  1. app.py +16 -7
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import gradio as gr
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
3
 
4
  model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-1b3", use_cache=True)
5
  tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-1b3")
@@ -9,21 +9,22 @@ def post_process_sentence(input_sentence, generated_sentence):
9
  if "\n" not in new_sentence:
10
  return generated_sentence.replace(" ", " ") + "\n- "
11
  else:
12
- return (input_sentence + new_sentence.split("\n")[0]).replace(" ", " ") + "\n- "
13
 
14
- def generate_single(model, tokenizer, input_sentence, max_length=50, top_k=0, temperature=0.7):
 
15
  input_ids = tokenizer.encode(input_sentence, return_tensors="pt")
16
  output = model.generate(
17
- input_ids, do_sample=True,
18
  max_length=len(input_sentence)+max_length,
19
  top_k=top_k,
20
- temperature=temperature
21
  )
22
  generated_sentence = tokenizer.decode(output[0], skip_special_tokens=True)
23
  return post_process_sentence(input_sentence, generated_sentence)
24
 
25
- def question_bloom(input_sentence, max_length, temperature):
26
- post_processed_output = generate_single(model, tokenizer, input_sentence, temperature=temperature, max_length=max_length)
27
  return post_processed_output.split("\n-")[-2]
28
 
29
  gr.Interface(
@@ -44,6 +45,14 @@ gr.Interface(
44
  default=0.6,
45
  label="Temperature",
46
  ),
 
 
 
 
 
 
 
 
47
  ],
48
  outputs=gr.Textbox(label="Predicted sentence", lines=10),
49
  ).launch()
 
1
  import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
3
 
4
  model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-1b3", use_cache=True)
5
  tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-1b3")
 
9
  if "\n" not in new_sentence:
10
  return generated_sentence.replace(" ", " ") + "\n- "
11
  else:
12
+ return (new_sentence.split("\n")[0]).replace(" ", " ") + "\n- "
13
 
14
+ def generate_single(model, tokenizer, input_sentence, max_length=50, top_k=0, temperature=0.7, do_sample=True, seed=42):
15
+ set_seed(seed)
16
  input_ids = tokenizer.encode(input_sentence, return_tensors="pt")
17
  output = model.generate(
18
+ input_ids, do_sample=do_sample,
19
  max_length=len(input_sentence)+max_length,
20
  top_k=top_k,
21
+ temperature=temperature,
22
  )
23
  generated_sentence = tokenizer.decode(output[0], skip_special_tokens=True)
24
  return post_process_sentence(input_sentence, generated_sentence)
25
 
26
+ def question_bloom(input_sentence, max_length, temperature, do_sample=True, seed=42):
27
+ post_processed_output = generate_single(model, tokenizer, input_sentence, temperature=temperature, max_length=max_length, do_sample=do_sample, seed=seed)
28
  return post_processed_output.split("\n-")[-2]
29
 
30
  gr.Interface(
 
45
  default=0.6,
46
  label="Temperature",
47
  ),
48
+ gr.inputs.Checkbox(True, label="Do Sample"),
49
+ gr.inputs.Slider(
50
+ minimum=0,
51
+ maximum=256,
52
+ step=1,
53
+ default=42,
54
+ label="Random seed for generation",
55
+ ),
56
  ],
57
  outputs=gr.Textbox(label="Predicted sentence", lines=10),
58
  ).launch()