File size: 3,796 Bytes
f98f59d
ea3043b
 
 
f98f59d
 
 
 
 
 
 
 
 
 
 
 
 
0d47a4f
f98f59d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d47a4f
f98f59d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import streamlit as st
import nltk
nltk.download("punkt")
nltk.download("wordnet")
from summarize import *
from utils.sentence_embedding import *
from utils.clustering import *
from models.summarizers import *
from nltk.tokenize import sent_tokenize, word_tokenize
import math
from time import perf_counter


START = False
COMPLETED = False
PLACEHOLDER = "Enter your article"

st.markdown("# Inference-Time Optimization for Large Article Summarization 😊")

article = st.text_input(
    label="Welcome, enter your article, press enter, and then Summarize",
    value=PLACEHOLDER,
)

model_name = st.sidebar.selectbox(
    label="Pick your model of choice:",
    options=("BART", "Pegasus", "Distill-BART", "RoBERTa")
)

max_length = st.sidebar.slider(
	label="Choose the maximum length of the summary",
	min_value=100,
	max_value=500,
	value=250
)

min_length = st.sidebar.slider(
	label="Choose the minimum length of the summary",
	min_value=20,
	max_value=150,
	value=50
)

go = st.button(
    label="Summarize",
    key=0,
)

reset = st.button(
	label="Reset",
	key=1,
)


START = go
tmp_out = st.empty()

if reset:
	COMPLETED = not reset
	tmp_out.empty()
else:
	COMPLETED = reset


bar = st.progress(0)

if START and not COMPLETED:
	start_time = perf_counter()
	
	with tmp_out.container():
		st.write("Loading in models and preparing article...")
	
	summarization_model, summarization_tokenizer = load_summarizer(model_name)
	summarizer_token_limit = summarization_tokenizer.model_max_length

	if "pegasus" in model_name.lower():
		input_toks = sent_tokenize(article)
		input_sent_toks = input_toks
		input_word_toks = word_tokenize(article)
		num_toks = len(input_toks)
	else:
		input_toks = word_tokenize(article)
		input_word_toks = input_toks
		input_sent_toks = sent_tokenize(article)
		num_toks = len(input_toks)

	bar.progress(15)

	if num_toks <= summarizer_token_limit:
		with tmp_out.container():
			st.write("Input token count (",num_toks,") <= token limit (",summarizer_token_limit,"), skipping optimization ...")
		
		pred_summary = summarize_input(article, summarization_model, summarization_tokenizer)		
		end_time = perf_counter()
		time_taken = end_time - start_time
		bar.progress(100)

	else:
		with tmp_out.container():
			st.write("Input token count (",num_toks,") > token limit (",summarizer_token_limit,"), optimizing ...")
			st.write(f"Going beyond {model_name} token limit:", summarizer_token_limit)

		input_sent_toks = sent_tokenize(article)
		embeddings = make_embeddings(input_sent_toks, mean_pooling)
		embeddings = embeddings.numpy()

		bar.progress(30)

		n_clusters_estimate = math.ceil(num_toks / summarizer_token_limit)

		clemb = ClusterEmbeddings(
			cluster_estimate=n_clusters_estimate,
			cluster_fn="agglo", # much better
			embeddings=embeddings,
			sentences=np.array(input_sent_toks),
			words=np.array(input_word_toks)
		)

		bar.progress(50)
		curr = 50
		rem = 90 - curr

		sentence_clusters = clemb.get_sentence_clusters()	

		n = len(sentence_clusters)
		summs = ""
		for cluster in sentence_clusters:
			cluster_summary = summarize_input(
				cluster, 
				summarization_model, 
				summarization_tokenizer,
				max_length=250,
				min_length=50,
			)
			if type(cluster_summary) == list:
				cluster_summary = cluster_summary[0]
			summs += cluster_summary + " "

			inc = rem / n
			bar.progress((curr + inc)/100)			
			
		bar.progress(90)

		pred_summary = summarize_input(
			summs, 
			summarization_model, 
			summarization_tokenizer,
			max_length=max_length,
			min_length=min_length,
		)

		bar.progress(100)

		end_time = perf_counter()
		time_taken = end_time - start_time

	with tmp_out.container():
		st.write(f"Took {time_taken} seconds")
		st.write(f"Summary: {pred_summary}")
	
	START = False
	COMPLETED = True

else:
	pass