Spaces:
Runtime error
Runtime error
jaisidhsingh
commited on
Commit
·
f98f59d
1
Parent(s):
c548571
add code
Browse files- app.py.py +162 -0
- models/__pycache__/summarizers.cpython-39.pyc +0 -0
- models/summarizers.py +56 -0
- requirements.txt +0 -0
- summarize.py +82 -0
- utils/clustering.py +64 -0
- utils/sentence_embedding.py +44 -0
app.py.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from summarize import *
|
3 |
+
from utils.sentence_embedding import *
|
4 |
+
from utils.clustering import *
|
5 |
+
from models.summarizers import *
|
6 |
+
from nltk.tokenize import sent_tokenize, word_tokenize
|
7 |
+
import math
|
8 |
+
from time import perf_counter
|
9 |
+
|
10 |
+
|
11 |
+
START = False
|
12 |
+
COMPLETED = False
|
13 |
+
PLACEHOLDER = "Enter your article"
|
14 |
+
|
15 |
+
st.markdown("Extractive Summarization for Large Articles 😊")
|
16 |
+
|
17 |
+
article = st.text_input(
|
18 |
+
label="Welcome, enter your article, press enter, and then Summarize",
|
19 |
+
value=PLACEHOLDER,
|
20 |
+
)
|
21 |
+
|
22 |
+
model_name = st.sidebar.selectbox(
|
23 |
+
label="Pick your model of choice:",
|
24 |
+
options=("BART", "Pegasus", "Distill-BART", "RoBERTa")
|
25 |
+
)
|
26 |
+
|
27 |
+
max_length = st.sidebar.slider(
|
28 |
+
label="Choose the maximum length of the summary",
|
29 |
+
min_value=100,
|
30 |
+
max_value=500,
|
31 |
+
value=250
|
32 |
+
)
|
33 |
+
|
34 |
+
min_length = st.sidebar.slider(
|
35 |
+
label="Choose the minimum length of the summary",
|
36 |
+
min_value=20,
|
37 |
+
max_value=150,
|
38 |
+
value=50
|
39 |
+
)
|
40 |
+
|
41 |
+
go = st.button(
|
42 |
+
label="Summarize",
|
43 |
+
key=0,
|
44 |
+
)
|
45 |
+
|
46 |
+
reset = st.button(
|
47 |
+
label="Reset",
|
48 |
+
key=1,
|
49 |
+
)
|
50 |
+
|
51 |
+
|
52 |
+
START = go
|
53 |
+
tmp_out = st.empty()
|
54 |
+
|
55 |
+
if reset:
|
56 |
+
COMPLETED = not reset
|
57 |
+
tmp_out.empty()
|
58 |
+
else:
|
59 |
+
COMPLETED = reset
|
60 |
+
|
61 |
+
|
62 |
+
bar = st.progress(0)
|
63 |
+
|
64 |
+
if START and not COMPLETED:
|
65 |
+
start_time = perf_counter()
|
66 |
+
|
67 |
+
with tmp_out.container():
|
68 |
+
st.write("Loading in models and preparing article...")
|
69 |
+
|
70 |
+
summarization_model, summarization_tokenizer = load_summarizer(model_name)
|
71 |
+
summarizer_token_limit = summarization_tokenizer.model_max_length
|
72 |
+
|
73 |
+
if "pegasus" in model_name.lower():
|
74 |
+
input_toks = sent_tokenize(article)
|
75 |
+
input_sent_toks = input_toks
|
76 |
+
input_word_toks = word_tokenize(article)
|
77 |
+
num_toks = len(input_toks)
|
78 |
+
else:
|
79 |
+
input_toks = word_tokenize(article)
|
80 |
+
input_word_toks = input_toks
|
81 |
+
input_sent_toks = sent_tokenize(article)
|
82 |
+
num_toks = len(input_toks)
|
83 |
+
|
84 |
+
bar.progress(15)
|
85 |
+
|
86 |
+
if num_toks <= summarizer_token_limit:
|
87 |
+
with tmp_out.container():
|
88 |
+
st.write("Input token count (",num_toks,") <= token limit (",summarizer_token_limit,"), skipping optimization ...")
|
89 |
+
|
90 |
+
pred_summary = summarize_input(article, summarization_model, summarization_tokenizer)
|
91 |
+
end_time = perf_counter()
|
92 |
+
time_taken = end_time - start_time
|
93 |
+
bar.progress(100)
|
94 |
+
|
95 |
+
else:
|
96 |
+
with tmp_out.container():
|
97 |
+
st.write("Input token count (",num_toks,") > token limit (",summarizer_token_limit,"), optimizing ...")
|
98 |
+
st.write(f"Going Beyong {model_name} Token limit:", summarizer_token_limit)
|
99 |
+
|
100 |
+
input_sent_toks = sent_tokenize(article)
|
101 |
+
embeddings = make_embeddings(input_sent_toks, mean_pooling)
|
102 |
+
embeddings = embeddings.numpy()
|
103 |
+
|
104 |
+
bar.progress(30)
|
105 |
+
|
106 |
+
n_clusters_estimate = math.ceil(num_toks / summarizer_token_limit)
|
107 |
+
|
108 |
+
clemb = ClusterEmbeddings(
|
109 |
+
cluster_estimate=n_clusters_estimate,
|
110 |
+
cluster_fn="agglo", # much better
|
111 |
+
embeddings=embeddings,
|
112 |
+
sentences=np.array(input_sent_toks),
|
113 |
+
words=np.array(input_word_toks)
|
114 |
+
)
|
115 |
+
|
116 |
+
bar.progress(50)
|
117 |
+
curr = 50
|
118 |
+
rem = 90 - curr
|
119 |
+
|
120 |
+
sentence_clusters = clemb.get_sentence_clusters()
|
121 |
+
|
122 |
+
n = len(sentence_clusters)
|
123 |
+
summs = ""
|
124 |
+
for cluster in sentence_clusters:
|
125 |
+
cluster_summary = summarize_input(
|
126 |
+
cluster,
|
127 |
+
summarization_model,
|
128 |
+
summarization_tokenizer,
|
129 |
+
max_length=250,
|
130 |
+
min_length=50,
|
131 |
+
)
|
132 |
+
if type(cluster_summary) == list:
|
133 |
+
cluster_summary = cluster_summary[0]
|
134 |
+
summs += cluster_summary + " "
|
135 |
+
|
136 |
+
inc = rem / n
|
137 |
+
bar.progress((curr + inc)/100)
|
138 |
+
|
139 |
+
bar.progress(90)
|
140 |
+
|
141 |
+
pred_summary = summarize_input(
|
142 |
+
summs,
|
143 |
+
summarization_model,
|
144 |
+
summarization_tokenizer,
|
145 |
+
max_length=max_length,
|
146 |
+
min_length=min_length,
|
147 |
+
)
|
148 |
+
|
149 |
+
bar.progress(100)
|
150 |
+
|
151 |
+
end_time = perf_counter()
|
152 |
+
time_taken = end_time - start_time
|
153 |
+
|
154 |
+
with tmp_out.container():
|
155 |
+
st.write(f"Took {time_taken} seconds")
|
156 |
+
st.write(f"Summary: {pred_summary}")
|
157 |
+
|
158 |
+
START = False
|
159 |
+
COMPLETED = True
|
160 |
+
|
161 |
+
else:
|
162 |
+
pass
|
models/__pycache__/summarizers.cpython-39.pyc
ADDED
Binary file (1.58 kB). View file
|
|
models/summarizers.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import BartTokenizer, BartForConditionalGeneration
|
2 |
+
from transformers import PegasusForConditionalGeneration, PegasusTokenizer
|
3 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
4 |
+
|
5 |
+
|
6 |
+
def load_summarizer(model_code):
|
7 |
+
name_dict = {
|
8 |
+
"bart": "facebook/bart-large-cnn",
|
9 |
+
"distill-bart": "sshleifer/distilbart-cnn-12-6",
|
10 |
+
"roberta": "google/roberta2roberta_L-24_cnn_daily_mail",
|
11 |
+
"pegasus": "google/pegasus-cnn_dailymail"
|
12 |
+
}
|
13 |
+
|
14 |
+
model_name = name_dict[model_code.lower()]
|
15 |
+
model, tokenizer = None, None
|
16 |
+
|
17 |
+
if "bart" in model_name:
|
18 |
+
tokenizer = BartTokenizer.from_pretrained(model_name)
|
19 |
+
model = BartForConditionalGeneration.from_pretrained(model_name)
|
20 |
+
|
21 |
+
if "pegasus" in model_name:
|
22 |
+
tokenizer = PegasusTokenizer.from_pretrained(model_name)
|
23 |
+
model = PegasusForConditionalGeneration.from_pretrained(model_name)
|
24 |
+
|
25 |
+
if "roberta" in model_name:
|
26 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
27 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
28 |
+
|
29 |
+
return model, tokenizer
|
30 |
+
|
31 |
+
def summarize_input(
|
32 |
+
input_article,
|
33 |
+
model,
|
34 |
+
tokenizer,
|
35 |
+
max_length=150,
|
36 |
+
min_length=50,
|
37 |
+
num_beams=3,
|
38 |
+
length_penalty=0.5,
|
39 |
+
no_repeat_ngram_size=3
|
40 |
+
):
|
41 |
+
text_input_ids = tokenizer.batch_encode_plus(
|
42 |
+
[input_article],
|
43 |
+
return_tensors='pt',
|
44 |
+
max_length=tokenizer.model_max_length
|
45 |
+
)['input_ids'].to("cpu")
|
46 |
+
|
47 |
+
summary_ids = model.generate(
|
48 |
+
text_input_ids,
|
49 |
+
num_beams=int(num_beams),
|
50 |
+
length_penalty=float(length_penalty),
|
51 |
+
max_length=int(max_length),
|
52 |
+
min_length=int(min_length),
|
53 |
+
no_repeat_ngram_size=int(no_repeat_ngram_size)
|
54 |
+
)
|
55 |
+
summary_txt = tokenizer.decode(summary_ids.squeeze(), skip_special_tokens=True)
|
56 |
+
return summary_txt.replace("<n>", "")
|
requirements.txt
ADDED
Binary file (9.19 kB). View file
|
|
summarize.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from utils.sentence_embedding import *
|
2 |
+
from utils.clustering import *
|
3 |
+
from models.summarizers import *
|
4 |
+
from nltk.tokenize import sent_tokenize, word_tokenize
|
5 |
+
import math
|
6 |
+
from time import perf_counter
|
7 |
+
import time
|
8 |
+
|
9 |
+
|
10 |
+
def get_summary(model_name, article, max_length, min_length, increment):
|
11 |
+
start_time = perf_counter()
|
12 |
+
summarization_model, summarization_tokenizer = load_summarizer(model_name)
|
13 |
+
summarizer_token_limit = summarization_tokenizer.model_max_length
|
14 |
+
print("Going Beyong Token limit:", summarizer_token_limit)
|
15 |
+
|
16 |
+
input_word_toks = word_tokenize(article)
|
17 |
+
num_words = len(input_word_toks)
|
18 |
+
|
19 |
+
if num_words <= summarizer_token_limit and model_name == "t5":
|
20 |
+
pred_summary = summarize_input(article, summarization_model, summarization_tokenizer)
|
21 |
+
end_time = perf_counter()
|
22 |
+
print("Time taken: ", end_time - start_time)
|
23 |
+
|
24 |
+
else:
|
25 |
+
input_sent_toks = sent_tokenize(article)
|
26 |
+
embeddings = make_embeddings(input_sent_toks, mean_pooling)
|
27 |
+
embeddings = embeddings.numpy()
|
28 |
+
|
29 |
+
increment[0] = 20
|
30 |
+
|
31 |
+
n_clusters_estimate = math.ceil(num_words / summarizer_token_limit)
|
32 |
+
|
33 |
+
clemb = ClusterEmbeddings(
|
34 |
+
cluster_estimate=n_clusters_estimate,
|
35 |
+
cluster_fn="agglo", # much better
|
36 |
+
embeddings=embeddings,
|
37 |
+
sentences=np.array(input_sent_toks),
|
38 |
+
words=np.array(input_word_toks)
|
39 |
+
)
|
40 |
+
|
41 |
+
increment[0] = 50
|
42 |
+
|
43 |
+
sentence_clusters = clemb.get_sentence_clusters()
|
44 |
+
|
45 |
+
n = len(sentence_clusters)
|
46 |
+
summs = ""
|
47 |
+
for cluster in sentence_clusters:
|
48 |
+
cluster_summary = summarize_input(
|
49 |
+
cluster,
|
50 |
+
summarization_model,
|
51 |
+
summarization_tokenizer,
|
52 |
+
max_length=250,
|
53 |
+
min_length=50,
|
54 |
+
)
|
55 |
+
if type(cluster_summary) == list:
|
56 |
+
cluster_summary = cluster_summary[0]
|
57 |
+
summs += cluster_summary + " "
|
58 |
+
|
59 |
+
increment[0] += 40 / n
|
60 |
+
|
61 |
+
pred_summary = summarize_input(
|
62 |
+
summs,
|
63 |
+
summarization_model,
|
64 |
+
summarization_tokenizer,
|
65 |
+
max_length=max_length,
|
66 |
+
min_length=min_length,
|
67 |
+
)
|
68 |
+
|
69 |
+
increment[0] += 100
|
70 |
+
|
71 |
+
end_time = perf_counter()
|
72 |
+
time_taken = end_time - start_time
|
73 |
+
|
74 |
+
return pred_summary, time_taken
|
75 |
+
|
76 |
+
def test():
|
77 |
+
article = """Recent text-to-image matching models apply contrastive learning to large corpora of uncurated pairs of images and sentences. While such models can provide a powerful score for matching and subsequent zero-shot tasks, they are not capable of generating caption given an image. In this work, we repurpose such models to generate a descriptive text given an image at inference time, without any further training or tuning step. This is done by combining the visual-semantic model with a large language model, benefiting from the knowledge in both web-scale models. The resulting captions are much less restrictive than those obtained by supervised captioning methods. Moreover, as a zero-shot learning method, it is extremely flexible and wedemonstrate its ability to perform image arithmetic in which the inputs can be either images or text and the output is a sentence."""
|
78 |
+
model_name = "BART"
|
79 |
+
summ, time_taken = get_summary(model_name, article, 250, 150)
|
80 |
+
print(summ)
|
81 |
+
print(time_taken)
|
82 |
+
|
utils/clustering.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from sklearn.cluster import AgglomerativeClustering, KMeans
|
3 |
+
from sklearn.manifold import TSNE
|
4 |
+
import numpy as np
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
|
7 |
+
|
8 |
+
class ClusterEmbeddings():
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
cluster_estimate,
|
12 |
+
cluster_fn,
|
13 |
+
embeddings,
|
14 |
+
sentences,
|
15 |
+
words
|
16 |
+
):
|
17 |
+
self.cluster_estimate = cluster_estimate
|
18 |
+
self.embeddings = embeddings
|
19 |
+
self.sentences = sentences
|
20 |
+
self.words = words
|
21 |
+
|
22 |
+
self.cluster_fn = cluster_fn
|
23 |
+
if self.cluster_fn == "agglo":
|
24 |
+
self.clustering_algo = AgglomerativeClustering(n_clusters=self.cluster_estimate)
|
25 |
+
self.num_clusters = cluster_estimate
|
26 |
+
|
27 |
+
elif self.cluster_fn == "kmeans":
|
28 |
+
self.clustering_algo = KMeans(n_clusters=self.cluster_estimate)
|
29 |
+
self.num_clusters = cluster_estimate
|
30 |
+
|
31 |
+
self.cluster = self.clustering_algo.fit(embeddings)
|
32 |
+
self.labels = self.cluster.labels_
|
33 |
+
|
34 |
+
def get_sentence_clusters(self):
|
35 |
+
sent_clusters = []
|
36 |
+
chunk = ""
|
37 |
+
|
38 |
+
for lbl in range(self.num_clusters):
|
39 |
+
single_cluster = self.sentences[self.labels == lbl]
|
40 |
+
for sent in single_cluster:
|
41 |
+
chunk += sent + " "
|
42 |
+
sent_clusters.append(chunk)
|
43 |
+
chunk = ""
|
44 |
+
|
45 |
+
return np.array(sent_clusters)
|
46 |
+
|
47 |
+
def make_plot(self):
|
48 |
+
projector = TSNE(
|
49 |
+
n_components=2,
|
50 |
+
learning_rate="auto",
|
51 |
+
init="random"
|
52 |
+
)
|
53 |
+
proj_embeddings = np.array(
|
54 |
+
projector.fit_transform(self.embeddings)
|
55 |
+
)
|
56 |
+
|
57 |
+
for lbl in range(self.num_clusters):
|
58 |
+
xs = proj_embeddings[self.labels == lbl]
|
59 |
+
plt.scatter(xs[:, 0], xs[:, 1], label=f"Cluster {lbl}")
|
60 |
+
|
61 |
+
plt.legend()
|
62 |
+
plt.xlabel("x1")
|
63 |
+
plt.ylabel("x2")
|
64 |
+
plt.show()
|
utils/sentence_embedding.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
cwd = os.getcwd()
|
4 |
+
module2add = '\\'.join(cwd.split("\\")[:-1])
|
5 |
+
sys.path.append(module2add)
|
6 |
+
|
7 |
+
from configs.model_config import cfg as model_configs
|
8 |
+
|
9 |
+
from transformers import AutoTokenizer, AutoModel
|
10 |
+
import torch
|
11 |
+
|
12 |
+
|
13 |
+
def mean_pooling(model_output, attention_mask):
|
14 |
+
token_embeddings = model_output[0]
|
15 |
+
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
16 |
+
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
|
17 |
+
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
18 |
+
return sum_embeddings / sum_mask
|
19 |
+
|
20 |
+
def make_embeddings(sentence_list, pool_fn):
|
21 |
+
tokenizer = AutoTokenizer.from_pretrained(model_configs.sent_model_name)
|
22 |
+
model = AutoModel.from_pretrained(model_configs.sent_model_name)
|
23 |
+
|
24 |
+
encoded_input = tokenizer(
|
25 |
+
sentence_list,
|
26 |
+
padding=True,
|
27 |
+
truncation=True,
|
28 |
+
max_length=model_configs.sent_model_seq_limit,
|
29 |
+
return_tensors='pt'
|
30 |
+
)
|
31 |
+
with torch.no_grad():
|
32 |
+
embeddings = model(**encoded_input)
|
33 |
+
|
34 |
+
attn_mask = encoded_input['attention_mask']
|
35 |
+
sentence_embeddings = pool_fn(embeddings, attn_mask)
|
36 |
+
return sentence_embeddings
|
37 |
+
|
38 |
+
def test_embedder():
|
39 |
+
sentences = ['This framework generates embeddings for each input sentence',
|
40 |
+
'Sentences are passed as a list of string.',
|
41 |
+
'The quick brown fox jumps over the lazy dog.']
|
42 |
+
|
43 |
+
embeddings = make_embeddings(sentences)
|
44 |
+
print(embeddings.shape)
|