Spaces:
Running
Running
mkmenta
commited on
Commit
·
77a05a1
1
Parent(s):
e0f03a3
Add application file
Browse files
app.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""GPT-1 and GPT-2 Text Generation demo."""
|
2 |
+
import gradio as gr
|
3 |
+
from torch.cuda import is_available
|
4 |
+
from transformers import OpenAIGPTLMHeadModel, OpenAIGPTTokenizer, GPT2Tokenizer, GPT2LMHeadModel
|
5 |
+
|
6 |
+
|
7 |
+
tokenizer = None
|
8 |
+
model = None
|
9 |
+
loaded_model = None
|
10 |
+
|
11 |
+
|
12 |
+
def load_model(model_name):
|
13 |
+
"""Loads the model and tokenizer from HuggingFace."""
|
14 |
+
global tokenizer, model, loaded_model
|
15 |
+
loaded_model = model_name
|
16 |
+
huggingface_model_name = model_name.split('(')[1][:-1]
|
17 |
+
if huggingface_model_name == 'openai-gpt': # GPT-1
|
18 |
+
tokenizer = OpenAIGPTTokenizer.from_pretrained(huggingface_model_name)
|
19 |
+
model = OpenAIGPTLMHeadModel.from_pretrained(huggingface_model_name)
|
20 |
+
else: # GPT-2
|
21 |
+
tokenizer = GPT2Tokenizer.from_pretrained(huggingface_model_name)
|
22 |
+
model = GPT2LMHeadModel.from_pretrained(huggingface_model_name)
|
23 |
+
# Load model in CUDA if available
|
24 |
+
if is_available():
|
25 |
+
model = model.cuda()
|
26 |
+
|
27 |
+
|
28 |
+
def generate(inp, model_name, temperature, top_p, rep_pty, max_length):
|
29 |
+
"""Generates text using the given model and parameters."""
|
30 |
+
if loaded_model != model_name:
|
31 |
+
load_model(model_name)
|
32 |
+
inputs = tokenizer.encode(inp, return_tensors='pt')
|
33 |
+
if is_available():
|
34 |
+
inputs = inputs.cuda()
|
35 |
+
outputs = model.generate(inputs,
|
36 |
+
max_length=max_length,
|
37 |
+
temperature=temperature,
|
38 |
+
num_return_sequences=1,
|
39 |
+
top_p=top_p,
|
40 |
+
repetition_penalty=rep_pty)
|
41 |
+
out = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
42 |
+
if 'GPT-1' in model_name:
|
43 |
+
out = out.replace(inp.lower(), "")
|
44 |
+
else:
|
45 |
+
out = out.replace(inp, "")
|
46 |
+
return out
|
47 |
+
|
48 |
+
|
49 |
+
with gr.Blocks() as demo:
|
50 |
+
gr.Markdown("# Try GPT-1 and GPT-2")
|
51 |
+
with gr.Row():
|
52 |
+
with gr.Column(scale=4):
|
53 |
+
inp = gr.Textbox(label="Input text:", placeholder="Write something here.", lines=10,)
|
54 |
+
out = gr.Textbox(label="Generated text:", lines=25)
|
55 |
+
with gr.Column(scale=1):
|
56 |
+
with gr.Row(scale=1):
|
57 |
+
model_name = gr.Dropdown(label="Select a model:",
|
58 |
+
choices=['GPT-2 XL (gpt2-xl)',
|
59 |
+
'GPT-2 L (gpt2-large)',
|
60 |
+
'GPT-2 M (gpt2-medium)',
|
61 |
+
'GPT-2 S (gpt2)',
|
62 |
+
'GPT-1 (openai-gpt)'],
|
63 |
+
value='GPT-2 XL (gpt2-xl)')
|
64 |
+
btn_run = gr.Button("Generate")
|
65 |
+
temperature = gr.Slider(
|
66 |
+
label="Temperature",
|
67 |
+
info=("Degree of randomness in the output, where higher values make it more unpredictable"
|
68 |
+
" and creative, while lower values make it more deterministic and focused."),
|
69 |
+
minimum=0.01, maximum=3.0, step=0.01, value=0.7)
|
70 |
+
top_p = gr.Slider(
|
71 |
+
label="Top-p",
|
72 |
+
info=("If set to float < 1, only the most probable tokens with probabilities that add up"
|
73 |
+
" to `top_p` or higher are kept for generation."),
|
74 |
+
minimum=0.01, maximum=1.0, step=0.01, value=.9)
|
75 |
+
rep_pty = gr.Slider(label="Repetition Penalty",
|
76 |
+
info="Token repetition penalty. 1.0 means no penalty.",
|
77 |
+
minimum=1.0, maximum=2.0, step=0.01, value=1.2)
|
78 |
+
max_length = gr.Number(label="Max Length",
|
79 |
+
info="The maximum length of the sequence to be generated.",
|
80 |
+
minimum=1, maximum=1024, value=256, precision=0)
|
81 |
+
# Fill the rest of the column with blank space
|
82 |
+
# (I didn't find a better way to do this)
|
83 |
+
with gr.Row(scale=1000):
|
84 |
+
gr.Markdown()
|
85 |
+
btn_run.click(fn=generate, inputs=[inp, model_name, temperature, top_p, rep_pty, max_length], outputs=out)
|
86 |
+
|
87 |
+
demo.launch()
|