Spaces:
Runtime error
Runtime error
magnolia-pm
commited on
Commit
•
4ebac20
1
Parent(s):
8660f01
added sentiment predictions
Browse files
app.py
CHANGED
@@ -2,72 +2,141 @@ import os
|
|
2 |
import torch
|
3 |
import streamlit as st
|
4 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
|
|
|
|
5 |
import plotly.graph_objects as go
|
6 |
|
7 |
-
input_text = st.text_input(
|
8 |
-
label='Estimate item desirability:',
|
9 |
-
value='I love a good fight.',
|
10 |
-
placeholder='Enter item'
|
11 |
-
|
12 |
-
)
|
13 |
-
|
14 |
-
#model_path = '/nlp/nlp/models/finetuned/twitter-xlm-roberta-base-regressive-desirability-ft-4'
|
15 |
-
model_path = 'magnolia-psychometrics/item-desirability'
|
16 |
-
|
17 |
-
#auth_token = os.environ.get("item-desirability") or True
|
18 |
-
auth_token = "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"
|
19 |
-
|
20 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
21 |
-
pretrained_model_name_or_path=model_path,
|
22 |
-
use_fast=True,
|
23 |
-
use_auth_token=auth_token
|
24 |
-
)
|
25 |
-
|
26 |
-
model = AutoModelForSequenceClassification.from_pretrained(
|
27 |
-
pretrained_model_name_or_path=model_path,
|
28 |
-
num_labels=1,
|
29 |
-
ignore_mismatched_sizes=True,
|
30 |
-
use_auth_token=auth_token
|
31 |
-
)
|
32 |
|
33 |
def z_score(y, mean=.04853076, sd=.9409466):
|
34 |
return (y - mean) / sd
|
35 |
|
36 |
-
|
37 |
-
|
38 |
-
inputs = tokenizer(input_text, padding=True, return_tensors='pt')
|
39 |
-
|
40 |
-
with torch.no_grad():
|
41 |
-
score = model(**inputs).logits.squeeze().tolist()
|
42 |
-
z = z_score(score)
|
43 |
|
44 |
-
|
45 |
mode = "gauge+delta",
|
46 |
-
value =
|
47 |
-
domain =
|
48 |
-
title =
|
49 |
delta = {
|
50 |
'reference': 0,
|
51 |
'decreasing': {'color': "#ec4899"},
|
52 |
'increasing': {'color': "#36def1"}
|
53 |
},
|
54 |
gauge = {
|
55 |
-
'axis': {'range':
|
56 |
'bar': {'color': "#4361ee"},
|
57 |
'bgcolor': "white",
|
58 |
'borderwidth': 2,
|
59 |
'bordercolor': "#efefef",
|
60 |
'steps': [
|
61 |
-
{'range': [
|
62 |
-
{'range': [0,
|
|
|
63 |
'threshold': {
|
64 |
'line': {'color': "#4361ee", 'width': 8},
|
65 |
'thickness': 0.75,
|
66 |
-
'value':
|
67 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
|
69 |
fig.update_layout(
|
|
|
70 |
paper_bgcolor = "white",
|
71 |
font = {'color': "black", 'family': "Arial"})
|
72 |
|
73 |
st.plotly_chart(fig, theme=None, use_container_width=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import torch
|
3 |
import streamlit as st
|
4 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
5 |
+
from transformers import pipeline
|
6 |
+
from plotly.subplots import make_subplots
|
7 |
import plotly.graph_objects as go
|
8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
def z_score(y, mean=.04853076, sd=.9409466):
|
11 |
return (y - mean) / sd
|
12 |
|
13 |
+
def indicator_plot(value, title, value_range, domain):
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
+
plot = go.Indicator(
|
16 |
mode = "gauge+delta",
|
17 |
+
value = value,
|
18 |
+
domain = domain,
|
19 |
+
title = title,
|
20 |
delta = {
|
21 |
'reference': 0,
|
22 |
'decreasing': {'color': "#ec4899"},
|
23 |
'increasing': {'color': "#36def1"}
|
24 |
},
|
25 |
gauge = {
|
26 |
+
'axis': {'range': value_range, 'tickwidth': 1, 'tickcolor': "black"},
|
27 |
'bar': {'color': "#4361ee"},
|
28 |
'bgcolor': "white",
|
29 |
'borderwidth': 2,
|
30 |
'bordercolor': "#efefef",
|
31 |
'steps': [
|
32 |
+
{'range': [value_range[0], 0], 'color': '#efefef'},
|
33 |
+
{'range': [0, value_range[1]], 'color': '#efefef'}
|
34 |
+
],
|
35 |
'threshold': {
|
36 |
'line': {'color': "#4361ee", 'width': 8},
|
37 |
'thickness': 0.75,
|
38 |
+
'value': value
|
39 |
+
}
|
40 |
+
}
|
41 |
+
)
|
42 |
+
|
43 |
+
return plot
|
44 |
+
|
45 |
+
body = """
|
46 |
+
# NLP for Item Desirability Ratings
|
47 |
+
This web application accompanies the paper *Leveraging Natural Language Processing for Item Desirability Ratings:
|
48 |
+
A Machine-Based Alternative to Human Judges* submitted to the Journal *Personality and Individual Differences*.
|
49 |
+
|
50 |
+
## What is this research about?
|
51 |
+
Researchers use personality scales to measure people's traits and behaviors, but biases can affect the accuracy of these scales.
|
52 |
+
Socially desirable responding is a common bias that can skew results. To overcome this, researchers gather item desirability ratings, e.g., to ensure that questions are neutral.
|
53 |
+
Recently, advancements in natural language processing have made it possible to use machines to estimate social desirability ratings,
|
54 |
+
which can provide a viable alternative to human ratings and help researchers, scale developers, and practitioners improve the accuracy of personality scales.
|
55 |
+
|
56 |
+
## Try it yourself!
|
57 |
+
Use the text field below to enter a statement that might be part of a psychological questionnaire (e.g., "I love a good fight.").
|
58 |
+
The left dial will indicate how socially desirable it might be to endorse this item.
|
59 |
+
The right dial indicates sentiment (i.e., valence) as estimated by regular sentiment analysis (using the `cardiffnlp/twitter-xlm-roberta-base-sentiment` model).
|
60 |
+
"""
|
61 |
+
|
62 |
+
st.markdown(body)
|
63 |
+
|
64 |
+
input_text = st.text_input(
|
65 |
+
label='Estimate item desirability:',
|
66 |
+
value='I love a good fight.',
|
67 |
+
placeholder='Enter item'
|
68 |
+
)
|
69 |
+
|
70 |
+
# desirability model
|
71 |
+
model_path = '/nlp/nlp/models/finetuned/twitter-xlm-roberta-base-regressive-desirability-ft-4'
|
72 |
+
#model_path = 'magnolia-psychometrics/item-desirability'
|
73 |
+
#auth_token = os.environ.get("item-desirability") or True
|
74 |
+
auth_token = "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"
|
75 |
+
|
76 |
+
|
77 |
+
if 'tokenizer' not in globals():
|
78 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
79 |
+
pretrained_model_name_or_path=model_path,
|
80 |
+
use_fast=True,
|
81 |
+
use_auth_token=auth_token
|
82 |
+
)
|
83 |
+
|
84 |
+
if 'model' not in globals():
|
85 |
+
model = AutoModelForSequenceClassification.from_pretrained(
|
86 |
+
pretrained_model_name_or_path=model_path,
|
87 |
+
num_labels=1,
|
88 |
+
ignore_mismatched_sizes=True,
|
89 |
+
use_auth_token=auth_token
|
90 |
+
)
|
91 |
+
|
92 |
+
# sentiment classifier
|
93 |
+
if 'classifier' not in globals():
|
94 |
+
sentiment_model = 'cardiffnlp/twitter-xlm-roberta-base-sentiment'
|
95 |
+
classifier = pipeline("sentiment-analysis", model=sentiment_model, tokenizer=sentiment_model, top_k=3)
|
96 |
+
|
97 |
+
classifier_output = classifier(input_text)
|
98 |
+
classifier_output_dict = {x['label']: x['score'] for x in classifier_output[0]}
|
99 |
+
classifier_score = classifier_output_dict['positive'] - classifier_output_dict['negative']
|
100 |
+
|
101 |
+
if input_text:
|
102 |
+
|
103 |
+
inputs = tokenizer(input_text, padding=True, return_tensors='pt')
|
104 |
+
|
105 |
+
with torch.no_grad():
|
106 |
+
score = model(**inputs).logits.squeeze().tolist()
|
107 |
+
z = z_score(score)
|
108 |
+
|
109 |
+
p1 = indicator_plot(
|
110 |
+
value=z,
|
111 |
+
title=f"Item Desirability",
|
112 |
+
value_range=[-4, 4],
|
113 |
+
domain={'x': [0, .45], 'y': [0, 1]},
|
114 |
+
)
|
115 |
+
|
116 |
+
p2 = indicator_plot(
|
117 |
+
value=classifier_score,
|
118 |
+
title=f"Item Sentiment",
|
119 |
+
value_range=[-1, 1],
|
120 |
+
domain={'x': [.55, 1], 'y': [0, 1]}
|
121 |
+
|
122 |
+
)
|
123 |
+
|
124 |
+
fig = go.Figure()
|
125 |
+
fig.add_trace(p1)
|
126 |
+
fig.add_trace(p2)
|
127 |
|
128 |
fig.update_layout(
|
129 |
+
title=dict(text=f'"{input_text}"', font=dict(size=36),yref='paper'),
|
130 |
paper_bgcolor = "white",
|
131 |
font = {'color': "black", 'family': "Arial"})
|
132 |
|
133 |
st.plotly_chart(fig, theme=None, use_container_width=True)
|
134 |
+
|
135 |
+
|
136 |
+
notes = """
|
137 |
+
Item desirability: z-transformed values, 0 indicated "neutral".
|
138 |
+
|
139 |
+
Item sentiment: Absolute differences between positive and negative sentiment.
|
140 |
+
"""
|
141 |
+
|
142 |
+
st.markdown(notes)
|