alirani commited on
Commit
3894141
1 Parent(s): 0f8c139

add classifier

Browse files
Files changed (1) hide show
  1. app.py +57 -22
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import streamlit as st
 
2
  from transformers import AutoTokenizer, TFAutoModelForCausalLM, TFAutoModelForSequenceClassification
3
 
4
  # MODEL TO CALL
@@ -7,10 +8,6 @@ generator_name = "Alirani/distilgpt2-finetuned-synopsis-genres_final"
7
  tokenizer_gen = AutoTokenizer.from_pretrained(generator_name)
8
  model_gen = TFAutoModelForCausalLM.from_pretrained(generator_name)
9
 
10
- classifier_name = "Alirani/overview_classifier_final"
11
- tokenizer_clf = AutoTokenizer.from_pretrained(classifier_name)
12
- model_clf = TFAutoModelForSequenceClassification.from_pretrained(classifier_name)
13
-
14
  def generate_synopsis(model, tokenizer, title):
15
  input_ids = tokenizer(title, return_tensors="tf")
16
  output = model.generate(input_ids['input_ids'], max_length=150, num_beams=5, no_repeat_ngram_size=2, top_k=50, attention_mask=input_ids['attention_mask'])
@@ -18,33 +15,71 @@ def generate_synopsis(model, tokenizer, title):
18
  processed_synopsis = "".join(synopsis.split('|')[2].rpartition('.')[:2]).strip()
19
  return processed_synopsis
20
 
 
 
 
 
21
  def generate_classification(model, tokenizer, title, overview):
22
  tokens = tokenizer(f"{title} | {overview}", padding=True, truncation=True, return_tensors="tf")
23
- output = model(**tokens)
24
- return output
 
25
 
26
  favicon = "https://i.ibb.co/JRdhFZg/favicon-32x32.png"
27
 
28
  st.set_page_config(page_title="Synopsis Generator", page_icon = favicon, layout = 'wide', initial_sidebar_state = 'auto')
29
 
30
- st.title('Demo of a Synopsis Generator')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
- st.header('Generate a story')
33
 
34
- prod_title = st.text_input('Type a title to generate a synopsis')
 
 
 
 
 
35
 
36
- option_genres = st.selectbox(
37
- 'Select a genre to tailor your synopsis',
38
- ('Family', 'Romance', 'Comedy', 'Action', 'Documentary', 'Adventure', 'Drama', 'Mystery', 'Crime', 'Thriller', 'Science Fiction', 'History', 'Music', 'Western', 'Fantasy', 'TV Movie', 'Horror', 'Animation', 'Reality'),
39
- index=None,
40
- placeholder="Select genres..."
41
- )
42
 
43
- button_synopsis = st.button('Get synopsis')
44
 
45
- if button_synopsis:
46
- if len(prod_title.split(' ')) > 0:
47
- gen_synopsis = generate_synopsis(model_gen, tokenizer_gen, f"{prod_title} | {option_genres} | ")
48
- st.text_area('Generated synopsis', value=gen_synopsis, disabled=True)
49
- else:
50
- st.write('Write a title for the generator to work !')
 
1
  import streamlit as st
2
+ import tensorflow as tf
3
  from transformers import AutoTokenizer, TFAutoModelForCausalLM, TFAutoModelForSequenceClassification
4
 
5
  # MODEL TO CALL
 
8
  tokenizer_gen = AutoTokenizer.from_pretrained(generator_name)
9
  model_gen = TFAutoModelForCausalLM.from_pretrained(generator_name)
10
 
 
 
 
 
11
  def generate_synopsis(model, tokenizer, title):
12
  input_ids = tokenizer(title, return_tensors="tf")
13
  output = model.generate(input_ids['input_ids'], max_length=150, num_beams=5, no_repeat_ngram_size=2, top_k=50, attention_mask=input_ids['attention_mask'])
 
15
  processed_synopsis = "".join(synopsis.split('|')[2].rpartition('.')[:2]).strip()
16
  return processed_synopsis
17
 
18
+ classifier_name = "Alirani/overview_classifier_final"
19
+ tokenizer_clf = AutoTokenizer.from_pretrained(classifier_name)
20
+ model_clf = TFAutoModelForSequenceClassification.from_pretrained(classifier_name)
21
+
22
  def generate_classification(model, tokenizer, title, overview):
23
  tokens = tokenizer(f"{title} | {overview}", padding=True, truncation=True, return_tensors="tf")
24
+ output = model(**tokens).logits
25
+ predicted_class_id = int(tf.math.argmax(output, axis=-1)[0])
26
+ return model.config.id2label[predicted_class_id]
27
 
28
  favicon = "https://i.ibb.co/JRdhFZg/favicon-32x32.png"
29
 
30
  st.set_page_config(page_title="Synopsis Generator", page_icon = favicon, layout = 'wide', initial_sidebar_state = 'auto')
31
 
32
+ st.title('Demo of a Synopsis Classifier & Generator')
33
+
34
+ functionality = st.radio("Choose the function to use",
35
+ ["Classification", "Generation"],
36
+ captions = ['Classify title & synopsis into genres', 'Generate synopsis from title & genres'],
37
+ index = None)
38
+
39
+ if functionality == "Classification" :
40
+
41
+ # CLASSIFY A SYNOPSIS
42
+
43
+ st.header('Classify a story')
44
+
45
+ prod_title = st.text_input('Type a title to classify a synopsis')
46
+
47
+ prod_synopsis = st.text_area('Type a synopsis to classify it')
48
+
49
+ button_classify = st.button('Get genre')
50
+
51
+ if button_classify:
52
+ if (len(prod_title.split(' ')) > 0) & len(prod_synopsis.split(' ')) > 0:
53
+ classified_genre = generate_classification(model_gen, tokenizer_gen, prod_title, prod_synopsis)
54
+ st.write('The genre of the title & synopsis is : ', classified_genre)
55
+ else:
56
+ st.write('Write a title & synopsis for the classifier to work !')
57
+
58
+ elif functionality == "Generation":
59
+
60
+ # GENERATE A SYNOPSIS
61
+
62
+ st.header('Generate a story')
63
+
64
+ prod_title = st.text_input('Type a title to generate a synopsis')
65
+
66
+ option_genres = st.selectbox(
67
+ 'Select a genre to tailor your synopsis',
68
+ ('Family', 'Romance', 'Comedy', 'Action', 'Documentary', 'Adventure', 'Drama', 'Mystery', 'Crime', 'Thriller', 'Science Fiction', 'History', 'Music', 'Western', 'Fantasy', 'TV Movie', 'Horror', 'Animation', 'Reality'),
69
+ index=None,
70
+ placeholder="Select genres..."
71
+ )
72
 
73
+ button_synopsis = st.button('Get synopsis')
74
 
75
+ if button_synopsis:
76
+ if len(prod_title.split(' ')) > 0:
77
+ gen_synopsis = generate_synopsis(model_gen, tokenizer_gen, f"{prod_title} | {option_genres} | ")
78
+ st.text_area('Generated synopsis', value=gen_synopsis, disabled=True)
79
+ else:
80
+ st.write('Write a title for the generator to work !')
81
 
82
+ else:
83
+ st.write("Select a functionality ! 😊")
 
 
 
 
84
 
 
85