Spaces:
Runtime error
Runtime error
import os | |
import torch | |
import json | |
import time | |
import random | |
import streamlit as st | |
import firebase_admin | |
import logging | |
from firebase_admin import credentials, firestore | |
from dotenv import load_dotenv | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
from transformers import pipeline | |
import plotly.graph_objects as go | |
logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=logging.INFO) | |
load_dotenv() | |
def load_credentials(): | |
try: | |
with open('public_creds.json') as f: | |
credentials_dict = json.load(f) | |
secret = { | |
'private_key_id': os.environ.get('private_key_id'), | |
'private_key': os.environ.get('private_key') | |
} | |
credentials_dict.update(secret) | |
return credentials_dict | |
except Exception as e: | |
logging.error(f'Error while loading credentials: {e}') | |
return None | |
def connect_to_db(credentials_dict): | |
try: | |
cred = credentials.Certificate(credentials_dict) | |
if not firebase_admin._apps: | |
firebase_admin.initialize_app(cred) | |
logging.info('Established connection to db!') | |
return firestore.client() | |
except Exception as e: | |
logging.error(f'Error while connecting to db: {e}') | |
return None | |
def get_statements_from_db(db): | |
try: | |
document = db.collection('ItemDesirability').document('Items') | |
statements = document.get().to_dict()['statements'] | |
logging.info(f'Retrieved {len(statements)} statements from db!') | |
return statements | |
except Exception as e: | |
logging.error(f'Error while retrieving items from db: {e}') | |
return None | |
def update_db(db, payload): | |
try: | |
collection_ref = db.collection('ItemDesirability') | |
doc_ref = collection_ref.document('Responses') | |
doc = doc_ref.get() | |
if doc.exists: | |
doc_ref.update({ | |
'Data': firestore.ArrayUnion([payload]) | |
}) | |
else: | |
doc_ref.set({ | |
'Data': [payload] | |
}) | |
logging.info(f'Sent payload to db!') | |
return True | |
except Exception as e: | |
logging.error(f'Error while sending payload to db: {e}') | |
return False | |
def pick_random(input_list): | |
try: | |
return random.choice(input_list) | |
except Exception as e: | |
logging.error(f'Error while picking random statement: {e}') | |
return None | |
def z_score(y, mean=.04853076, sd=.9409466): | |
return (y - mean) / sd | |
def score_text(input_text): | |
classifier_output = st.session_state.classifier(input_text) | |
classifier_output_dict = {x['label']: x['score'] for x in classifier_output[0]} | |
sentiment = classifier_output_dict['positive'] - classifier_output_dict['negative'] | |
inputs = st.session_state.tokenizer(text=input_text, padding=True, return_tensors='pt') | |
with torch.no_grad(): | |
score = st.session_state.model(**inputs).logits.squeeze().tolist() | |
desirability = z_score(score) | |
return sentiment, desirability | |
def indicator_plot(value, title, value_range, domain): | |
plot = go.Indicator( | |
mode = "gauge+delta", | |
value = value, | |
domain = domain, | |
title = title, | |
delta = { | |
'reference': 0, | |
'decreasing': {'color': "#ec4899"}, | |
'increasing': {'color': "#36def1"} | |
}, | |
gauge = { | |
'axis': {'range': value_range, 'tickwidth': 1, 'tickcolor': "black"}, | |
'bar': {'color': "#4361ee"}, | |
'bgcolor': "white", | |
'borderwidth': 2, | |
'bordercolor': "#efefef", | |
'steps': [ | |
{'range': [value_range[0], 0], 'color': '#efefef'}, | |
{'range': [0, value_range[1]], 'color': '#efefef'} | |
], | |
'threshold': { | |
'line': {'color': "#4361ee", 'width': 8}, | |
'thickness': 0.75, | |
'value': value | |
} | |
} | |
) | |
return plot | |
def show_scores(sentiment, desirability, input_text): | |
p1 = indicator_plot( | |
value=sentiment, | |
title=f'Item Sentiment', | |
value_range=[-1, 1], | |
domain={'x': [0, .45], 'y': [0, 1]}, | |
) | |
p2 = indicator_plot( | |
value=desirability, | |
title=f'Item Desirability', | |
value_range=[-4, 4], | |
domain={'x': [.55, 1], 'y': [0, 1]} | |
) | |
fig = go.Figure() | |
fig.add_trace(p1) | |
fig.add_trace(p2) | |
fig.update_layout( | |
title=dict(text=f'"{input_text}"', font=dict(size=36),yref='paper'), | |
paper_bgcolor = "white", | |
font = {'color': "black", 'family': "Arial"}) | |
st.plotly_chart(fig, theme=None, use_container_width=True) | |
st.markdown(""" | |
Item sentiment: Absolute differences between positive and negative sentiment. | |
Item desirability: z-transformed values, 0 indicated "neutral". | |
""") | |
def update_statement_placeholder(placeholder): | |
placeholder.markdown( | |
body=f""" | |
Is it socially desirable or undesirable to endorse the following statement? | |
### <center>\"{st.session_state.current_statement.capitalize()}\"</center> | |
""", | |
unsafe_allow_html=True | |
) | |
def show(): | |
credentials_dict = load_credentials() | |
connection_attempts = 0 | |
if 'db' not in st.session_state: | |
st.session_state.db = None | |
while st.session_state.db is None and connection_attempts < 3: | |
st.session_state.db = connect_to_db(credentials_dict) | |
if st.session_state.db is None: | |
logging.info('Retrying to connect to db...') | |
connection_attempts += 1 | |
time.sleep(1) | |
retrieval_attempts = 0 | |
if 'statements' not in st.session_state: | |
st.session_state.statements = None | |
if 'current_statement' not in st.session_state: | |
st.session_state.current_statement = None | |
while st.session_state.statements is None and retrieval_attempts < 3: | |
st.session_state.statements = get_statements_from_db(st.session_state.db) | |
st.session_state.current_statement = pick_random(st.session_state.statements) | |
if st.session_state.statements is None: | |
logging.info('Retrying to retrieve statements from db...') | |
retrieval_attempts += 1 | |
time.sleep(1) | |
st.markdown(""" | |
## Try it yourself! | |
Use the text field below to enter a statement that might be part of a psychological questionnaire (e.g., "I love a good fight."). | |
The left dial indicates how socially desirable it might be to endorse this item. | |
The right dial indicates sentiment (i.e., valence) as estimated by regular sentiment analysis (using the `cardiffnlp/twitter-xlm-roberta-base-sentiment` model). | |
""") | |
if st.session_state.db: | |
collect_data = st.checkbox( | |
label='I want to support and help improve this research.', | |
value=True | |
) | |
else: | |
collect_data = False | |
if st.session_state.db and collect_data: | |
statement_placeholder = st.empty() | |
update_statement_placeholder(statement_placeholder) | |
rating_options = ['[Please select]', 'Very undesirable', 'Undesirable', 'Neutral', 'Desirable', 'Very desirable'] | |
selected_rating = st.selectbox( | |
label='Rate the statement above according to whether it is socially desirable or undesirable.', | |
options=rating_options, | |
index=0 | |
) | |
suitability_options = ['No, I\'m just playing around', 'Yes, my input can help improve this research'] | |
research_suitability = st.radio( | |
label='Is your input suitable for research purposes?', | |
options=suitability_options, | |
horizontal=True | |
) | |
with st.spinner('Loading the model might take a couple of seconds...'): | |
st.markdown("### Estimate item desirability") | |
if os.environ.get('item-desirability'): | |
model_path = 'magnolia-psychometrics/item-desirability' | |
else: | |
model_path = os.getenv('model_path') | |
auth_token = os.environ.get('item-desirability') or True | |
if 'tokenizer' not in st.session_state: | |
st.session_state.tokenizer = AutoTokenizer.from_pretrained( | |
pretrained_model_name_or_path=model_path, | |
use_fast=True, | |
use_auth_token=auth_token | |
) | |
if 'model' not in st.session_state: | |
st.session_state.model = AutoModelForSequenceClassification.from_pretrained( | |
pretrained_model_name_or_path=model_path, | |
num_labels=1, | |
ignore_mismatched_sizes=True, | |
use_auth_token=auth_token | |
) | |
## sentiment model | |
if 'classifier' not in st.session_state: | |
st.session_state.sentiment_model = 'cardiffnlp/twitter-xlm-roberta-base-sentiment' | |
st.session_state.classifier = pipeline( | |
task='sentiment-analysis', | |
model=st.session_state.sentiment_model, | |
tokenizer=st.session_state.sentiment_model, | |
use_fast=False, | |
top_k=3 | |
) | |
input_text = st.text_input( | |
label='Item text/statement:', | |
value='I love a good fight.', | |
placeholder='Enter item text' | |
) | |
if st.button(label='Evaluate Item Text', type="primary", use_container_width=True): | |
if collect_data and st.session_state.db: | |
if selected_rating != rating_options[0]: | |
item_rating = rating_options.index(selected_rating) | |
suitability_rating = suitability_options.index(research_suitability) | |
sentiment, desirability = score_text(input_text) | |
payload = { | |
'user_id': st.session_state.user_id, | |
'statement': st.session_state.current_statement, | |
'rating': item_rating, | |
'suitability': suitability_rating, | |
'input_text': input_text, | |
'sentiment': sentiment, | |
'desirability': desirability, | |
} | |
update_success = update_db( | |
db=st.session_state.db, | |
payload=payload | |
) | |
if update_success: | |
st.session_state.current_statement = pick_random(st.session_state.statements) | |
update_statement_placeholder(statement_placeholder) | |
show_scores(sentiment, desirability, input_text) | |
else: | |
st.error('Please rate the statement presented above!') | |
else: | |
sentiment, desirability = score_text(input_text) | |
show_scores(sentiment, desirability, input_text) |