trader_agents_performance / tabs /retention_plots.py
cyberosa
first prototype of wow retention graph for different trader types
355fb10
raw
history blame
2.96 kB
import plotly.express as px
import gradio as gr
import plotly.graph_objects as go
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.ticker import PercentFormatter
def plot_wow_retention_by_type(wow_retention):
wow_retention["week"] = pd.to_datetime(wow_retention["week"])
wow_retention = wow_retention.sort_values(["trader_type", "week"])
fig = px.line(
wow_retention,
x="week",
y="retention_rate",
color="trader_type",
markers=True,
title="Weekly Retention Rate by Trader Type",
labels={
"week": "Week",
"retention_rate": "Retention Rate (%)",
"trader_type": "Trader Type",
},
)
fig.update_layout(
hovermode="x unified",
legend=dict(
yanchor="middle",
y=0.5,
xanchor="left",
x=1.02, # Move legend outside
orientation="v",
),
yaxis=dict(
ticksuffix="%",
range=[
0,
max(wow_retention["retention_rate"]) * 1.1,
], # Add 10% padding to y-axis
),
xaxis=dict(tickformat="%Y-%m-%d"),
margin=dict(r=150), # Add right margin to make space for legend
)
# Add hover template
fig.update_traces(
hovertemplate="<b>%{y:.1f}%</b><br>Week: %{x|%Y-%m-%d}<extra></extra>"
)
return gr.Plot(
value=fig,
)
def plot_cohort_retention_heatmap(retention_matrix):
# Create a copy of the matrix to avoid modifying the original
retention_matrix = retention_matrix.copy()
# Convert index to datetime and format to date string
retention_matrix.index = pd.to_datetime(retention_matrix.index).strftime("%Y-%m-%d")
# Create figure and axes with specified size
plt.figure(figsize=(12, 8))
# Create mask for NaN values
mask = retention_matrix.isna()
# Create heatmap
ax = sns.heatmap(
data=retention_matrix,
annot=True, # Show numbers in cells
fmt=".1f", # Format numbers to 1 decimal place
cmap="YlOrRd", # Yellow to Orange to Red color scheme
vmin=0,
vmax=100,
center=50,
cbar_kws={"label": "Retention Rate (%)", "format": PercentFormatter()},
mask=mask,
annot_kws={"size": 8},
)
# Customize the plot
plt.title("Cohort Retention Analysis", pad=20, size=14)
plt.xlabel("Weeks Since First Trade", size=12)
plt.ylabel("Cohort Starting Week", size=12)
# Format week numbers on x-axis
x_labels = [f"Week {i}" for i in retention_matrix.columns]
ax.set_xticklabels(x_labels, rotation=45, ha="right")
# Set y-axis labels rotation
plt.yticks(rotation=0)
# Add gridlines
ax.set_axisbelow(True)
# Adjust layout to prevent label cutoff
plt.tight_layout()
cohort_fig = ax.get_figure()
return gr.Plot(value=cohort_fig)