import plotly.express as px import plotly.graph_objects as go import seaborn as sns import matplotlib.pyplot as plt from matplotlib.ticker import PercentFormatter def plot_wow_retention_by_type(wow_retention): wow_retention["week"] = pd.to_datetime(wow_retention["week"]) wow_retention = wow_retention.sort_values(["trader_type", "week"]) fig = px.line( wow_retention, x="week", y="retention_rate", color="trader_type", markers=True, title="Weekly Retention Rate by Trader Type", labels={ "week": "Week", "retention_rate": "Retention Rate (%)", "trader_type": "Trader Type", }, ) fig.update_layout( hovermode="x unified", legend=dict( yanchor="middle", y=0.5, xanchor="left", x=1.02, # Move legend outside orientation="v", ), yaxis=dict( ticksuffix="%", range=[ 0, max(wow_retention["retention_rate"]) * 1.1, ], # Add 10% padding to y-axis ), xaxis=dict(tickformat="%Y-%m-%d"), margin=dict(r=150), # Add right margin to make space for legend ) # Add hover template fig.update_traces( hovertemplate="%{y:.1f}%
Week: %{x|%Y-%m-%d}" ) return fig def plot_cohort_retention_heatmap(retention_matrix): # Create a copy of the matrix to avoid modifying the original retention_matrix = retention_matrix.copy() # Convert index to datetime and format to date string retention_matrix.index = pd.to_datetime(retention_matrix.index).strftime("%Y-%m-%d") # Create figure and axes with specified size plt.figure(figsize=(12, 8)) # Create mask for NaN values mask = retention_matrix.isna() # Create heatmap ax = sns.heatmap( data=retention_matrix, annot=True, # Show numbers in cells fmt=".1f", # Format numbers to 1 decimal place cmap="YlOrRd", # Yellow to Orange to Red color scheme vmin=0, vmax=100, center=50, cbar_kws={"label": "Retention Rate (%)", "format": PercentFormatter()}, mask=mask, annot_kws={"size": 8}, ) # Customize the plot plt.title("Cohort Retention Analysis", pad=20, size=14) plt.xlabel("Weeks Since First Trade", size=12) plt.ylabel("Cohort Starting Week", size=12) # Format week numbers on x-axis x_labels = [f"Week {i}" for i in retention_matrix.columns] ax.set_xticklabels(x_labels, rotation=45, ha="right") # Set y-axis labels rotation plt.yticks(rotation=0) # Add gridlines ax.set_axisbelow(True) # Adjust layout to prevent label cutoff plt.tight_layout() return plt