Spaces:
Runtime error
Runtime error
File size: 3,796 Bytes
f98f59d ea3043b f98f59d 0d47a4f f98f59d 0d47a4f f98f59d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
import streamlit as st
import nltk
nltk.download("punkt")
nltk.download("wordnet")
from summarize import *
from utils.sentence_embedding import *
from utils.clustering import *
from models.summarizers import *
from nltk.tokenize import sent_tokenize, word_tokenize
import math
from time import perf_counter
START = False
COMPLETED = False
PLACEHOLDER = "Enter your article"
st.markdown("# Inference-Time Optimization for Large Article Summarization π")
article = st.text_input(
label="Welcome, enter your article, press enter, and then Summarize",
value=PLACEHOLDER,
)
model_name = st.sidebar.selectbox(
label="Pick your model of choice:",
options=("BART", "Pegasus", "Distill-BART", "RoBERTa")
)
max_length = st.sidebar.slider(
label="Choose the maximum length of the summary",
min_value=100,
max_value=500,
value=250
)
min_length = st.sidebar.slider(
label="Choose the minimum length of the summary",
min_value=20,
max_value=150,
value=50
)
go = st.button(
label="Summarize",
key=0,
)
reset = st.button(
label="Reset",
key=1,
)
START = go
tmp_out = st.empty()
if reset:
COMPLETED = not reset
tmp_out.empty()
else:
COMPLETED = reset
bar = st.progress(0)
if START and not COMPLETED:
start_time = perf_counter()
with tmp_out.container():
st.write("Loading in models and preparing article...")
summarization_model, summarization_tokenizer = load_summarizer(model_name)
summarizer_token_limit = summarization_tokenizer.model_max_length
if "pegasus" in model_name.lower():
input_toks = sent_tokenize(article)
input_sent_toks = input_toks
input_word_toks = word_tokenize(article)
num_toks = len(input_toks)
else:
input_toks = word_tokenize(article)
input_word_toks = input_toks
input_sent_toks = sent_tokenize(article)
num_toks = len(input_toks)
bar.progress(15)
if num_toks <= summarizer_token_limit:
with tmp_out.container():
st.write("Input token count (",num_toks,") <= token limit (",summarizer_token_limit,"), skipping optimization ...")
pred_summary = summarize_input(article, summarization_model, summarization_tokenizer)
end_time = perf_counter()
time_taken = end_time - start_time
bar.progress(100)
else:
with tmp_out.container():
st.write("Input token count (",num_toks,") > token limit (",summarizer_token_limit,"), optimizing ...")
st.write(f"Going beyond {model_name} token limit:", summarizer_token_limit)
input_sent_toks = sent_tokenize(article)
embeddings = make_embeddings(input_sent_toks, mean_pooling)
embeddings = embeddings.numpy()
bar.progress(30)
n_clusters_estimate = math.ceil(num_toks / summarizer_token_limit)
clemb = ClusterEmbeddings(
cluster_estimate=n_clusters_estimate,
cluster_fn="agglo", # much better
embeddings=embeddings,
sentences=np.array(input_sent_toks),
words=np.array(input_word_toks)
)
bar.progress(50)
curr = 50
rem = 90 - curr
sentence_clusters = clemb.get_sentence_clusters()
n = len(sentence_clusters)
summs = ""
for cluster in sentence_clusters:
cluster_summary = summarize_input(
cluster,
summarization_model,
summarization_tokenizer,
max_length=250,
min_length=50,
)
if type(cluster_summary) == list:
cluster_summary = cluster_summary[0]
summs += cluster_summary + " "
inc = rem / n
bar.progress((curr + inc)/100)
bar.progress(90)
pred_summary = summarize_input(
summs,
summarization_model,
summarization_tokenizer,
max_length=max_length,
min_length=min_length,
)
bar.progress(100)
end_time = perf_counter()
time_taken = end_time - start_time
with tmp_out.container():
st.write(f"Took {time_taken} seconds")
st.write(f"Summary: {pred_summary}")
START = False
COMPLETED = True
else:
pass |