bjorn-hommel commited on
Commit
46ca3b9
1 Parent(s): 0c985d2

relocated functions to utils

Browse files
Files changed (2) hide show
  1. app.py +5 -80
  2. utils.py +76 -0
app.py CHANGED
@@ -3,88 +3,14 @@ import torch
3
  import dash
4
  import streamlit as st
5
  import pandas as pd
 
6
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
7
  from transformers import pipeline
8
  from dotenv import load_dotenv
9
- from plotly.subplots import make_subplots
10
  import plotly.graph_objects as go
11
- import plotly.express as px
12
 
13
  load_dotenv()
14
 
15
- def z_score(y, mean=.04853076, sd=.9409466):
16
- return (y - mean) / sd
17
-
18
- def indicator_plot(value, title, value_range, domain):
19
-
20
- plot = go.Indicator(
21
- mode = "gauge+delta",
22
- value = value,
23
- domain = domain,
24
- title = title,
25
- delta = {
26
- 'reference': 0,
27
- 'decreasing': {'color': "#ec4899"},
28
- 'increasing': {'color': "#36def1"}
29
- },
30
- gauge = {
31
- 'axis': {'range': value_range, 'tickwidth': 1, 'tickcolor': "black"},
32
- 'bar': {'color': "#4361ee"},
33
- 'bgcolor': "white",
34
- 'borderwidth': 2,
35
- 'bordercolor': "#efefef",
36
- 'steps': [
37
- {'range': [value_range[0], 0], 'color': '#efefef'},
38
- {'range': [0, value_range[1]], 'color': '#efefef'}
39
- ],
40
- 'threshold': {
41
- 'line': {'color': "#4361ee", 'width': 8},
42
- 'thickness': 0.75,
43
- 'value': value
44
- }
45
- }
46
- )
47
-
48
- return plot
49
-
50
- def scatter_plot(df, group_var):
51
-
52
- colors = ['#36def1', '#4361ee'] if group_var else ['#4361ee']
53
-
54
- plot = px.scatter(
55
- df,
56
- x='Machine-ratings',
57
- y='Human-ratings',
58
- color=group_var,
59
- facet_col='x_group',
60
- facet_col_wrap=2,
61
- trendline='ols',
62
- trendline_scope='trace',
63
- hover_data={
64
- 'Text': df.text,
65
- 'Language': False,
66
- 'x_group': False,
67
- 'Human-ratings': ':.2f',
68
- 'Machine-ratings': ':.2f',
69
- 'Study': df.study,
70
- 'Instrument': df.instrument,
71
- },
72
- width=400,
73
- height=400,
74
- color_discrete_sequence=colors
75
- )
76
-
77
- plot.for_each_annotation(lambda a: a.update(text=a.text.split('=')[-1]))
78
- plot.update_layout(
79
- legend={
80
- 'orientation':'h',
81
- 'yanchor': 'bottom',
82
- 'y': -.30
83
- })
84
- plot.update_xaxes(title_standoff = 0)
85
-
86
- return plot
87
-
88
  # data import and wrangling
89
  covariate_columns = {
90
  'content_domain': 'Content Domain',
@@ -195,16 +121,16 @@ with st.spinner('Processing...'):
195
 
196
  with torch.no_grad():
197
  score = st.session_state.model(**inputs).logits.squeeze().tolist()
198
- z = z_score(score)
199
 
200
- p1 = indicator_plot(
201
  value=classifier_score,
202
  title=f'Item Sentiment',
203
  value_range=[-1, 1],
204
  domain={'x': [.55, 1], 'y': [0, 1]}
205
  )
206
 
207
- p2 = indicator_plot(
208
  value=z,
209
  title=f'Item Desirability',
210
  value_range=[-4, 4],
@@ -233,7 +159,6 @@ st.markdown("""
233
  Figures show the accuarcy in precitions of human-rated item desirability by the sentiment model (left) and the desirability model (right), using `test`-partition data only.
234
  """)
235
 
236
-
237
  show_covariates = st.checkbox('Show covariates', value=True)
238
 
239
  if show_covariates:
@@ -241,6 +166,6 @@ if show_covariates:
241
  else:
242
  option = None
243
 
244
- plot = scatter_plot(st.session_state.df, option)
245
 
246
  st.plotly_chart(plot, theme=None, use_container_width=True)
 
3
  import dash
4
  import streamlit as st
5
  import pandas as pd
6
+ import utils
7
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
8
  from transformers import pipeline
9
  from dotenv import load_dotenv
 
10
  import plotly.graph_objects as go
 
11
 
12
  load_dotenv()
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  # data import and wrangling
15
  covariate_columns = {
16
  'content_domain': 'Content Domain',
 
121
 
122
  with torch.no_grad():
123
  score = st.session_state.model(**inputs).logits.squeeze().tolist()
124
+ z = utils.z_score(score)
125
 
126
+ p1 = utils.indicator_plot(
127
  value=classifier_score,
128
  title=f'Item Sentiment',
129
  value_range=[-1, 1],
130
  domain={'x': [.55, 1], 'y': [0, 1]}
131
  )
132
 
133
+ p2 = utils.indicator_plot(
134
  value=z,
135
  title=f'Item Desirability',
136
  value_range=[-4, 4],
 
159
  Figures show the accuarcy in precitions of human-rated item desirability by the sentiment model (left) and the desirability model (right), using `test`-partition data only.
160
  """)
161
 
 
162
  show_covariates = st.checkbox('Show covariates', value=True)
163
 
164
  if show_covariates:
 
166
  else:
167
  option = None
168
 
169
+ plot = utils.scatter_plot(st.session_state.df, option)
170
 
171
  st.plotly_chart(plot, theme=None, use_container_width=True)
utils.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from plotly.subplots import make_subplots
2
+ import plotly.graph_objects as go
3
+ import plotly.express as px
4
+
5
+ def z_score(y, mean=.04853076, sd=.9409466):
6
+ return (y - mean) / sd
7
+
8
+ def indicator_plot(value, title, value_range, domain):
9
+
10
+ plot = go.Indicator(
11
+ mode = "gauge+delta",
12
+ value = value,
13
+ domain = domain,
14
+ title = title,
15
+ delta = {
16
+ 'reference': 0,
17
+ 'decreasing': {'color': "#ec4899"},
18
+ 'increasing': {'color': "#36def1"}
19
+ },
20
+ gauge = {
21
+ 'axis': {'range': value_range, 'tickwidth': 1, 'tickcolor': "black"},
22
+ 'bar': {'color': "#4361ee"},
23
+ 'bgcolor': "white",
24
+ 'borderwidth': 2,
25
+ 'bordercolor': "#efefef",
26
+ 'steps': [
27
+ {'range': [value_range[0], 0], 'color': '#efefef'},
28
+ {'range': [0, value_range[1]], 'color': '#efefef'}
29
+ ],
30
+ 'threshold': {
31
+ 'line': {'color': "#4361ee", 'width': 8},
32
+ 'thickness': 0.75,
33
+ 'value': value
34
+ }
35
+ }
36
+ )
37
+
38
+ return plot
39
+
40
+ def scatter_plot(df, group_var):
41
+
42
+ colors = ['#36def1', '#4361ee'] if group_var else ['#4361ee']
43
+
44
+ plot = px.scatter(
45
+ df,
46
+ x='Machine-ratings',
47
+ y='Human-ratings',
48
+ color=group_var,
49
+ facet_col='x_group',
50
+ facet_col_wrap=2,
51
+ trendline='ols',
52
+ trendline_scope='trace',
53
+ hover_data={
54
+ 'Text': df.text,
55
+ 'Language': False,
56
+ 'x_group': False,
57
+ 'Human-ratings': ':.2f',
58
+ 'Machine-ratings': ':.2f',
59
+ 'Study': df.study,
60
+ 'Instrument': df.instrument,
61
+ },
62
+ width=400,
63
+ height=400,
64
+ color_discrete_sequence=colors
65
+ )
66
+
67
+ plot.for_each_annotation(lambda a: a.update(text=a.text.split('=')[-1]))
68
+ plot.update_layout(
69
+ legend={
70
+ 'orientation':'h',
71
+ 'yanchor': 'bottom',
72
+ 'y': -.30
73
+ })
74
+ plot.update_xaxes(title_standoff = 0)
75
+
76
+ return plot