bjorn-hommel
commited on
Commit
•
46ca3b9
1
Parent(s):
0c985d2
relocated functions to utils
Browse files
app.py
CHANGED
@@ -3,88 +3,14 @@ import torch
|
|
3 |
import dash
|
4 |
import streamlit as st
|
5 |
import pandas as pd
|
|
|
6 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
7 |
from transformers import pipeline
|
8 |
from dotenv import load_dotenv
|
9 |
-
from plotly.subplots import make_subplots
|
10 |
import plotly.graph_objects as go
|
11 |
-
import plotly.express as px
|
12 |
|
13 |
load_dotenv()
|
14 |
|
15 |
-
def z_score(y, mean=.04853076, sd=.9409466):
|
16 |
-
return (y - mean) / sd
|
17 |
-
|
18 |
-
def indicator_plot(value, title, value_range, domain):
|
19 |
-
|
20 |
-
plot = go.Indicator(
|
21 |
-
mode = "gauge+delta",
|
22 |
-
value = value,
|
23 |
-
domain = domain,
|
24 |
-
title = title,
|
25 |
-
delta = {
|
26 |
-
'reference': 0,
|
27 |
-
'decreasing': {'color': "#ec4899"},
|
28 |
-
'increasing': {'color': "#36def1"}
|
29 |
-
},
|
30 |
-
gauge = {
|
31 |
-
'axis': {'range': value_range, 'tickwidth': 1, 'tickcolor': "black"},
|
32 |
-
'bar': {'color': "#4361ee"},
|
33 |
-
'bgcolor': "white",
|
34 |
-
'borderwidth': 2,
|
35 |
-
'bordercolor': "#efefef",
|
36 |
-
'steps': [
|
37 |
-
{'range': [value_range[0], 0], 'color': '#efefef'},
|
38 |
-
{'range': [0, value_range[1]], 'color': '#efefef'}
|
39 |
-
],
|
40 |
-
'threshold': {
|
41 |
-
'line': {'color': "#4361ee", 'width': 8},
|
42 |
-
'thickness': 0.75,
|
43 |
-
'value': value
|
44 |
-
}
|
45 |
-
}
|
46 |
-
)
|
47 |
-
|
48 |
-
return plot
|
49 |
-
|
50 |
-
def scatter_plot(df, group_var):
|
51 |
-
|
52 |
-
colors = ['#36def1', '#4361ee'] if group_var else ['#4361ee']
|
53 |
-
|
54 |
-
plot = px.scatter(
|
55 |
-
df,
|
56 |
-
x='Machine-ratings',
|
57 |
-
y='Human-ratings',
|
58 |
-
color=group_var,
|
59 |
-
facet_col='x_group',
|
60 |
-
facet_col_wrap=2,
|
61 |
-
trendline='ols',
|
62 |
-
trendline_scope='trace',
|
63 |
-
hover_data={
|
64 |
-
'Text': df.text,
|
65 |
-
'Language': False,
|
66 |
-
'x_group': False,
|
67 |
-
'Human-ratings': ':.2f',
|
68 |
-
'Machine-ratings': ':.2f',
|
69 |
-
'Study': df.study,
|
70 |
-
'Instrument': df.instrument,
|
71 |
-
},
|
72 |
-
width=400,
|
73 |
-
height=400,
|
74 |
-
color_discrete_sequence=colors
|
75 |
-
)
|
76 |
-
|
77 |
-
plot.for_each_annotation(lambda a: a.update(text=a.text.split('=')[-1]))
|
78 |
-
plot.update_layout(
|
79 |
-
legend={
|
80 |
-
'orientation':'h',
|
81 |
-
'yanchor': 'bottom',
|
82 |
-
'y': -.30
|
83 |
-
})
|
84 |
-
plot.update_xaxes(title_standoff = 0)
|
85 |
-
|
86 |
-
return plot
|
87 |
-
|
88 |
# data import and wrangling
|
89 |
covariate_columns = {
|
90 |
'content_domain': 'Content Domain',
|
@@ -195,16 +121,16 @@ with st.spinner('Processing...'):
|
|
195 |
|
196 |
with torch.no_grad():
|
197 |
score = st.session_state.model(**inputs).logits.squeeze().tolist()
|
198 |
-
z = z_score(score)
|
199 |
|
200 |
-
p1 = indicator_plot(
|
201 |
value=classifier_score,
|
202 |
title=f'Item Sentiment',
|
203 |
value_range=[-1, 1],
|
204 |
domain={'x': [.55, 1], 'y': [0, 1]}
|
205 |
)
|
206 |
|
207 |
-
p2 = indicator_plot(
|
208 |
value=z,
|
209 |
title=f'Item Desirability',
|
210 |
value_range=[-4, 4],
|
@@ -233,7 +159,6 @@ st.markdown("""
|
|
233 |
Figures show the accuarcy in precitions of human-rated item desirability by the sentiment model (left) and the desirability model (right), using `test`-partition data only.
|
234 |
""")
|
235 |
|
236 |
-
|
237 |
show_covariates = st.checkbox('Show covariates', value=True)
|
238 |
|
239 |
if show_covariates:
|
@@ -241,6 +166,6 @@ if show_covariates:
|
|
241 |
else:
|
242 |
option = None
|
243 |
|
244 |
-
plot = scatter_plot(st.session_state.df, option)
|
245 |
|
246 |
st.plotly_chart(plot, theme=None, use_container_width=True)
|
|
|
3 |
import dash
|
4 |
import streamlit as st
|
5 |
import pandas as pd
|
6 |
+
import utils
|
7 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
8 |
from transformers import pipeline
|
9 |
from dotenv import load_dotenv
|
|
|
10 |
import plotly.graph_objects as go
|
|
|
11 |
|
12 |
load_dotenv()
|
13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
# data import and wrangling
|
15 |
covariate_columns = {
|
16 |
'content_domain': 'Content Domain',
|
|
|
121 |
|
122 |
with torch.no_grad():
|
123 |
score = st.session_state.model(**inputs).logits.squeeze().tolist()
|
124 |
+
z = utils.z_score(score)
|
125 |
|
126 |
+
p1 = utils.indicator_plot(
|
127 |
value=classifier_score,
|
128 |
title=f'Item Sentiment',
|
129 |
value_range=[-1, 1],
|
130 |
domain={'x': [.55, 1], 'y': [0, 1]}
|
131 |
)
|
132 |
|
133 |
+
p2 = utils.indicator_plot(
|
134 |
value=z,
|
135 |
title=f'Item Desirability',
|
136 |
value_range=[-4, 4],
|
|
|
159 |
Figures show the accuarcy in precitions of human-rated item desirability by the sentiment model (left) and the desirability model (right), using `test`-partition data only.
|
160 |
""")
|
161 |
|
|
|
162 |
show_covariates = st.checkbox('Show covariates', value=True)
|
163 |
|
164 |
if show_covariates:
|
|
|
166 |
else:
|
167 |
option = None
|
168 |
|
169 |
+
plot = utils.scatter_plot(st.session_state.df, option)
|
170 |
|
171 |
st.plotly_chart(plot, theme=None, use_container_width=True)
|
utils.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from plotly.subplots import make_subplots
|
2 |
+
import plotly.graph_objects as go
|
3 |
+
import plotly.express as px
|
4 |
+
|
5 |
+
def z_score(y, mean=.04853076, sd=.9409466):
|
6 |
+
return (y - mean) / sd
|
7 |
+
|
8 |
+
def indicator_plot(value, title, value_range, domain):
|
9 |
+
|
10 |
+
plot = go.Indicator(
|
11 |
+
mode = "gauge+delta",
|
12 |
+
value = value,
|
13 |
+
domain = domain,
|
14 |
+
title = title,
|
15 |
+
delta = {
|
16 |
+
'reference': 0,
|
17 |
+
'decreasing': {'color': "#ec4899"},
|
18 |
+
'increasing': {'color': "#36def1"}
|
19 |
+
},
|
20 |
+
gauge = {
|
21 |
+
'axis': {'range': value_range, 'tickwidth': 1, 'tickcolor': "black"},
|
22 |
+
'bar': {'color': "#4361ee"},
|
23 |
+
'bgcolor': "white",
|
24 |
+
'borderwidth': 2,
|
25 |
+
'bordercolor': "#efefef",
|
26 |
+
'steps': [
|
27 |
+
{'range': [value_range[0], 0], 'color': '#efefef'},
|
28 |
+
{'range': [0, value_range[1]], 'color': '#efefef'}
|
29 |
+
],
|
30 |
+
'threshold': {
|
31 |
+
'line': {'color': "#4361ee", 'width': 8},
|
32 |
+
'thickness': 0.75,
|
33 |
+
'value': value
|
34 |
+
}
|
35 |
+
}
|
36 |
+
)
|
37 |
+
|
38 |
+
return plot
|
39 |
+
|
40 |
+
def scatter_plot(df, group_var):
|
41 |
+
|
42 |
+
colors = ['#36def1', '#4361ee'] if group_var else ['#4361ee']
|
43 |
+
|
44 |
+
plot = px.scatter(
|
45 |
+
df,
|
46 |
+
x='Machine-ratings',
|
47 |
+
y='Human-ratings',
|
48 |
+
color=group_var,
|
49 |
+
facet_col='x_group',
|
50 |
+
facet_col_wrap=2,
|
51 |
+
trendline='ols',
|
52 |
+
trendline_scope='trace',
|
53 |
+
hover_data={
|
54 |
+
'Text': df.text,
|
55 |
+
'Language': False,
|
56 |
+
'x_group': False,
|
57 |
+
'Human-ratings': ':.2f',
|
58 |
+
'Machine-ratings': ':.2f',
|
59 |
+
'Study': df.study,
|
60 |
+
'Instrument': df.instrument,
|
61 |
+
},
|
62 |
+
width=400,
|
63 |
+
height=400,
|
64 |
+
color_discrete_sequence=colors
|
65 |
+
)
|
66 |
+
|
67 |
+
plot.for_each_annotation(lambda a: a.update(text=a.text.split('=')[-1]))
|
68 |
+
plot.update_layout(
|
69 |
+
legend={
|
70 |
+
'orientation':'h',
|
71 |
+
'yanchor': 'bottom',
|
72 |
+
'y': -.30
|
73 |
+
})
|
74 |
+
plot.update_xaxes(title_standoff = 0)
|
75 |
+
|
76 |
+
return plot
|