DeDeckerThomas commited on
Commit
b96bd14
·
1 Parent(s): 31decce

Update HF space

Browse files
Files changed (3) hide show
  1. app.py +58 -89
  2. css/style.css +19 -0
  3. requirements.txt +0 -1
app.py CHANGED
@@ -5,10 +5,8 @@ from pipelines.keyphrase_generation_pipeline import KeyphraseGenerationPipeline
5
  import orjson
6
 
7
  from annotated_text.util import get_annotated_html
8
- from st_aggrid import AgGrid, GridOptionsBuilder, GridUpdateMode
9
  import re
10
  import string
11
- import numpy as np
12
 
13
 
14
  @st.cache(allow_output_mutation=True, show_spinner=False)
@@ -21,28 +19,13 @@ def load_pipeline(chosen_model):
21
 
22
  def extract_keyphrases():
23
  st.session_state.keyphrases = pipe(st.session_state.input_text)
24
- st.session_state.data_frame = pd.concat(
25
- [
26
- st.session_state.data_frame,
27
- pd.DataFrame(
28
- data=[
29
- np.concatenate(
30
- (
31
- [
32
- st.session_state.chosen_model,
33
- st.session_state.input_text,
34
- ],
35
- st.session_state.keyphrases,
36
- )
37
- )
38
- ],
39
- columns=["model", "text"]
40
- + [str(i) for i in range(len(st.session_state.keyphrases))],
41
- ),
42
- ],
43
- ignore_index=True,
44
- axis=0,
45
- ).fillna("")
46
 
47
 
48
  def get_annotated_text(text, keyphrases):
@@ -90,51 +73,36 @@ def get_annotated_text(text, keyphrases):
90
  return result
91
 
92
 
93
- def rerender_output(layout):
94
- layout.write("⚙️ Output")
95
- if (
96
- len(st.session_state.keyphrases) > 0
97
- and len(st.session_state.selected_rows) == 0
98
- ):
99
- text, keyphrases = st.session_state.input_text, st.session_state.keyphrases
100
- else:
101
- text, keyphrases = (
102
- st.session_state.selected_rows["text"].values[0],
103
- [
104
- keyphrase
105
- for keyphrase in st.session_state.selected_rows.loc[
106
- :,
107
- st.session_state.selected_rows.columns.difference(
108
- ["model", "text"]
109
- ),
110
- ]
111
- .astype(str)
112
- .values.tolist()[0]
113
- if keyphrase != ""
114
- ],
115
- )
116
 
117
- result = get_annotated_text(text, list(keyphrases))
 
118
 
119
- layout.markdown(
120
- get_annotated_html(*result),
121
- unsafe_allow_html=True,
122
- )
123
- if "generation" in st.session_state.chosen_model:
124
- abstractive_keyphrases = [
125
- keyphrase
126
- for keyphrase in keyphrases
127
- if keyphrase.lower() not in text.lower()
128
- ]
129
- layout.write(", ".join(abstractive_keyphrases))
 
130
 
131
 
132
  if "config" not in st.session_state:
133
  with open("config.json", "r") as f:
134
  content = f.read()
135
  st.session_state.config = orjson.loads(content)
136
- st.session_state.data_frame = pd.DataFrame(columns=["model"])
137
  st.session_state.keyphrases = []
 
 
138
 
139
  if "select_rows" not in st.session_state:
140
  st.session_state.selected_rows = []
@@ -177,42 +145,43 @@ Do you want to see some magic 🧙‍♂️? Try it out yourself! 👇
177
 
178
  st.write(description)
179
 
180
- with st.form("test"):
181
- chosen_model = st.selectbox(
182
- "Choose your model:",
183
- st.session_state.config.get("models"),
 
184
  )
185
- st.session_state.chosen_model = chosen_model
186
  st.markdown(
187
- f"For more information about the chosen model, please be sure to check it out the [🤗 Model Card](https://huggingface.co/DeDeckerThomas/{chosen_model})."
188
  )
189
 
190
- with st.spinner("Loading pipeline..."):
191
- pipe = load_pipeline(
192
- f"{st.session_state.config.get('model_author')}/{st.session_state.chosen_model}"
193
- )
194
-
195
  st.session_state.input_text = st.text_area(
196
- "✍ Input", st.session_state.config.get("example_text"), height=300
197
  ).replace("\n", " ")
198
 
199
  with st.spinner("Extracting keyphrases..."):
200
- pressed = st.form_submit_button("Extract", on_click=extract_keyphrases)
 
 
 
 
 
 
 
 
201
 
202
- if len(st.session_state.selected_rows) > 0 or len(st.session_state.keyphrases) > 0:
203
- rerender_output(st)
 
 
 
204
 
205
- if len(st.session_state.data_frame.columns) > 0:
206
- st.subheader("📜 History")
207
- builder = GridOptionsBuilder.from_dataframe(
208
- st.session_state.data_frame, sortable=False
209
- )
210
- builder.configure_selection(selection_mode="single", use_checkbox=True)
211
- builder.configure_column("text", hide=True)
212
- go = builder.build()
213
- data = AgGrid(
214
- st.session_state.data_frame,
215
- gridOptions=go,
216
- update_mode=GridUpdateMode.SELECTION_CHANGED,
217
- )
218
- st.session_state.selected_rows = pd.DataFrame(data["selected_rows"])
 
5
  import orjson
6
 
7
  from annotated_text.util import get_annotated_html
 
8
  import re
9
  import string
 
10
 
11
 
12
  @st.cache(allow_output_mutation=True, show_spinner=False)
 
19
 
20
  def extract_keyphrases():
21
  st.session_state.keyphrases = pipe(st.session_state.input_text)
22
+ st.session_state.history[f"run_{st.session_state.current_run_id}"] = {
23
+ "run_id": st.session_state.current_run_id,
24
+ "model": st.session_state.chosen_model,
25
+ "text": st.session_state.input_text,
26
+ "keyphrases": st.session_state.keyphrases,
27
+ }
28
+ st.session_state.current_run_id += 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
 
31
  def get_annotated_text(text, keyphrases):
 
73
  return result
74
 
75
 
76
+ def render_output(layout, runs, reverse=False, multi_select=False):
77
+ runs = list(runs.values())[::-1] if reverse else list(runs.values())
78
+ for run in runs:
79
+ layout.markdown(f"**⚙️ Output run {run.get('run_id')}**")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
+ layout.markdown(f"**Model**: {run.get('model')}")
82
+ result = get_annotated_text(run.get("text"), list(run.get("keyphrases")))
83
 
84
+ layout.markdown(
85
+ get_annotated_html(*result),
86
+ unsafe_allow_html=True,
87
+ )
88
+ if "generation" in st.session_state.chosen_model:
89
+ abstractive_keyphrases = [
90
+ keyphrase
91
+ for keyphrase in run.get("keyphrases")
92
+ if keyphrase.lower() not in run.get("text").lower()
93
+ ]
94
+ layout.write(", ".join(abstractive_keyphrases))
95
+ layout.markdown("---")
96
 
97
 
98
  if "config" not in st.session_state:
99
  with open("config.json", "r") as f:
100
  content = f.read()
101
  st.session_state.config = orjson.loads(content)
102
+ st.session_state.history = {}
103
  st.session_state.keyphrases = []
104
+ st.session_state.current_run_id = 1
105
+ st.session_state.chosen_model = st.session_state.config.get("models")[0]
106
 
107
  if "select_rows" not in st.session_state:
108
  st.session_state.selected_rows = []
 
145
 
146
  st.write(description)
147
 
148
+ with st.form("keyphrase-extraction-form"):
149
+ selectbox_container, _ = st.columns(2)
150
+
151
+ st.session_state.chosen_model = selectbox_container.selectbox(
152
+ "Choose your model:", st.session_state.config.get("models")
153
  )
154
+
155
  st.markdown(
156
+ f"For more information about the chosen model, please be sure to check out the [🤗 Model Card](https://huggingface.co/DeDeckerThomas/{st.session_state.chosen_model})."
157
  )
158
 
 
 
 
 
 
159
  st.session_state.input_text = st.text_area(
160
+ "✍ Input", st.session_state.config.get("example_text"), height=250
161
  ).replace("\n", " ")
162
 
163
  with st.spinner("Extracting keyphrases..."):
164
+ pressed = st.form_submit_button("Extract")
165
+
166
+ if pressed:
167
+ with st.spinner("Loading pipeline..."):
168
+ pipe = load_pipeline(
169
+ f"{st.session_state.config.get('model_author')}/{st.session_state.chosen_model}"
170
+ )
171
+ with st.spinner("Extracting keyphrases"):
172
+ extract_keyphrases()
173
 
174
+ options = st.multiselect(
175
+ "Specify runs you want to see",
176
+ st.session_state.history.keys(),
177
+ format_func=lambda run_id: f"Run {run_id.split('_')[1]}",
178
+ )
179
 
180
+ if len(st.session_state.history.keys()) > 0:
181
+ if options:
182
+ render_output(
183
+ st,
184
+ {key: st.session_state.history[key] for key in options},
185
+ )
186
+ else:
187
+ render_output(st, st.session_state.history, reverse=True)
 
 
 
 
 
 
css/style.css CHANGED
@@ -1,5 +1,24 @@
 
 
 
 
 
1
  @import url('https://fonts.googleapis.com/css2?family=Roboto:ital,wght@0,100;0,300;0,400;0,500;0,700;0,900;1,100;1,300;1,400;1,500;1,700;1,900&display=swap');
2
 
 
 
 
 
 
3
  body {
4
  font-family: 'Roboto', 'Source Sans Pro', sans-serif;
5
  }
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+
3
+ Fonts
4
+
5
+ */
6
  @import url('https://fonts.googleapis.com/css2?family=Roboto:ital,wght@0,100;0,300;0,400;0,500;0,700;0,900;1,100;1,300;1,400;1,500;1,700;1,900&display=swap');
7
 
8
+ /*
9
+
10
+ HTML body
11
+
12
+ */
13
  body {
14
  font-family: 'Roboto', 'Source Sans Pro', sans-serif;
15
  }
16
+
17
+ /*
18
+
19
+ Component: Extract Button
20
+
21
+ */
22
+ .css-1cpxqw2{
23
+ float: right;
24
+ }
requirements.txt CHANGED
@@ -3,4 +3,3 @@ transformers[torch]==4.17.0
3
  pandas==1.4.1
4
  numpy==1.22.3
5
  st-annotated-text==3.0.0
6
- streamlit-aggrid==0.2.3.post2
 
3
  pandas==1.4.1
4
  numpy==1.22.3
5
  st-annotated-text==3.0.0