File size: 4,189 Bytes
bd61da5
3894141
9f4fb0e
233c774
 
 
9f4fb0e
 
 
 
233c774
 
 
 
ff3389c
076a158
bd61da5
3894141
 
 
 
9f4fb0e
 
3894141
 
 
9f4fb0e
bd61da5
 
0f8c139
bd61da5
3894141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e605db
3894141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b329ae
ff377ee
6b329ae
ff377ee
 
a6ac941
ff377ee
 
 
 
7db6225
ff377ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233c774
3894141
 
ff3389c
233c774
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import streamlit as st
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForCausalLM, TFAutoModelForSequenceClassification

# MODEL TO CALL

generator_name = "Alirani/distilgpt2-finetuned-synopsis-genres_final"
tokenizer_gen = AutoTokenizer.from_pretrained(generator_name)
model_gen = TFAutoModelForCausalLM.from_pretrained(generator_name)

def generate_synopsis(model, tokenizer, title):
    input_ids = tokenizer(title, return_tensors="tf")
    output = model.generate(input_ids['input_ids'], max_length=150, num_beams=5, no_repeat_ngram_size=2, top_k=50, attention_mask=input_ids['attention_mask'])
    synopsis = tokenizer.decode(output[0], skip_special_tokens=True)
    processed_synopsis = "".join(synopsis.split('|')[2].rpartition('.')[:2]).strip()
    return processed_synopsis

classifier_name = "Alirani/overview_classifier_final"
tokenizer_clf = AutoTokenizer.from_pretrained(classifier_name)
model_clf = TFAutoModelForSequenceClassification.from_pretrained(classifier_name)

def generate_classification(model, tokenizer, title, overview):
    tokens = tokenizer(f"{title} | {overview}", padding=True, truncation=True, return_tensors="tf")
    output = model(**tokens).logits
    predicted_class_id = int(tf.math.argmax(output, axis=-1)[0])
    return model.config.id2label[predicted_class_id]

favicon = "https://i.ibb.co/JRdhFZg/favicon-32x32.png"

st.set_page_config(page_title="Synopsis Generator", page_icon = favicon, layout = 'wide', initial_sidebar_state = 'auto')

st.title('Demo of a Synopsis Classifier & Generator')

functionality = st.radio("Choose the function to use",
                         ["Classification", "Generation"],
                         captions = ['Classify title & synopsis into genres', 'Generate synopsis from title & genres'],
                         index = None)

if functionality == "Classification" :

    # CLASSIFY A SYNOPSIS

    st.header('Classify a story')

    prod_title = st.text_input('Type a title to classify a synopsis')

    prod_synopsis = st.text_area('Type a synopsis to classify it')

    button_classify = st.button('Get genre')

    if button_classify:
        if (len(prod_title.split(' ')) > 0) & len(prod_synopsis.split(' ')) > 0:
            classified_genre = generate_classification(model_clf, tokenizer_clf, prod_title, prod_synopsis)
            st.write('The genre of the title & synopsis is : ', classified_genre)
        else:
            st.write('Write a title & synopsis for the classifier to work !')

elif functionality == "Generation":

    # GENERATE A SYNOPSIS

    st.header('Generate a story')

    prod_title = st.text_input('Type a title to generate a synopsis')

    option_genres = st.selectbox(
        'Select a genre to tailor your synopsis',
        ('Family', 'Romance', 'Comedy', 'Action', 'Documentary', 'Adventure', 'Drama', 'Mystery', 'Crime', 'Thriller', 'Science Fiction', 'History', 'Music', 'Western', 'Fantasy', 'TV Movie', 'Horror', 'Animation', 'Reality'),
        index=None,
        placeholder="Select genres..."
        )

    complete_synopsis = st.toggle('Synopsis completion')

    if complete_synopsis:

        pre_synopsis = st.text_input('Type the beginning of your synopsis')

        button_synopsis = st.button('Get synopsis')

        if button_synopsis:
            if (len(prod_title.split(' ')) > 0) & (len(option_genres) > 0) :
                gen_synopsis = generate_synopsis(model_gen, tokenizer_gen, f"{prod_title} | {option_genres} | {pre_synopsis}")
                st.text_area('Generated synopsis', value=gen_synopsis, disabled=True)
            else:
                st.write('Write a title & select a genre for the generator to work !')
    
    else:

        button_synopsis = st.button('Get synopsis')

        if button_synopsis:
            if len(prod_title.split(' ')) > 0:
                gen_synopsis = generate_synopsis(model_gen, tokenizer_gen, f"{prod_title} | {option_genres} | ")
                st.text_area('Generated synopsis', value=gen_synopsis, disabled=True)
            else:
                st.write('Write a title for the generator to work !')

else:
    st.write("Select a functionality ! 😊")