Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed | |
model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-1b3", use_cache=True) | |
tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-1b3") | |
def post_process_sentence(input_sentence, generated_sentence): | |
new_sentence = generated_sentence.replace(input_sentence, "") | |
if "\n" not in new_sentence: | |
return generated_sentence.replace(" ", " ") + "\n- " | |
else: | |
return (new_sentence.split("\n")[0]).replace(" ", " ") + "\n- " | |
def generate_single(model, tokenizer, input_sentence, max_length=50, top_k=0, temperature=0.7, do_sample=True, seed=42): | |
set_seed(seed) | |
input_ids = tokenizer.encode(input_sentence, return_tensors="pt") | |
output = model.generate( | |
input_ids, do_sample=do_sample, | |
max_length=len(input_sentence)+max_length, | |
top_k=top_k, | |
temperature=temperature, | |
) | |
generated_sentence = tokenizer.decode(output[0], skip_special_tokens=True) | |
return post_process_sentence(input_sentence, generated_sentence) | |
def question_bloom(input_sentence, max_length, temperature, do_sample=True, top_k=3, seed=42): | |
post_processed_output = generate_single(model, tokenizer, input_sentence, temperature=temperature, max_length=max_length, do_sample=do_sample, top_k=top_k, seed=seed) | |
return post_processed_output.split("\n-")[-2] | |
gr.Interface( | |
question_bloom, | |
[ | |
gr.Textbox(lines=10, label="Input code"), | |
gr.inputs.Slider( | |
minimum=8, | |
maximum=256, | |
step=1, | |
default=8, | |
label="Number of tokens to generate", | |
), | |
gr.inputs.Slider( | |
minimum=0, | |
maximum=2, | |
step=0.1, | |
default=0.6, | |
label="Temperature", | |
), | |
gr.inputs.Checkbox(True, label="Do Sample"), | |
gr.inputs.Slider( | |
minimum=0, | |
maximum=10, | |
step=1, | |
default=3, | |
label="Top K", | |
), | |
gr.inputs.Slider( | |
minimum=0, | |
maximum=256, | |
step=1, | |
default=42, | |
label="Random seed for generation", | |
), | |
], | |
outputs=gr.Textbox(label="Predicted sentence", lines=10), | |
).launch() |