item-desirability-demo / explore_data_section.py
bjorn-hommel's picture
refactor and database integration
228ea6c
raw
history blame
2.9 kB
import streamlit as st
import pandas as pd
import plotly.graph_objects as go
import plotly.express as px
covariate_columns = {
'content_domain': 'Content Domain',
'language': 'Language',
'rater_group': 'Rater Group',
}
id_vars = [
'mean_z', 'text', 'content_domain', 'language',
'rater_group', 'study', 'instrument'
]
if 'df' not in st.session_state:
st.session_state.df = (
pd
.read_feather(path='data.feather')
.query('partition == "test" | partition == "dev"')
.melt(
value_vars=['sentiment_model', 'desirability_model'],
var_name='x_group',
value_name='x',
id_vars=id_vars
)
.replace(
to_replace={
'en': 'English',
'de': 'German',
'other': 'Other',
'personality': 'Personality',
'laypeople': 'Laypeople',
'students': 'Students',
'sentiment_model': 'Sentiment Model',
'desirability_model': 'Desirability Model'
}
)
.rename(columns=covariate_columns)
.rename(
columns={
'mean_z': 'Human-ratings',
'x': 'Machine-ratings',
}
)
)
def scatter_plot(df, group_var):
colors = ['#36def1', '#4361ee'] if group_var else ['#4361ee']
plot = px.scatter(
df,
x='Machine-ratings',
y='Human-ratings',
color=group_var,
facet_col='x_group',
facet_col_wrap=2,
trendline='ols',
trendline_scope='trace',
hover_data={
'Text': df.text,
'Language': False,
'x_group': False,
'Human-ratings': ':.2f',
'Machine-ratings': ':.2f',
'Study': df.study,
'Instrument': df.instrument,
},
width=400,
height=400,
color_discrete_sequence=colors
)
plot.for_each_annotation(lambda a: a.update(text=a.text.split('=')[-1]))
plot.update_layout(
legend={
'orientation':'h',
'yanchor': 'bottom',
'y': -.30
})
plot.update_xaxes(title_standoff = 0)
return plot
def show():
st.markdown("""
## Explore the data
Figures show the accuarcy in precitions of human-rated item desirability by the sentiment model (left) and the desirability model (right), using `test`-partition data only.
""")
show_covariates = st.checkbox('Show covariates', value=True)
if show_covariates:
option = st.selectbox('Group by', options=list(covariate_columns.values()))
else:
option = None
if 'df' in st.session_state:
plot = scatter_plot(st.session_state.df, option)
st.plotly_chart(plot, theme=None, use_container_width=True)