Spaces:
Runtime error
Runtime error
alirani
commited on
Commit
•
ff3389c
1
Parent(s):
076a158
add genre to generation
Browse files
app.py
CHANGED
@@ -3,7 +3,7 @@ from transformers import AutoTokenizer, TFAutoModelForCausalLM
|
|
3 |
|
4 |
# MODEL TO CALL
|
5 |
|
6 |
-
model_name = "Alirani/distilgpt2-finetuned-synopsis"
|
7 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
8 |
model = TFAutoModelForCausalLM.from_pretrained(model_name)
|
9 |
|
@@ -11,7 +11,7 @@ def generate_synopsis(model, tokenizer, title):
|
|
11 |
input_ids = tokenizer(title, return_tensors="tf")
|
12 |
output = model.generate(input_ids['input_ids'], max_length=150, num_beams=5, no_repeat_ngram_size=2, top_k=50, attention_mask=input_ids['attention_mask'])
|
13 |
synopsis = tokenizer.decode(output[0], skip_special_tokens=True)
|
14 |
-
processed_synopsis = "".join(synopsis.split('
|
15 |
return processed_synopsis
|
16 |
|
17 |
favicon = "https://i.ibb.co/JRdhFZg/favicon-32x32.png"
|
@@ -24,11 +24,16 @@ st.header('Generate a story')
|
|
24 |
|
25 |
prod_title = st.text_input('Type a title to generate a synopsis')
|
26 |
|
|
|
|
|
|
|
|
|
|
|
27 |
button_synopsis = st.button('Get synopsis')
|
28 |
|
29 |
if button_synopsis:
|
30 |
if len(prod_title.split(' ')) > 0:
|
31 |
-
gen_synopsis = generate_synopsis(model, tokenizer, f"{prod_title}
|
32 |
st.text_area('Generated synopsis', value=gen_synopsis, disabled=True)
|
33 |
else:
|
34 |
st.write('Write a title for the generator to work !')
|
|
|
3 |
|
4 |
# MODEL TO CALL
|
5 |
|
6 |
+
model_name = "Alirani/distilgpt2-finetuned-synopsis-genres_final"
|
7 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
8 |
model = TFAutoModelForCausalLM.from_pretrained(model_name)
|
9 |
|
|
|
11 |
input_ids = tokenizer(title, return_tensors="tf")
|
12 |
output = model.generate(input_ids['input_ids'], max_length=150, num_beams=5, no_repeat_ngram_size=2, top_k=50, attention_mask=input_ids['attention_mask'])
|
13 |
synopsis = tokenizer.decode(output[0], skip_special_tokens=True)
|
14 |
+
processed_synopsis = "".join(synopsis.split('|')[2].rpartition('.')[:2]).strip()
|
15 |
return processed_synopsis
|
16 |
|
17 |
favicon = "https://i.ibb.co/JRdhFZg/favicon-32x32.png"
|
|
|
24 |
|
25 |
prod_title = st.text_input('Type a title to generate a synopsis')
|
26 |
|
27 |
+
option_genres = st.selectbox(
|
28 |
+
'Select a genre to tailor your synopsis',
|
29 |
+
('Family', 'Romance', 'Comedy', 'Action', 'Documentary', 'Adventure', 'Drama', 'Mystery', 'Crime', 'Thriller', 'Science Fiction', 'History', 'Music', 'Western', 'Fantasy', 'TV Movie', 'Horror', 'Animation', 'Reality')
|
30 |
+
)
|
31 |
+
|
32 |
button_synopsis = st.button('Get synopsis')
|
33 |
|
34 |
if button_synopsis:
|
35 |
if len(prod_title.split(' ')) > 0:
|
36 |
+
gen_synopsis = generate_synopsis(model, tokenizer, f"{prod_title} | {option_genres} | ")
|
37 |
st.text_area('Generated synopsis', value=gen_synopsis, disabled=True)
|
38 |
else:
|
39 |
st.write('Write a title for the generator to work !')
|