Spaces:
Runtime error
Runtime error
""" | |
RPG Room Generator App | |
Sets the sampling parameters and provides minimal interface to the user | |
https://huggingface.co/blog/how-to-generate | |
""" | |
import gradio as gr | |
from gradio import inputs # allows easier doc lookup in Pycharm | |
import transformers as tr | |
MPATH = "./models/mdl_roomgen2" | |
MODEL = tr.GPT2LMHeadModel.from_pretrained(MPATH) | |
# ToDo: Will save tokenizer next time so can replace this with a load | |
SPECIAL_TOKENS = { | |
'eos_token': '<|EOS|>', | |
'bos_token': '<|endoftext|>', | |
'pad_token': '<pad>', | |
'sep_token': '<|body|>' | |
} | |
TOK = tr.GPT2Tokenizer.from_pretrained("gpt2") | |
TOK.add_special_tokens(SPECIAL_TOKENS) | |
SAMPLING_OPTIONS = { | |
"Reasonable": | |
{ | |
"top_k": 25, | |
"temperature": 50, | |
"top_p": 60 | |
}, | |
"Odd": | |
{ | |
"top_k": 50, | |
"temperature": 75, | |
"top_p": 90 | |
}, | |
"Insane": | |
{ | |
"top_k": 300, | |
"temperature": 100, | |
"top_p": 85 | |
}, | |
} | |
def generate_room(room_name, room_desc, max_length, sampling_method): | |
""" | |
Uses pretrained model to generate text for a dungeon room | |
Returns: Room description text | |
""" | |
prompt = " ".join( | |
[ | |
SPECIAL_TOKENS["bos_token"], | |
room_name, | |
SPECIAL_TOKENS["sep_token"], | |
room_desc | |
] | |
) | |
# Only want to skip the room name part | |
to_skip = TOK.encode(" ".join([SPECIAL_TOKENS["bos_token"], room_name, SPECIAL_TOKENS["sep_token"]]), | |
return_tensors="pt") | |
ids = TOK.encode(prompt, return_tensors="pt") | |
# Sample | |
top_k = SAMPLING_OPTIONS[sampling_method]["top_k"] | |
temperature = SAMPLING_OPTIONS[sampling_method]["temperature"] / 100. | |
top_p = SAMPLING_OPTIONS[sampling_method]["top_p"] / 100. | |
output = MODEL.generate( | |
ids, | |
max_length=max_length, | |
do_sample=True, | |
top_k=top_k, | |
temperature=temperature, | |
top_p=top_p | |
) | |
output = TOK.decode(output[0][to_skip.shape[1]:], clean_up_tokenization_spaces=True).replace(" ", " ") | |
# Slice off last partial sentence | |
last_period = output.rfind(".") | |
if last_period > 0: | |
output = output[:last_period+1] | |
return output | |
if __name__ == "__main__": | |
iface = gr.Interface( | |
title="RPG Room Generator", | |
fn=generate_room, | |
inputs=[ | |
inputs.Textbox(lines=1, label="Room Name"), | |
inputs.Textbox(lines=3, label="Start of Room Description (Optional)", default=""), | |
inputs.Slider(minimum=50, maximum=1000, default=200, label="Length"), | |
inputs.Radio(choices=list(SAMPLING_OPTIONS.keys()), default="Odd", label="Craziness"), | |
], | |
outputs="text", | |
layout="horizontal", | |
allow_flagging="never", | |
theme="dark", | |
) | |
app, local_url, share_url = iface.launch() | |