Spaces:
Runtime error
Runtime error
bjorn-hommel
commited on
Commit
•
228ea6c
1
Parent(s):
aed0724
refactor and database integration
Browse files- app.py +19 -146
- demo_section.py +314 -0
- utils.py → explore_data_section.py +57 -32
- public_creds.json +11 -0
- requirements.txt +1 -0
app.py
CHANGED
@@ -3,56 +3,33 @@ import torch
|
|
3 |
import dash
|
4 |
import streamlit as st
|
5 |
import pandas as pd
|
|
|
|
|
6 |
import utils
|
|
|
7 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
8 |
from transformers import pipeline
|
|
|
9 |
from dotenv import load_dotenv
|
10 |
import plotly.graph_objects as go
|
11 |
|
12 |
-
|
|
|
13 |
|
14 |
-
|
15 |
-
covariate_columns = {
|
16 |
-
'content_domain': 'Content Domain',
|
17 |
-
'language': 'Language',
|
18 |
-
'rater_group': 'Rater Group',
|
19 |
-
}
|
20 |
|
21 |
-
if '
|
22 |
-
st.session_state.
|
23 |
-
pd
|
24 |
-
.read_feather(path='data.feather').query('partition == "test" | partition == "dev"')
|
25 |
-
.melt(
|
26 |
-
value_vars=['sentiment_model', 'desirability_model'],
|
27 |
-
var_name='x_group',
|
28 |
-
value_name='x',
|
29 |
-
id_vars=['mean_z', 'text', 'content_domain', 'language', 'rater_group', 'study', 'instrument']
|
30 |
-
)
|
31 |
-
.replace(
|
32 |
-
to_replace={
|
33 |
-
'en': 'English',
|
34 |
-
'de': 'German',
|
35 |
-
'other': 'Other',
|
36 |
-
'personality': 'Personality',
|
37 |
-
'laypeople': 'Laypeople',
|
38 |
-
'students': 'Students',
|
39 |
-
'sentiment_model': 'Sentiment Model',
|
40 |
-
'desirability_model': 'Desirability Model'
|
41 |
-
}
|
42 |
-
)
|
43 |
-
.rename(columns=covariate_columns)
|
44 |
-
.rename(
|
45 |
-
columns={
|
46 |
-
'mean_z': 'Human-ratings',
|
47 |
-
'x': 'Machine-ratings',
|
48 |
-
}
|
49 |
-
)
|
50 |
-
)
|
51 |
|
|
|
|
|
52 |
|
53 |
st.markdown("""
|
54 |
-
#
|
55 |
This web application accompanies the paper "*Expanding the Methodological Toolbox: Machine-Based Item Desirability Ratings as an Alternative to Human-Based Ratings*".
|
|
|
|
|
|
|
56 |
|
57 |
## What is this research about?
|
58 |
Researchers use personality scales to measure people's traits and behaviors, but biases can affect the accuracy of these scales.
|
@@ -61,111 +38,7 @@ st.markdown("""
|
|
61 |
which can provide a viable alternative to human ratings and help researchers, scale developers, and practitioners improve the accuracy of personality scales.
|
62 |
""")
|
63 |
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
The left dial indicates how socially desirable it might be to endorse this item.
|
69 |
-
The right dial indicates sentiment (i.e., valence) as estimated by regular sentiment analysis (using the `cardiffnlp/twitter-xlm-roberta-base-sentiment` model).
|
70 |
-
""")
|
71 |
-
|
72 |
-
## desirability model
|
73 |
-
with st.spinner('Processing...'):
|
74 |
-
|
75 |
-
if os.environ.get('item-desirability'):
|
76 |
-
model_path = 'magnolia-psychometrics/item-desirability'
|
77 |
-
else:
|
78 |
-
model_path = os.getenv('model_path')
|
79 |
-
|
80 |
-
auth_token = os.environ.get('item-desirability') or True
|
81 |
-
|
82 |
-
if 'tokenizer' not in st.session_state:
|
83 |
-
st.session_state.tokenizer = AutoTokenizer.from_pretrained(
|
84 |
-
pretrained_model_name_or_path=model_path,
|
85 |
-
use_fast=True,
|
86 |
-
use_auth_token=auth_token
|
87 |
-
)
|
88 |
-
|
89 |
-
if 'model' not in st.session_state:
|
90 |
-
st.session_state.model = AutoModelForSequenceClassification.from_pretrained(
|
91 |
-
pretrained_model_name_or_path=model_path,
|
92 |
-
num_labels=1,
|
93 |
-
ignore_mismatched_sizes=True,
|
94 |
-
use_auth_token=auth_token
|
95 |
-
)
|
96 |
-
|
97 |
-
## sentiment model
|
98 |
-
if 'classifier' not in st.session_state:
|
99 |
-
st.session_state.sentiment_model = 'cardiffnlp/twitter-xlm-roberta-base-sentiment'
|
100 |
-
st.session_state.classifier = pipeline(
|
101 |
-
task='sentiment-analysis',
|
102 |
-
model=st.session_state.sentiment_model,
|
103 |
-
tokenizer=st.session_state.sentiment_model,
|
104 |
-
use_fast=False,
|
105 |
-
top_k=3
|
106 |
-
)
|
107 |
-
|
108 |
-
input_text = st.text_input(
|
109 |
-
label='Estimate item desirability:',
|
110 |
-
value='I love a good fight.',
|
111 |
-
placeholder='Enter item text'
|
112 |
-
)
|
113 |
-
|
114 |
-
if input_text:
|
115 |
-
|
116 |
-
classifier_output = st.session_state.classifier(input_text)
|
117 |
-
classifier_output_dict = {x['label']: x['score'] for x in classifier_output[0]}
|
118 |
-
classifier_score = classifier_output_dict['positive'] - classifier_output_dict['negative']
|
119 |
-
|
120 |
-
inputs = st.session_state.tokenizer(text=input_text, padding=True, return_tensors='pt')
|
121 |
-
|
122 |
-
with torch.no_grad():
|
123 |
-
score = st.session_state.model(**inputs).logits.squeeze().tolist()
|
124 |
-
z = utils.z_score(score)
|
125 |
-
|
126 |
-
p1 = utils.indicator_plot(
|
127 |
-
value=classifier_score,
|
128 |
-
title=f'Item Sentiment',
|
129 |
-
value_range=[-1, 1],
|
130 |
-
domain={'x': [.55, 1], 'y': [0, 1]}
|
131 |
-
)
|
132 |
-
|
133 |
-
p2 = utils.indicator_plot(
|
134 |
-
value=z,
|
135 |
-
title=f'Item Desirability',
|
136 |
-
value_range=[-4, 4],
|
137 |
-
domain={'x': [0, .45], 'y': [0, 1]},
|
138 |
-
)
|
139 |
-
|
140 |
-
fig = go.Figure()
|
141 |
-
fig.add_trace(p1)
|
142 |
-
fig.add_trace(p2)
|
143 |
-
|
144 |
-
fig.update_layout(
|
145 |
-
title=dict(text=f'"{input_text}"', font=dict(size=36),yref='paper'),
|
146 |
-
paper_bgcolor = "white",
|
147 |
-
font = {'color': "black", 'family': "Arial"})
|
148 |
-
|
149 |
-
st.plotly_chart(fig, theme=None, use_container_width=True)
|
150 |
-
|
151 |
-
st.markdown("""
|
152 |
-
Item sentiment: Absolute differences between positive and negative sentiment.
|
153 |
-
Item desirability: z-transformed values, 0 indicated "neutral".
|
154 |
-
""")
|
155 |
-
|
156 |
-
## plot
|
157 |
-
st.markdown("""
|
158 |
-
## Explore the data
|
159 |
-
Figures show the accuarcy in precitions of human-rated item desirability by the sentiment model (left) and the desirability model (right), using `test`-partition data only.
|
160 |
-
""")
|
161 |
-
|
162 |
-
show_covariates = st.checkbox('Show covariates', value=True)
|
163 |
-
|
164 |
-
if show_covariates:
|
165 |
-
option = st.selectbox('Group by', options=list(covariate_columns.values()))
|
166 |
-
else:
|
167 |
-
option = None
|
168 |
-
|
169 |
-
plot = utils.scatter_plot(st.session_state.df, option)
|
170 |
-
|
171 |
-
st.plotly_chart(plot, theme=None, use_container_width=True)
|
|
|
3 |
import dash
|
4 |
import streamlit as st
|
5 |
import pandas as pd
|
6 |
+
import json
|
7 |
+
import random
|
8 |
import utils
|
9 |
+
import firebase_admin
|
10 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
11 |
from transformers import pipeline
|
12 |
+
from firebase_admin import credentials, firestore
|
13 |
from dotenv import load_dotenv
|
14 |
import plotly.graph_objects as go
|
15 |
|
16 |
+
import demo_section
|
17 |
+
import explore_data_section
|
18 |
|
19 |
+
load_dotenv()
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
+
if 'collect_data' not in st.session_state:
|
22 |
+
st.session_state.collect_data = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
+
if 'user_id' not in st.session_state:
|
25 |
+
st.session_state.user_id = random.randint(1, 9999999)
|
26 |
|
27 |
st.markdown("""
|
28 |
+
# Machine-Based Item Desirability Ratings
|
29 |
This web application accompanies the paper "*Expanding the Methodological Toolbox: Machine-Based Item Desirability Ratings as an Alternative to Human-Based Ratings*".
|
30 |
+
|
31 |
+
*Hommel, B. E. (2023). Expanding the methodological toolbox: Machine-based item desirability ratings as an alternative to human-based ratings. Personality and Individual Differences, 213, 112307. https://doi.org/10.1016/j.paid.2023.112307*
|
32 |
+
|
33 |
|
34 |
## What is this research about?
|
35 |
Researchers use personality scales to measure people's traits and behaviors, but biases can affect the accuracy of these scales.
|
|
|
38 |
which can provide a viable alternative to human ratings and help researchers, scale developers, and practitioners improve the accuracy of personality scales.
|
39 |
""")
|
40 |
|
41 |
+
st.divider()
|
42 |
+
demo_section.show()
|
43 |
+
st.divider()
|
44 |
+
explore_data_section.show()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo_section.py
ADDED
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import json
|
4 |
+
import time
|
5 |
+
import random
|
6 |
+
import streamlit as st
|
7 |
+
import firebase_admin
|
8 |
+
import logging
|
9 |
+
from firebase_admin import credentials, firestore
|
10 |
+
from dotenv import load_dotenv
|
11 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
12 |
+
from transformers import pipeline
|
13 |
+
import plotly.graph_objects as go
|
14 |
+
|
15 |
+
logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=logging.INFO)
|
16 |
+
|
17 |
+
load_dotenv()
|
18 |
+
|
19 |
+
def load_credentials():
|
20 |
+
try:
|
21 |
+
with open('public_creds.json') as f:
|
22 |
+
credentials_dict = json.load(f)
|
23 |
+
secret = {
|
24 |
+
'private_key_id': os.environ.get('private_key_id'),
|
25 |
+
'private_key': os.environ.get('private_key')
|
26 |
+
}
|
27 |
+
credentials_dict.update(secret)
|
28 |
+
return credentials_dict
|
29 |
+
except Exception as e:
|
30 |
+
logging.error(f'Error while loading credentials: {e}')
|
31 |
+
return None
|
32 |
+
|
33 |
+
def connect_to_db(credentials_dict):
|
34 |
+
try:
|
35 |
+
cred = credentials.Certificate(credentials_dict)
|
36 |
+
if not firebase_admin._apps:
|
37 |
+
firebase_admin.initialize_app(cred)
|
38 |
+
logging.info('Established connection to db!')
|
39 |
+
return firestore.client()
|
40 |
+
except Exception as e:
|
41 |
+
logging.error(f'Error while connecting to db: {e}')
|
42 |
+
return None
|
43 |
+
|
44 |
+
def get_statements_from_db(db):
|
45 |
+
try:
|
46 |
+
document = db.collection('ItemDesirability').document('Items')
|
47 |
+
statements = document.get().to_dict()['statements']
|
48 |
+
logging.info(f'Retrieved {len(statements)} statements from db!')
|
49 |
+
return statements
|
50 |
+
except Exception as e:
|
51 |
+
logging.error(f'Error while retrieving items from db: {e}')
|
52 |
+
return None
|
53 |
+
|
54 |
+
def update_db(db, payload):
|
55 |
+
|
56 |
+
try:
|
57 |
+
collection_ref = db.collection('ItemDesirability')
|
58 |
+
doc_ref = collection_ref.document('Responses')
|
59 |
+
doc = doc_ref.get()
|
60 |
+
|
61 |
+
if doc.exists:
|
62 |
+
doc_ref.update({
|
63 |
+
'Data': firestore.ArrayUnion([payload])
|
64 |
+
})
|
65 |
+
else:
|
66 |
+
doc_ref.set({
|
67 |
+
'Data': [payload]
|
68 |
+
})
|
69 |
+
logging.info(f'Sent payload to db!')
|
70 |
+
return True
|
71 |
+
except Exception as e:
|
72 |
+
logging.error(f'Error while sending payload to db: {e}')
|
73 |
+
return False
|
74 |
+
|
75 |
+
def pick_random(input_list):
|
76 |
+
try:
|
77 |
+
return random.choice(input_list)
|
78 |
+
except Exception as e:
|
79 |
+
logging.error(f'Error while picking random statement: {e}')
|
80 |
+
return None
|
81 |
+
|
82 |
+
def z_score(y, mean=.04853076, sd=.9409466):
|
83 |
+
return (y - mean) / sd
|
84 |
+
|
85 |
+
def score_text(input_text):
|
86 |
+
classifier_output = st.session_state.classifier(input_text)
|
87 |
+
classifier_output_dict = {x['label']: x['score'] for x in classifier_output[0]}
|
88 |
+
sentiment = classifier_output_dict['positive'] - classifier_output_dict['negative']
|
89 |
+
|
90 |
+
inputs = st.session_state.tokenizer(text=input_text, padding=True, return_tensors='pt')
|
91 |
+
|
92 |
+
with torch.no_grad():
|
93 |
+
score = st.session_state.model(**inputs).logits.squeeze().tolist()
|
94 |
+
desirability = z_score(score)
|
95 |
+
|
96 |
+
return sentiment, desirability
|
97 |
+
|
98 |
+
def indicator_plot(value, title, value_range, domain):
|
99 |
+
|
100 |
+
plot = go.Indicator(
|
101 |
+
mode = "gauge+delta",
|
102 |
+
value = value,
|
103 |
+
domain = domain,
|
104 |
+
title = title,
|
105 |
+
delta = {
|
106 |
+
'reference': 0,
|
107 |
+
'decreasing': {'color': "#ec4899"},
|
108 |
+
'increasing': {'color': "#36def1"}
|
109 |
+
},
|
110 |
+
gauge = {
|
111 |
+
'axis': {'range': value_range, 'tickwidth': 1, 'tickcolor': "black"},
|
112 |
+
'bar': {'color': "#4361ee"},
|
113 |
+
'bgcolor': "white",
|
114 |
+
'borderwidth': 2,
|
115 |
+
'bordercolor': "#efefef",
|
116 |
+
'steps': [
|
117 |
+
{'range': [value_range[0], 0], 'color': '#efefef'},
|
118 |
+
{'range': [0, value_range[1]], 'color': '#efefef'}
|
119 |
+
],
|
120 |
+
'threshold': {
|
121 |
+
'line': {'color': "#4361ee", 'width': 8},
|
122 |
+
'thickness': 0.75,
|
123 |
+
'value': value
|
124 |
+
}
|
125 |
+
}
|
126 |
+
)
|
127 |
+
|
128 |
+
return plot
|
129 |
+
|
130 |
+
def show_scores(sentiment, desirability, input_text):
|
131 |
+
p1 = indicator_plot(
|
132 |
+
value=sentiment,
|
133 |
+
title=f'Item Sentiment',
|
134 |
+
value_range=[-1, 1],
|
135 |
+
domain={'x': [0, .45], 'y': [0, 1]},
|
136 |
+
)
|
137 |
+
|
138 |
+
p2 = indicator_plot(
|
139 |
+
value=desirability,
|
140 |
+
title=f'Item Desirability',
|
141 |
+
value_range=[-4, 4],
|
142 |
+
domain={'x': [.55, 1], 'y': [0, 1]}
|
143 |
+
)
|
144 |
+
|
145 |
+
fig = go.Figure()
|
146 |
+
fig.add_trace(p1)
|
147 |
+
fig.add_trace(p2)
|
148 |
+
|
149 |
+
fig.update_layout(
|
150 |
+
title=dict(text=f'"{input_text}"', font=dict(size=36),yref='paper'),
|
151 |
+
paper_bgcolor = "white",
|
152 |
+
font = {'color': "black", 'family': "Arial"})
|
153 |
+
|
154 |
+
st.plotly_chart(fig, theme=None, use_container_width=True)
|
155 |
+
|
156 |
+
st.markdown("""
|
157 |
+
Item sentiment: Absolute differences between positive and negative sentiment.
|
158 |
+
Item desirability: z-transformed values, 0 indicated "neutral".
|
159 |
+
""")
|
160 |
+
|
161 |
+
def update_statement_placeholder(placeholder):
|
162 |
+
|
163 |
+
placeholder.markdown(
|
164 |
+
body=f"""
|
165 |
+
Is it socially desirable or undesirable to endorse the following statement?
|
166 |
+
### <center>\"{st.session_state.current_statement.capitalize()}\"</center>
|
167 |
+
""",
|
168 |
+
unsafe_allow_html=True
|
169 |
+
)
|
170 |
+
|
171 |
+
def show():
|
172 |
+
credentials_dict = load_credentials()
|
173 |
+
connection_attempts = 0
|
174 |
+
|
175 |
+
if 'db' not in st.session_state:
|
176 |
+
st.session_state.db = None
|
177 |
+
|
178 |
+
while st.session_state.db is None and connection_attempts < 3:
|
179 |
+
st.session_state.db = connect_to_db(credentials_dict)
|
180 |
+
if st.session_state.db is None:
|
181 |
+
logging.info('Retrying to connect to db...')
|
182 |
+
connection_attempts += 1
|
183 |
+
time.sleep(1)
|
184 |
+
|
185 |
+
|
186 |
+
retrieval_attempts = 0
|
187 |
+
|
188 |
+
if 'statements' not in st.session_state:
|
189 |
+
st.session_state.statements = None
|
190 |
+
|
191 |
+
if 'current_statement' not in st.session_state:
|
192 |
+
st.session_state.current_statement = None
|
193 |
+
|
194 |
+
while st.session_state.statements is None and retrieval_attempts < 3:
|
195 |
+
st.session_state.statements = get_statements_from_db(st.session_state.db)
|
196 |
+
st.session_state.current_statement = pick_random(st.session_state.statements)
|
197 |
+
if st.session_state.statements is None:
|
198 |
+
logging.info('Retrying to retrieve statements from db...')
|
199 |
+
retrieval_attempts += 1
|
200 |
+
time.sleep(1)
|
201 |
+
|
202 |
+
st.markdown("""
|
203 |
+
## Try it yourself!
|
204 |
+
Use the text field below to enter a statement that might be part of a psychological questionnaire (e.g., "I love a good fight.").
|
205 |
+
The left dial indicates how socially desirable it might be to endorse this item.
|
206 |
+
The right dial indicates sentiment (i.e., valence) as estimated by regular sentiment analysis (using the `cardiffnlp/twitter-xlm-roberta-base-sentiment` model).
|
207 |
+
""")
|
208 |
+
|
209 |
+
if st.session_state.db:
|
210 |
+
collect_data = st.checkbox(
|
211 |
+
label='I want to support and help improve this research.',
|
212 |
+
value=True
|
213 |
+
)
|
214 |
+
else:
|
215 |
+
collect_data = False
|
216 |
+
|
217 |
+
if st.session_state.db and collect_data:
|
218 |
+
|
219 |
+
st.divider()
|
220 |
+
statement_placeholder = st.empty()
|
221 |
+
update_statement_placeholder(statement_placeholder)
|
222 |
+
|
223 |
+
rating_options = ['[Please select]', 'Very undesirable', 'Undesirable', 'Neutral', 'Desirable', 'Very desirable']
|
224 |
+
|
225 |
+
selected_rating = st.selectbox(
|
226 |
+
label='Rate the statement above according to whether it is socially desirable or undesirable.',
|
227 |
+
options=rating_options,
|
228 |
+
index=0
|
229 |
+
)
|
230 |
+
|
231 |
+
suitability_options = ['No, I\'m just playing around', 'Yes, my input can help improve this research']
|
232 |
+
research_suitability = st.radio(
|
233 |
+
label='Is your input suitable for research purposes?',
|
234 |
+
options=suitability_options,
|
235 |
+
horizontal=True
|
236 |
+
)
|
237 |
+
st.divider()
|
238 |
+
|
239 |
+
with st.spinner('Loading the model might take a couple of seconds...'):
|
240 |
+
|
241 |
+
st.markdown("### Estimate item desirability")
|
242 |
+
|
243 |
+
if os.environ.get('item-desirability'):
|
244 |
+
model_path = 'magnolia-psychometrics/item-desirability'
|
245 |
+
else:
|
246 |
+
model_path = os.getenv('model_path')
|
247 |
+
|
248 |
+
auth_token = os.environ.get('item-desirability') or True
|
249 |
+
|
250 |
+
if 'tokenizer' not in st.session_state:
|
251 |
+
st.session_state.tokenizer = AutoTokenizer.from_pretrained(
|
252 |
+
pretrained_model_name_or_path=model_path,
|
253 |
+
use_fast=True,
|
254 |
+
use_auth_token=auth_token
|
255 |
+
)
|
256 |
+
|
257 |
+
if 'model' not in st.session_state:
|
258 |
+
st.session_state.model = AutoModelForSequenceClassification.from_pretrained(
|
259 |
+
pretrained_model_name_or_path=model_path,
|
260 |
+
num_labels=1,
|
261 |
+
ignore_mismatched_sizes=True,
|
262 |
+
use_auth_token=auth_token
|
263 |
+
)
|
264 |
+
|
265 |
+
## sentiment model
|
266 |
+
if 'classifier' not in st.session_state:
|
267 |
+
st.session_state.sentiment_model = 'cardiffnlp/twitter-xlm-roberta-base-sentiment'
|
268 |
+
st.session_state.classifier = pipeline(
|
269 |
+
task='sentiment-analysis',
|
270 |
+
model=st.session_state.sentiment_model,
|
271 |
+
tokenizer=st.session_state.sentiment_model,
|
272 |
+
use_fast=False,
|
273 |
+
top_k=3
|
274 |
+
)
|
275 |
+
|
276 |
+
input_text = st.text_input(
|
277 |
+
label='Item text/statement:',
|
278 |
+
value='I love a good fight.',
|
279 |
+
placeholder='Enter item text'
|
280 |
+
)
|
281 |
+
|
282 |
+
if st.button(label='Evaluate Item Text', type="primary", use_container_width=True):
|
283 |
+
if collect_data and st.session_state.db:
|
284 |
+
if selected_rating != rating_options[0]:
|
285 |
+
item_rating = rating_options.index(selected_rating)
|
286 |
+
suitability_rating = suitability_options.index(research_suitability)
|
287 |
+
sentiment, desirability = score_text(input_text)
|
288 |
+
|
289 |
+
payload = {
|
290 |
+
'user_id': st.session_state.user_id,
|
291 |
+
'statement': st.session_state.current_statement,
|
292 |
+
'rating': item_rating,
|
293 |
+
'suitability': suitability_rating,
|
294 |
+
'input_text': input_text,
|
295 |
+
'sentiment': sentiment,
|
296 |
+
'desirability': desirability,
|
297 |
+
}
|
298 |
+
|
299 |
+
update_success = update_db(
|
300 |
+
db=st.session_state.db,
|
301 |
+
payload=payload
|
302 |
+
)
|
303 |
+
|
304 |
+
if update_success:
|
305 |
+
st.session_state.current_statement = pick_random(st.session_state.statements)
|
306 |
+
update_statement_placeholder(statement_placeholder)
|
307 |
+
|
308 |
+
show_scores(sentiment, desirability, input_text)
|
309 |
+
|
310 |
+
else:
|
311 |
+
st.error('Please rate the statement presented above!')
|
312 |
+
else:
|
313 |
+
sentiment, desirability = score_text(input_text)
|
314 |
+
show_scores(sentiment, desirability, input_text)
|
utils.py → explore_data_section.py
RENAMED
@@ -1,42 +1,51 @@
|
|
1 |
-
|
|
|
2 |
import plotly.graph_objects as go
|
3 |
import plotly.express as px
|
4 |
|
5 |
-
|
6 |
-
|
|
|
|
|
|
|
7 |
|
8 |
-
|
|
|
|
|
|
|
9 |
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
'
|
17 |
-
'
|
18 |
-
'
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
|
|
|
|
|
|
|
|
34 |
}
|
35 |
-
|
36 |
)
|
37 |
|
38 |
-
return plot
|
39 |
-
|
40 |
def scatter_plot(df, group_var):
|
41 |
|
42 |
colors = ['#36def1', '#4361ee'] if group_var else ['#4361ee']
|
@@ -73,4 +82,20 @@ def scatter_plot(df, group_var):
|
|
73 |
})
|
74 |
plot.update_xaxes(title_standoff = 0)
|
75 |
|
76 |
-
return plot
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
import plotly.graph_objects as go
|
4 |
import plotly.express as px
|
5 |
|
6 |
+
covariate_columns = {
|
7 |
+
'content_domain': 'Content Domain',
|
8 |
+
'language': 'Language',
|
9 |
+
'rater_group': 'Rater Group',
|
10 |
+
}
|
11 |
|
12 |
+
id_vars = [
|
13 |
+
'mean_z', 'text', 'content_domain', 'language',
|
14 |
+
'rater_group', 'study', 'instrument'
|
15 |
+
]
|
16 |
|
17 |
+
if 'df' not in st.session_state:
|
18 |
+
st.session_state.df = (
|
19 |
+
pd
|
20 |
+
.read_feather(path='data.feather')
|
21 |
+
.query('partition == "test" | partition == "dev"')
|
22 |
+
.melt(
|
23 |
+
value_vars=['sentiment_model', 'desirability_model'],
|
24 |
+
var_name='x_group',
|
25 |
+
value_name='x',
|
26 |
+
id_vars=id_vars
|
27 |
+
)
|
28 |
+
.replace(
|
29 |
+
to_replace={
|
30 |
+
'en': 'English',
|
31 |
+
'de': 'German',
|
32 |
+
'other': 'Other',
|
33 |
+
'personality': 'Personality',
|
34 |
+
'laypeople': 'Laypeople',
|
35 |
+
'students': 'Students',
|
36 |
+
'sentiment_model': 'Sentiment Model',
|
37 |
+
'desirability_model': 'Desirability Model'
|
38 |
+
}
|
39 |
+
)
|
40 |
+
.rename(columns=covariate_columns)
|
41 |
+
.rename(
|
42 |
+
columns={
|
43 |
+
'mean_z': 'Human-ratings',
|
44 |
+
'x': 'Machine-ratings',
|
45 |
}
|
46 |
+
)
|
47 |
)
|
48 |
|
|
|
|
|
49 |
def scatter_plot(df, group_var):
|
50 |
|
51 |
colors = ['#36def1', '#4361ee'] if group_var else ['#4361ee']
|
|
|
82 |
})
|
83 |
plot.update_xaxes(title_standoff = 0)
|
84 |
|
85 |
+
return plot
|
86 |
+
|
87 |
+
def show():
|
88 |
+
st.markdown("""
|
89 |
+
## Explore the data
|
90 |
+
Figures show the accuarcy in precitions of human-rated item desirability by the sentiment model (left) and the desirability model (right), using `test`-partition data only.
|
91 |
+
""")
|
92 |
+
|
93 |
+
show_covariates = st.checkbox('Show covariates', value=True)
|
94 |
+
if show_covariates:
|
95 |
+
option = st.selectbox('Group by', options=list(covariate_columns.values()))
|
96 |
+
else:
|
97 |
+
option = None
|
98 |
+
|
99 |
+
if 'df' in st.session_state:
|
100 |
+
plot = scatter_plot(st.session_state.df, option)
|
101 |
+
st.plotly_chart(plot, theme=None, use_container_width=True)
|
public_creds.json
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"type": "service_account",
|
3 |
+
"project_id": "huggingfacespaces",
|
4 |
+
"client_email": "firebase-adminsdk-1nwag@huggingfacespaces.iam.gserviceaccount.com",
|
5 |
+
"client_id": "106819644534694903759",
|
6 |
+
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
|
7 |
+
"token_uri": "https://oauth2.googleapis.com/token",
|
8 |
+
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
|
9 |
+
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/firebase-adminsdk-1nwag%40huggingfacespaces.iam.gserviceaccount.com",
|
10 |
+
"universe_domain": "googleapis.com"
|
11 |
+
}
|
requirements.txt
CHANGED
@@ -5,4 +5,5 @@ dash==2.10.2
|
|
5 |
statsmodels==0.14.0
|
6 |
sentencepiece==0.1.99
|
7 |
altair==4.2.2
|
|
|
8 |
python-dotenv
|
|
|
5 |
statsmodels==0.14.0
|
6 |
sentencepiece==0.1.99
|
7 |
altair==4.2.2
|
8 |
+
firebase_admin==6.1.0
|
9 |
python-dotenv
|