Axolotlily's picture
Update app.py
4d3ea95
raw
history blame
2.9 kB
"""
RPG Room Generator App
Sets the sampling parameters and provides minimal interface to the user
https://huggingface.co/blog/how-to-generate
"""
import gradio as gr
from gradio import inputs # allows easier doc lookup in Pycharm
import transformers as tr
MPATH = "./models/mdl_roomgen2"
MODEL = tr.GPT2LMHeadModel.from_pretrained(MPATH)
# ToDo: Will save tokenizer next time so can replace this with a load
SPECIAL_TOKENS = {
'eos_token': '<|EOS|>',
'bos_token': '<|endoftext|>',
'pad_token': '<pad>',
'sep_token': '<|body|>'
}
TOK = tr.GPT2Tokenizer.from_pretrained("gpt2")
TOK.add_special_tokens(SPECIAL_TOKENS)
SAMPLING_OPTIONS = {
"Reasonable":
{
"top_k": 25,
"temperature": 50,
"top_p": 60
},
"Odd":
{
"top_k": 50,
"temperature": 75,
"top_p": 90
},
"Insane":
{
"top_k": 300,
"temperature": 100,
"top_p": 85
},
}
def generate_room(room_name, room_desc, max_length, sampling_method):
"""
Uses pretrained model to generate text for a dungeon room
Returns: Room description text
"""
prompt = " ".join(
[
SPECIAL_TOKENS["bos_token"],
room_name,
SPECIAL_TOKENS["sep_token"],
room_desc
]
)
# Only want to skip the room name part
to_skip = TOK.encode(" ".join([SPECIAL_TOKENS["bos_token"], room_name, SPECIAL_TOKENS["sep_token"]]),
return_tensors="pt")
ids = TOK.encode(prompt, return_tensors="pt")
# Sample
top_k = SAMPLING_OPTIONS[sampling_method]["top_k"]
temperature = SAMPLING_OPTIONS[sampling_method]["temperature"] / 100.
top_p = SAMPLING_OPTIONS[sampling_method]["top_p"] / 100.
output = MODEL.generate(
ids,
max_length=max_length,
do_sample=True,
top_k=top_k,
temperature=temperature,
top_p=top_p
)
output = TOK.decode(output[0][to_skip.shape[1]:], clean_up_tokenization_spaces=True).replace(" ", " ")
# Slice off last partial sentence
last_period = output.rfind(".")
if last_period > 0:
output = output[:last_period+1]
return output
if __name__ == "__main__":
iface = gr.Interface(
title="RPG Room Generator",
fn=generate_room,
inputs=[
inputs.Textbox(lines=1, label="Room Name"),
inputs.Textbox(lines=3, label="Start of Room Description (Optional)", default=""),
inputs.Slider(minimum=50, maximum=1000, default=200, label="Length"),
inputs.Radio(choices=list(SAMPLING_OPTIONS.keys()), default="Odd", label="Craziness"),
],
outputs="text",
layout="horizontal",
allow_flagging="never",
theme="dark",
)
app, local_url, share_url = iface.launch()