cluster-summ / app.py
jaisidhsingh's picture
Update app.py
0d47a4f
import streamlit as st
import nltk
nltk.download("punkt")
nltk.download("wordnet")
from summarize import *
from utils.sentence_embedding import *
from utils.clustering import *
from models.summarizers import *
from nltk.tokenize import sent_tokenize, word_tokenize
import math
from time import perf_counter
START = False
COMPLETED = False
PLACEHOLDER = "Enter your article"
st.markdown("# Inference-Time Optimization for Large Article Summarization 😊")
article = st.text_input(
label="Welcome, enter your article, press enter, and then Summarize",
value=PLACEHOLDER,
)
model_name = st.sidebar.selectbox(
label="Pick your model of choice:",
options=("BART", "Pegasus", "Distill-BART", "RoBERTa")
)
max_length = st.sidebar.slider(
label="Choose the maximum length of the summary",
min_value=100,
max_value=500,
value=250
)
min_length = st.sidebar.slider(
label="Choose the minimum length of the summary",
min_value=20,
max_value=150,
value=50
)
go = st.button(
label="Summarize",
key=0,
)
reset = st.button(
label="Reset",
key=1,
)
START = go
tmp_out = st.empty()
if reset:
COMPLETED = not reset
tmp_out.empty()
else:
COMPLETED = reset
bar = st.progress(0)
if START and not COMPLETED:
start_time = perf_counter()
with tmp_out.container():
st.write("Loading in models and preparing article...")
summarization_model, summarization_tokenizer = load_summarizer(model_name)
summarizer_token_limit = summarization_tokenizer.model_max_length
if "pegasus" in model_name.lower():
input_toks = sent_tokenize(article)
input_sent_toks = input_toks
input_word_toks = word_tokenize(article)
num_toks = len(input_toks)
else:
input_toks = word_tokenize(article)
input_word_toks = input_toks
input_sent_toks = sent_tokenize(article)
num_toks = len(input_toks)
bar.progress(15)
if num_toks <= summarizer_token_limit:
with tmp_out.container():
st.write("Input token count (",num_toks,") <= token limit (",summarizer_token_limit,"), skipping optimization ...")
pred_summary = summarize_input(article, summarization_model, summarization_tokenizer)
end_time = perf_counter()
time_taken = end_time - start_time
bar.progress(100)
else:
with tmp_out.container():
st.write("Input token count (",num_toks,") > token limit (",summarizer_token_limit,"), optimizing ...")
st.write(f"Going beyond {model_name} token limit:", summarizer_token_limit)
input_sent_toks = sent_tokenize(article)
embeddings = make_embeddings(input_sent_toks, mean_pooling)
embeddings = embeddings.numpy()
bar.progress(30)
n_clusters_estimate = math.ceil(num_toks / summarizer_token_limit)
clemb = ClusterEmbeddings(
cluster_estimate=n_clusters_estimate,
cluster_fn="agglo", # much better
embeddings=embeddings,
sentences=np.array(input_sent_toks),
words=np.array(input_word_toks)
)
bar.progress(50)
curr = 50
rem = 90 - curr
sentence_clusters = clemb.get_sentence_clusters()
n = len(sentence_clusters)
summs = ""
for cluster in sentence_clusters:
cluster_summary = summarize_input(
cluster,
summarization_model,
summarization_tokenizer,
max_length=250,
min_length=50,
)
if type(cluster_summary) == list:
cluster_summary = cluster_summary[0]
summs += cluster_summary + " "
inc = rem / n
bar.progress((curr + inc)/100)
bar.progress(90)
pred_summary = summarize_input(
summs,
summarization_model,
summarization_tokenizer,
max_length=max_length,
min_length=min_length,
)
bar.progress(100)
end_time = perf_counter()
time_taken = end_time - start_time
with tmp_out.container():
st.write(f"Took {time_taken} seconds")
st.write(f"Summary: {pred_summary}")
START = False
COMPLETED = True
else:
pass