bloom-1b3-gen / app.py
ybelkada's picture
few modifs
3aff428
raw
history blame contribute delete
No virus
2.31 kB
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-1b3", use_cache=True)
tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-1b3")
def post_process_sentence(input_sentence, generated_sentence):
new_sentence = generated_sentence.replace(input_sentence, "")
if "\n" not in new_sentence:
return generated_sentence.replace(" ", " ") + "\n- "
else:
return (new_sentence.split("\n")[0]).replace(" ", " ") + "\n- "
def generate_single(model, tokenizer, input_sentence, max_length=50, top_k=0, temperature=0.7, do_sample=True, seed=42):
set_seed(seed)
input_ids = tokenizer.encode(input_sentence, return_tensors="pt")
output = model.generate(
input_ids, do_sample=do_sample,
max_length=len(input_sentence)+max_length,
top_k=top_k,
temperature=temperature,
)
generated_sentence = tokenizer.decode(output[0], skip_special_tokens=True)
return post_process_sentence(input_sentence, generated_sentence)
def question_bloom(input_sentence, max_length, temperature, do_sample=True, top_k=3, seed=42):
post_processed_output = generate_single(model, tokenizer, input_sentence, temperature=temperature, max_length=max_length, do_sample=do_sample, top_k=top_k, seed=seed)
return post_processed_output.split("\n-")[-2]
gr.Interface(
question_bloom,
[
gr.Textbox(lines=10, label="Input code"),
gr.inputs.Slider(
minimum=8,
maximum=256,
step=1,
default=8,
label="Number of tokens to generate",
),
gr.inputs.Slider(
minimum=0,
maximum=2,
step=0.1,
default=0.6,
label="Temperature",
),
gr.inputs.Checkbox(True, label="Do Sample"),
gr.inputs.Slider(
minimum=0,
maximum=10,
step=1,
default=3,
label="Top K",
),
gr.inputs.Slider(
minimum=0,
maximum=256,
step=1,
default=42,
label="Random seed for generation",
),
],
outputs=gr.Textbox(label="Predicted sentence", lines=10),
).launch()