alirani commited on
Commit
ff3389c
1 Parent(s): 076a158

add genre to generation

Browse files
Files changed (1) hide show
  1. app.py +8 -3
app.py CHANGED
@@ -3,7 +3,7 @@ from transformers import AutoTokenizer, TFAutoModelForCausalLM
3
 
4
  # MODEL TO CALL
5
 
6
- model_name = "Alirani/distilgpt2-finetuned-synopsis"
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
  model = TFAutoModelForCausalLM.from_pretrained(model_name)
9
 
@@ -11,7 +11,7 @@ def generate_synopsis(model, tokenizer, title):
11
  input_ids = tokenizer(title, return_tensors="tf")
12
  output = model.generate(input_ids['input_ids'], max_length=150, num_beams=5, no_repeat_ngram_size=2, top_k=50, attention_mask=input_ids['attention_mask'])
13
  synopsis = tokenizer.decode(output[0], skip_special_tokens=True)
14
- processed_synopsis = "".join(synopsis.split(':')[1].rpartition('.')[:2]).strip()
15
  return processed_synopsis
16
 
17
  favicon = "https://i.ibb.co/JRdhFZg/favicon-32x32.png"
@@ -24,11 +24,16 @@ st.header('Generate a story')
24
 
25
  prod_title = st.text_input('Type a title to generate a synopsis')
26
 
 
 
 
 
 
27
  button_synopsis = st.button('Get synopsis')
28
 
29
  if button_synopsis:
30
  if len(prod_title.split(' ')) > 0:
31
- gen_synopsis = generate_synopsis(model, tokenizer, f"{prod_title} : ")
32
  st.text_area('Generated synopsis', value=gen_synopsis, disabled=True)
33
  else:
34
  st.write('Write a title for the generator to work !')
 
3
 
4
  # MODEL TO CALL
5
 
6
+ model_name = "Alirani/distilgpt2-finetuned-synopsis-genres_final"
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
  model = TFAutoModelForCausalLM.from_pretrained(model_name)
9
 
 
11
  input_ids = tokenizer(title, return_tensors="tf")
12
  output = model.generate(input_ids['input_ids'], max_length=150, num_beams=5, no_repeat_ngram_size=2, top_k=50, attention_mask=input_ids['attention_mask'])
13
  synopsis = tokenizer.decode(output[0], skip_special_tokens=True)
14
+ processed_synopsis = "".join(synopsis.split('|')[2].rpartition('.')[:2]).strip()
15
  return processed_synopsis
16
 
17
  favicon = "https://i.ibb.co/JRdhFZg/favicon-32x32.png"
 
24
 
25
  prod_title = st.text_input('Type a title to generate a synopsis')
26
 
27
+ option_genres = st.selectbox(
28
+ 'Select a genre to tailor your synopsis',
29
+ ('Family', 'Romance', 'Comedy', 'Action', 'Documentary', 'Adventure', 'Drama', 'Mystery', 'Crime', 'Thriller', 'Science Fiction', 'History', 'Music', 'Western', 'Fantasy', 'TV Movie', 'Horror', 'Animation', 'Reality')
30
+ )
31
+
32
  button_synopsis = st.button('Get synopsis')
33
 
34
  if button_synopsis:
35
  if len(prod_title.split(' ')) > 0:
36
+ gen_synopsis = generate_synopsis(model, tokenizer, f"{prod_title} | {option_genres} | ")
37
  st.text_area('Generated synopsis', value=gen_synopsis, disabled=True)
38
  else:
39
  st.write('Write a title for the generator to work !')