|
import plotly.express as px |
|
import gradio as gr |
|
import plotly.graph_objects as go |
|
import seaborn as sns |
|
import pandas as pd |
|
import matplotlib.pyplot as plt |
|
from matplotlib.ticker import PercentFormatter |
|
|
|
|
|
def plot_wow_retention_by_type(wow_retention): |
|
wow_retention["week"] = pd.to_datetime(wow_retention["week"]) |
|
wow_retention = wow_retention.sort_values(["trader_type", "week"]) |
|
fig = px.line( |
|
wow_retention, |
|
x="week", |
|
y="retention_rate", |
|
color="trader_type", |
|
markers=True, |
|
title="Weekly Retention Rate by Trader Type", |
|
labels={ |
|
"week": "Week", |
|
"retention_rate": "Retention Rate (%)", |
|
"trader_type": "Trader Type", |
|
}, |
|
) |
|
|
|
fig.update_layout( |
|
hovermode="x unified", |
|
legend=dict( |
|
yanchor="middle", |
|
y=0.5, |
|
xanchor="left", |
|
x=1.02, |
|
orientation="v", |
|
), |
|
yaxis=dict( |
|
ticksuffix="%", |
|
range=[ |
|
0, |
|
max(wow_retention["retention_rate"]) * 1.1, |
|
], |
|
), |
|
xaxis=dict(tickformat="%Y-%m-%d"), |
|
margin=dict(r=150), |
|
) |
|
|
|
|
|
fig.update_traces( |
|
hovertemplate="<b>%{y:.1f}%</b><br>Week: %{x|%Y-%m-%d}<extra></extra>" |
|
) |
|
return gr.Plot( |
|
value=fig, |
|
) |
|
|
|
|
|
def plot_cohort_retention_heatmap(retention_matrix): |
|
|
|
retention_matrix = retention_matrix.copy() |
|
|
|
|
|
retention_matrix.index = pd.to_datetime(retention_matrix.index).strftime("%Y-%m-%d") |
|
|
|
|
|
plt.figure(figsize=(12, 8)) |
|
|
|
|
|
mask = retention_matrix.isna() |
|
|
|
|
|
ax = sns.heatmap( |
|
data=retention_matrix, |
|
annot=True, |
|
fmt=".1f", |
|
cmap="YlOrRd", |
|
vmin=0, |
|
vmax=100, |
|
center=50, |
|
cbar_kws={"label": "Retention Rate (%)", "format": PercentFormatter()}, |
|
mask=mask, |
|
annot_kws={"size": 8}, |
|
) |
|
|
|
|
|
plt.title("Cohort Retention Analysis", pad=20, size=14) |
|
plt.xlabel("Weeks Since First Trade", size=12) |
|
plt.ylabel("Cohort Starting Week", size=12) |
|
|
|
|
|
x_labels = [f"Week {i}" for i in retention_matrix.columns] |
|
ax.set_xticklabels(x_labels, rotation=45, ha="right") |
|
|
|
|
|
plt.yticks(rotation=0) |
|
|
|
|
|
ax.set_axisbelow(True) |
|
|
|
|
|
plt.tight_layout() |
|
|
|
cohort_fig = ax.get_figure() |
|
return gr.Plot(value=cohort_fig) |
|
|