Paula Leonova
commited on
Commit
·
d4be6e6
1
Parent(s):
2b16dfe
Add evaluation metrics
Browse files- app.py +20 -0
- requirements.txt +1 -0
app.py
CHANGED
@@ -5,6 +5,8 @@ import pandas as pd
|
|
5 |
import base64
|
6 |
from typing import Sequence
|
7 |
import streamlit as st
|
|
|
|
|
8 |
|
9 |
from models import create_nest_sentences, load_summary_model, summarizer_gen, load_model, classifier_zero
|
10 |
from utils import plot_result, plot_dual_bar_chart, examples_load, example_long_text_load
|
@@ -102,7 +104,16 @@ if submit_button:
|
|
102 |
plot_dual_bar_chart(topics, scores, topics_ex_text, scores_ex_text)
|
103 |
|
104 |
data_ex_text = pd.DataFrame({'label': topics_ex_text, 'scores_from_full_text': scores_ex_text})
|
|
|
105 |
data2 = pd.merge(data, data_ex_text, on = ['label'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
st.markdown("### Data Table")
|
107 |
|
108 |
with st.spinner('Generating a table of results and a download link...'):
|
@@ -112,5 +123,14 @@ if submit_button:
|
|
112 |
unsafe_allow_html = True
|
113 |
)
|
114 |
st.dataframe(data2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
st.success('All done!')
|
116 |
st.balloons()
|
|
|
5 |
import base64
|
6 |
from typing import Sequence
|
7 |
import streamlit as st
|
8 |
+
from sklearn.metrics import classification_report
|
9 |
+
|
10 |
|
11 |
from models import create_nest_sentences, load_summary_model, summarizer_gen, load_model, classifier_zero
|
12 |
from utils import plot_result, plot_dual_bar_chart, examples_load, example_long_text_load
|
|
|
104 |
plot_dual_bar_chart(topics, scores, topics_ex_text, scores_ex_text)
|
105 |
|
106 |
data_ex_text = pd.DataFrame({'label': topics_ex_text, 'scores_from_full_text': scores_ex_text})
|
107 |
+
|
108 |
data2 = pd.merge(data, data_ex_text, on = ['label'])
|
109 |
+
|
110 |
+
if len(glabels) > 0:
|
111 |
+
gdata = pd.DataFrame({'label': glabels})
|
112 |
+
gdata['is_true_label'] = 1
|
113 |
+
|
114 |
+
data2 = pd.merge(data2, gdata, how = 'left', on = ['label'])
|
115 |
+
data2['is_true_label'].fillna(0, inplace = True)
|
116 |
+
|
117 |
st.markdown("### Data Table")
|
118 |
|
119 |
with st.spinner('Generating a table of results and a download link...'):
|
|
|
123 |
unsafe_allow_html = True
|
124 |
)
|
125 |
st.dataframe(data2)
|
126 |
+
|
127 |
+
if len(glabels) > 0:
|
128 |
+
with st.spinner('Evaluating output against ground truth...'):
|
129 |
+
report = classification_report(y_true = data2[['is_true_label']],
|
130 |
+
y_pred = (data2[['scores_from_full_text']] >= threshold_value) * 1.0,
|
131 |
+
output_dict=True)
|
132 |
+
df_report = pd.DataFrame(report).transpose()
|
133 |
+
st.dataframe(df_report)
|
134 |
+
|
135 |
st.success('All done!')
|
136 |
st.balloons()
|
requirements.txt
CHANGED
@@ -3,5 +3,6 @@ pandas
|
|
3 |
streamlit
|
4 |
plotly
|
5 |
torch
|
|
|
6 |
spacy>=2.2.0,<3.0.0
|
7 |
https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-2.2.0/en_core_web_sm-2.2.0.tar.gz#egg=en_core_web_sm
|
|
|
3 |
streamlit
|
4 |
plotly
|
5 |
torch
|
6 |
+
sklearn
|
7 |
spacy>=2.2.0,<3.0.0
|
8 |
https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-2.2.0/en_core_web_sm-2.2.0.tar.gz#egg=en_core_web_sm
|