stephenleo commited on
Commit
d9f2adf
·
1 Parent(s): 576be81

many optimizations for streamlit

Browse files
Files changed (3) hide show
  1. .gitignore +1 -0
  2. app.py +42 -17
  3. helpers.py +79 -52
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__
app.py CHANGED
@@ -2,12 +2,22 @@ import networkx as nx
2
  from streamlit.components.v1 import html
3
  import streamlit as st
4
  import helpers
 
 
 
5
  st.set_page_config(layout='wide',
6
  page_title='STriP: Semantic Similarity of Scientific Papers!',
7
  page_icon='💡'
8
  )
9
 
10
 
 
 
 
 
 
 
 
11
  def main():
12
  st.title('STriP (S3P): Semantic Similarity of Scientific Papers!')
13
 
@@ -18,39 +28,51 @@ def main():
18
  ##########
19
  # Load data
20
  ##########
 
21
  if uploaded_file is not None:
22
  df = helpers.load_data(uploaded_file)
23
  else:
24
  df = helpers.load_data('data.csv')
25
 
26
  data = df.copy()
 
 
 
 
 
27
  st.write(f'Number of papers: {len(data)}')
28
  st.write('First 5 rows of loaded data:')
29
- st.write(data[['Title', 'Abstract']].head())
 
 
 
 
 
 
30
 
31
- if data is not None:
32
  ##########
33
  # Topic modeling
34
  ##########
 
35
  st.header('🔥 Topic Modeling')
36
 
37
  cols = st.columns(3)
38
  with cols[0]:
39
  min_topic_size = st.slider('Minimum topic size', key='min_topic_size', min_value=2,
40
- max_value=int(len(data)/3), step=1, value=3,
41
  help='The minimum size of the topic. Increasing this value will lead to a lower number of clusters/topics.')
42
  with cols[1]:
43
  n_gram_range = st.slider('N-gram range', key='n_gram_range', min_value=1,
44
- max_value=4, step=1, value=(1, 3),
45
  help='N-gram range for the topic model')
46
  with cols[2]:
47
  st.text('')
48
  st.text('')
49
  st.button('Reset Defaults', on_click=helpers.reset_default_topic_sliders, key='reset_topic_sliders',
50
- kwargs={'min_topic_size': 3, 'n_gram_range': (1, 3)})
51
 
52
  with st.spinner('Topic Modeling'):
53
- data, topic_model, topics = helpers.topic_modeling(
54
  data, min_topic_size=min_topic_size, n_gram_range=n_gram_range)
55
 
56
  mapping = {
@@ -65,7 +87,7 @@ def main():
65
  topic_model_vis_option = st.selectbox(
66
  'Select Topic Modeling Visualization', mapping.keys())
67
  try:
68
- fig = mapping[topic_model_vis_option]()
69
  fig.update_layout(title='')
70
  st.plotly_chart(fig, use_container_width=True)
71
  except:
@@ -75,18 +97,18 @@ def main():
75
  ##########
76
  # STriP Network
77
  ##########
 
78
  st.header('🚀 STriP Network')
79
 
80
- with st.spinner('Embedding generation'):
81
- data = helpers.embeddings(data)
82
-
83
  with st.spinner('Cosine Similarity Calculation'):
84
  cosine_sim_matrix = helpers.cosine_sim(data)
85
 
86
- min_value, value = helpers.calc_optimal_threshold(
87
  cosine_sim_matrix,
88
  # 25% is a good value for the number of papers
89
- max_connections=helpers.calc_max_connections(len(data), 0.25)
 
 
90
  )
91
 
92
  cols = st.columns(3)
@@ -107,7 +129,7 @@ def main():
107
 
108
  with st.spinner('Network Generation'):
109
  nx_net, pyvis_net = helpers.network_plot(
110
- data, topics, neighbors)
111
 
112
  # Save and read graph as HTML file (on Streamlit Sharing)
113
  try:
@@ -129,6 +151,7 @@ def main():
129
  ##########
130
  # Centrality
131
  ##########
 
132
  st.header('🏅 Most Important Papers')
133
 
134
  centrality_mapping = {
@@ -146,10 +169,12 @@ def main():
146
  # Calculate centrality
147
  centrality = centrality_mapping[centrality_option](nx_net)
148
 
149
- with st.spinner('Network Centrality Calculation'):
150
- fig = helpers.network_centrality(
151
- data, centrality, centrality_option)
152
- st.plotly_chart(fig, use_container_width=True)
 
 
153
 
154
  st.markdown(
155
  """
 
2
  from streamlit.components.v1 import html
3
  import streamlit as st
4
  import helpers
5
+ import logging
6
+
7
+ # Setup Basic Configuration
8
  st.set_page_config(layout='wide',
9
  page_title='STriP: Semantic Similarity of Scientific Papers!',
10
  page_icon='💡'
11
  )
12
 
13
 
14
+ logging.basicConfig(level=logging.INFO,
15
+ format='%(asctime)s %(levelname)s: %(message)s',
16
+ datefmt='%Y-%m-%d %H:%M:%S')
17
+
18
+ logger = logging.getLogger('main')
19
+
20
+
21
  def main():
22
  st.title('STriP (S3P): Semantic Similarity of Scientific Papers!')
23
 
 
28
  ##########
29
  # Load data
30
  ##########
31
+ logger.info('========== Step1: Loading data ==========')
32
  if uploaded_file is not None:
33
  df = helpers.load_data(uploaded_file)
34
  else:
35
  df = helpers.load_data('data.csv')
36
 
37
  data = df.copy()
38
+ selected_cols = st.multiselect('Select columns to analyse', options=data.columns,
39
+ default=[col for col in data.columns if col.lower() in ['title', 'abstract']])
40
+ data = data[selected_cols]
41
+ data = data.dropna()
42
+ data = data.reset_index(drop=True)
43
  st.write(f'Number of papers: {len(data)}')
44
  st.write('First 5 rows of loaded data:')
45
+ st.write(data[selected_cols].head())
46
+
47
+ if (data is not None) and selected_cols:
48
+ # For 'allenai-specter'
49
+ data['Text'] = data[data.columns[0]]
50
+ for column in data.columns[1:]:
51
+ data['Text'] = data['Text'] + '[SEP]' + data[column].astype(str)
52
 
 
53
  ##########
54
  # Topic modeling
55
  ##########
56
+ logger.info('========== Step2: Topic modeling ==========')
57
  st.header('🔥 Topic Modeling')
58
 
59
  cols = st.columns(3)
60
  with cols[0]:
61
  min_topic_size = st.slider('Minimum topic size', key='min_topic_size', min_value=2,
62
+ max_value=round(len(data)*0.25), step=1, value=min(round(len(data)/25), 10),
63
  help='The minimum size of the topic. Increasing this value will lead to a lower number of clusters/topics.')
64
  with cols[1]:
65
  n_gram_range = st.slider('N-gram range', key='n_gram_range', min_value=1,
66
+ max_value=3, step=1, value=(1, 2),
67
  help='N-gram range for the topic model')
68
  with cols[2]:
69
  st.text('')
70
  st.text('')
71
  st.button('Reset Defaults', on_click=helpers.reset_default_topic_sliders, key='reset_topic_sliders',
72
+ kwargs={'min_topic_size': min(round(len(data)/25), 10), 'n_gram_range': (1, 2)})
73
 
74
  with st.spinner('Topic Modeling'):
75
+ topic_data, topic_model, topics = helpers.topic_modeling(
76
  data, min_topic_size=min_topic_size, n_gram_range=n_gram_range)
77
 
78
  mapping = {
 
87
  topic_model_vis_option = st.selectbox(
88
  'Select Topic Modeling Visualization', mapping.keys())
89
  try:
90
+ fig = mapping[topic_model_vis_option](top_n_topics=10)
91
  fig.update_layout(title='')
92
  st.plotly_chart(fig, use_container_width=True)
93
  except:
 
97
  ##########
98
  # STriP Network
99
  ##########
100
+ logger.info('========== Step3: STriP Network ==========')
101
  st.header('🚀 STriP Network')
102
 
 
 
 
103
  with st.spinner('Cosine Similarity Calculation'):
104
  cosine_sim_matrix = helpers.cosine_sim(data)
105
 
106
+ value, min_value = helpers.calc_optimal_threshold(
107
  cosine_sim_matrix,
108
  # 25% is a good value for the number of papers
109
+ max_connections=min(
110
+ helpers.calc_max_connections(len(data), 0.25), 5_000
111
+ )
112
  )
113
 
114
  cols = st.columns(3)
 
129
 
130
  with st.spinner('Network Generation'):
131
  nx_net, pyvis_net = helpers.network_plot(
132
+ topic_data, topics, neighbors)
133
 
134
  # Save and read graph as HTML file (on Streamlit Sharing)
135
  try:
 
151
  ##########
152
  # Centrality
153
  ##########
154
+ logger.info('========== Step4: Network Centrality ==========')
155
  st.header('🏅 Most Important Papers')
156
 
157
  centrality_mapping = {
 
169
  # Calculate centrality
170
  centrality = centrality_mapping[centrality_option](nx_net)
171
 
172
+ cols = st.columns([1, 10, 1])
173
+ with cols[1]:
174
+ with st.spinner('Network Centrality Calculation'):
175
+ fig = helpers.network_centrality(
176
+ topic_data, centrality, centrality_option)
177
+ st.plotly_chart(fig, use_container_width=True)
178
 
179
  st.markdown(
180
  """
helpers.py CHANGED
@@ -8,6 +8,10 @@ from sklearn.feature_extraction.text import CountVectorizer
8
  import pandas as pd
9
  import numpy as np
10
  import networkx as nx
 
 
 
 
11
 
12
 
13
  def reset_default_topic_sliders(min_topic_size, n_gram_range):
@@ -19,61 +23,60 @@ def reset_default_threshold_slider(threshold):
19
  st.session_state['threshold'] = threshold
20
 
21
 
22
- @st.cache(allow_output_mutation=True)
23
- def load_sbert_model():
24
- return SentenceTransformer('allenai-specter')
25
-
26
-
27
  @st.cache()
28
  def load_data(uploaded_file):
29
  data = pd.read_csv(uploaded_file)
30
 
31
- data = data[['Title', 'Abstract']]
32
- data = data.dropna()
33
- data = data.reset_index(drop=True)
34
-
35
  return data
36
 
37
 
38
- @st.cache(allow_output_mutation=True)
39
- def topic_modeling(data, min_topic_size, n_gram_range):
40
- """Topic modeling using BERTopic
41
- """
42
- topic_model = BERTopic(
43
- embedding_model=load_sbert_model(),
 
 
 
 
44
  vectorizer_model=CountVectorizer(
45
- stop_words='english', ngram_range=n_gram_range),
46
- min_topic_size=min_topic_size
 
 
47
  )
48
 
49
- # For 'allenai-specter'
50
- data['Title + Abstract'] = data['Title'] + '[SEP]' + data['Abstract']
 
 
 
 
 
51
 
52
  # Train the topic model
53
- data["Topic"], data["Probs"] = topic_model.fit_transform(
54
- data['Title + Abstract'])
 
55
 
56
  # Merge topic results
57
- topic_df = topic_model.get_topic_info()[['Topic', 'Name']]
58
- data = data.merge(topic_df, on='Topic', how='left')
 
 
59
 
60
  # Topics
61
- topics = topic_df.set_index('Topic').to_dict(orient='index')
62
-
63
- return data, topic_model, topics
64
 
65
-
66
- @st.cache(allow_output_mutation=True)
67
- def embeddings(data):
68
- data['embedding'] = load_sbert_model().encode(
69
- data['Title + Abstract']).tolist()
70
-
71
- return data
72
 
73
 
74
  @st.cache()
75
  def cosine_sim(data):
76
- cosine_sim_matrix = cosine_similarity(data['embedding'].values.tolist())
 
77
 
78
  # Take only upper triangular matrix
79
  cosine_sim_matrix = np.triu(cosine_sim_matrix, k=1)
@@ -93,10 +96,11 @@ def calc_optimal_threshold(cosine_sim_matrix, max_connections):
93
  """Calculates the optimal threshold for the cosine similarity matrix.
94
  Allows a max of max_connections
95
  """
96
- thresh_sweep = np.arange(0.05, 1.05, 0.05)
 
97
  for idx, threshold in enumerate(thresh_sweep):
98
  neighbors = np.argwhere(cosine_sim_matrix >= threshold).tolist()
99
- if len(neighbors) < max_connections:
100
  break
101
 
102
  return round(thresh_sweep[idx-1], 2).item(), round(thresh_sweep[idx], 2).item()
@@ -104,6 +108,7 @@ def calc_optimal_threshold(cosine_sim_matrix, max_connections):
104
 
105
  @st.cache()
106
  def calc_neighbors(cosine_sim_matrix, threshold):
 
107
  neighbors = np.argwhere(cosine_sim_matrix >= threshold).tolist()
108
 
109
  return neighbors, len(neighbors)
@@ -122,9 +127,10 @@ def pyvis_hash_func(pyvis_net):
122
 
123
 
124
  @st.cache(hash_funcs={nx.Graph: nx_hash_func, Network: pyvis_hash_func})
125
- def network_plot(data, topics, neighbors):
126
  """Creates a network plot of connected papers. Colored by Topic Model topics.
127
  """
 
128
  nx_net = nx.Graph()
129
  pyvis_net = Network(height='750px', width='100%', bgcolor='#222222')
130
 
@@ -135,14 +141,21 @@ def network_plot(data, topics, neighbors):
135
  {
136
  'group': row.Topic,
137
  'label': row.Index,
138
- 'title': row.Title,
139
  'size': 20, 'font': {'size': 20, 'color': 'white'}
140
  }
141
  )
142
- for row in data.itertuples()
143
  ]
144
  nx_net.add_nodes_from(nodes)
145
- assert(nx_net.number_of_nodes() == len(data))
 
 
 
 
 
 
 
146
 
147
  # Add Legend Nodes
148
  step = 150
@@ -150,9 +163,9 @@ def network_plot(data, topics, neighbors):
150
  y = -500
151
  legend_nodes = [
152
  (
153
- len(data)+idx,
154
  {
155
- 'group': key, 'label': ', '.join(value['Name'].split('_')[1:]),
156
  'size': 30, 'physics': False, 'x': x, 'y': f'{y + idx*step}px',
157
  # , 'fixed': True,
158
  'shape': 'box', 'widthConstraint': 1000, 'font': {'size': 40, 'color': 'black'}
@@ -162,33 +175,47 @@ def network_plot(data, topics, neighbors):
162
  ]
163
  nx_net.add_nodes_from(legend_nodes)
164
 
165
- # Add Edges
166
- nx_net.add_edges_from(neighbors)
167
- assert(nx_net.number_of_edges() == len(neighbors))
168
-
169
  # Plot the Pyvis graph
170
  pyvis_net.from_nx(nx_net)
171
 
172
  return nx_net, pyvis_net
173
 
174
 
 
 
 
 
 
 
 
 
175
  @st.cache()
176
- def network_centrality(data, centrality, centrality_option):
177
  """Calculates the centrality of the network
178
  """
 
179
  # Sort Top 10 Central nodes
180
  central_nodes = sorted(
181
  centrality.items(), key=lambda item: item[1], reverse=True)
182
  central_nodes = pd.DataFrame(central_nodes, columns=[
183
  'node', centrality_option]).set_index('node')
184
 
185
- joined_data = data.join(central_nodes)
 
186
  top_central_nodes = joined_data.sort_values(
187
  centrality_option, ascending=False).head(10)
188
 
 
 
 
 
 
 
 
 
189
  # Plot the Top 10 Central nodes
190
- fig = px.bar(top_central_nodes, x=centrality_option, y='Title')
191
- fig.update_layout(yaxis={'categoryorder': 'total ascending'},
192
- font={'size': 15},
193
- height=800, width=800)
194
  return fig
 
8
  import pandas as pd
9
  import numpy as np
10
  import networkx as nx
11
+ import textwrap
12
+ import logging
13
+
14
+ logger = logging.getLogger('main')
15
 
16
 
17
  def reset_default_topic_sliders(min_topic_size, n_gram_range):
 
23
  st.session_state['threshold'] = threshold
24
 
25
 
 
 
 
 
 
26
  @st.cache()
27
  def load_data(uploaded_file):
28
  data = pd.read_csv(uploaded_file)
29
 
 
 
 
 
30
  return data
31
 
32
 
33
+ @st.cache()
34
+ def embedding_gen(data):
35
+ logger.info('Calculating Embeddings')
36
+ return SentenceTransformer('allenai-specter').encode(data['Text'])
37
+
38
+
39
+ @st.cache()
40
+ def load_bertopic_model(min_topic_size, n_gram_range):
41
+ logger.info('Loading BERTopic model')
42
+ return BERTopic(
43
  vectorizer_model=CountVectorizer(
44
+ stop_words='english', ngram_range=n_gram_range
45
+ ),
46
+ min_topic_size=min_topic_size,
47
+ verbose=True
48
  )
49
 
50
+
51
+ @st.cache()
52
+ def topic_modeling(data, min_topic_size, n_gram_range):
53
+ """Topic modeling using BERTopic
54
+ """
55
+ logger.info('Calculating Topic Model')
56
+ topic_model = load_bertopic_model(min_topic_size, n_gram_range)
57
 
58
  # Train the topic model
59
+ topic_data = data.copy()
60
+ topic_data["Topic"], topic_data["Probs"] = topic_model.fit_transform(
61
+ data['Text'], embeddings=embedding_gen(data))
62
 
63
  # Merge topic results
64
+ topic_df = topic_model.get_topic_info()
65
+ topic_df.columns = ['Topic', 'Topic_Count', 'Topic_Name']
66
+ topic_df = topic_df.sort_values(by='Topic_Count', ascending=False)
67
+ topic_data = topic_data.merge(topic_df, on='Topic', how='left')
68
 
69
  # Topics
70
+ # Optimization: Only take top 10 largest topics
71
+ topics = topic_df.head(10).set_index('Topic').to_dict(orient='index')
 
72
 
73
+ return topic_data, topic_model, topics
 
 
 
 
 
 
74
 
75
 
76
  @st.cache()
77
  def cosine_sim(data):
78
+ logger.info('Cosine similarity')
79
+ cosine_sim_matrix = cosine_similarity(embedding_gen(data))
80
 
81
  # Take only upper triangular matrix
82
  cosine_sim_matrix = np.triu(cosine_sim_matrix, k=1)
 
96
  """Calculates the optimal threshold for the cosine similarity matrix.
97
  Allows a max of max_connections
98
  """
99
+ logger.info('Calculating optimal threshold')
100
+ thresh_sweep = np.arange(0.05, 1.05, 0.05)[::-1]
101
  for idx, threshold in enumerate(thresh_sweep):
102
  neighbors = np.argwhere(cosine_sim_matrix >= threshold).tolist()
103
+ if len(neighbors) > max_connections:
104
  break
105
 
106
  return round(thresh_sweep[idx-1], 2).item(), round(thresh_sweep[idx], 2).item()
 
108
 
109
  @st.cache()
110
  def calc_neighbors(cosine_sim_matrix, threshold):
111
+ logger.info('Calculating neighbors')
112
  neighbors = np.argwhere(cosine_sim_matrix >= threshold).tolist()
113
 
114
  return neighbors, len(neighbors)
 
127
 
128
 
129
  @st.cache(hash_funcs={nx.Graph: nx_hash_func, Network: pyvis_hash_func})
130
+ def network_plot(topic_data, topics, neighbors):
131
  """Creates a network plot of connected papers. Colored by Topic Model topics.
132
  """
133
+ logger.info('Calculating Network Plot')
134
  nx_net = nx.Graph()
135
  pyvis_net = Network(height='750px', width='100%', bgcolor='#222222')
136
 
 
141
  {
142
  'group': row.Topic,
143
  'label': row.Index,
144
+ 'title': row.Text,
145
  'size': 20, 'font': {'size': 20, 'color': 'white'}
146
  }
147
  )
148
+ for row in topic_data.itertuples()
149
  ]
150
  nx_net.add_nodes_from(nodes)
151
+ assert(nx_net.number_of_nodes() == len(topic_data))
152
+
153
+ # Add Edges
154
+ nx_net.add_edges_from(neighbors)
155
+ assert(nx_net.number_of_edges() == len(neighbors))
156
+
157
+ # Optimization: Remove Isolated nodes
158
+ nx_net.remove_nodes_from(list(nx.isolates(nx_net)))
159
 
160
  # Add Legend Nodes
161
  step = 150
 
163
  y = -500
164
  legend_nodes = [
165
  (
166
+ len(topic_data)+idx,
167
  {
168
+ 'group': key, 'label': ', '.join(value['Topic_Name'].split('_')[1:]),
169
  'size': 30, 'physics': False, 'x': x, 'y': f'{y + idx*step}px',
170
  # , 'fixed': True,
171
  'shape': 'box', 'widthConstraint': 1000, 'font': {'size': 40, 'color': 'black'}
 
175
  ]
176
  nx_net.add_nodes_from(legend_nodes)
177
 
 
 
 
 
178
  # Plot the Pyvis graph
179
  pyvis_net.from_nx(nx_net)
180
 
181
  return nx_net, pyvis_net
182
 
183
 
184
+ def text_processing(text):
185
+ text = text.split('[SEP]')
186
+ text = '<br><br>'.join(text)
187
+ text = '<br>'.join(textwrap.wrap(text, width=50))[:500]
188
+ text = text + '...'
189
+ return text
190
+
191
+
192
  @st.cache()
193
+ def network_centrality(topic_data, centrality, centrality_option):
194
  """Calculates the centrality of the network
195
  """
196
+ logger.info('Calculating Network Centrality')
197
  # Sort Top 10 Central nodes
198
  central_nodes = sorted(
199
  centrality.items(), key=lambda item: item[1], reverse=True)
200
  central_nodes = pd.DataFrame(central_nodes, columns=[
201
  'node', centrality_option]).set_index('node')
202
 
203
+ joined_data = topic_data.join(central_nodes)
204
+
205
  top_central_nodes = joined_data.sort_values(
206
  centrality_option, ascending=False).head(10)
207
 
208
+ # Prepare for plot
209
+ top_central_nodes = top_central_nodes.reset_index()
210
+ top_central_nodes['index'] = top_central_nodes['index'].astype(str)
211
+ top_central_nodes['Topic_Name'] = top_central_nodes['Topic_Name'].apply(
212
+ lambda x: ', '.join(x.split('_')[1:]))
213
+ top_central_nodes['Text'] = top_central_nodes['Text'].apply(
214
+ text_processing)
215
+
216
  # Plot the Top 10 Central nodes
217
+ fig = px.bar(top_central_nodes, x=centrality_option, y='index',
218
+ color='Topic_Name', hover_data=['Text'], orientation='h')
219
+ fig.update_layout(yaxis={'categoryorder': 'total ascending', 'visible': False, 'showticklabels': False},
220
+ font={'size': 15}, height=800)
221
  return fig