cyberosa commited on
Commit
87eca50
·
1 Parent(s): 8834fdb

Retention plots prototypes

Browse files
Files changed (1) hide show
  1. tabs/retention_plots.py +98 -0
tabs/retention_plots.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import plotly.express as px
2
+ import plotly.graph_objects as go
3
+ import seaborn as sns
4
+ import matplotlib.pyplot as plt
5
+ from matplotlib.ticker import PercentFormatter
6
+
7
+
8
+ def plot_wow_retention_by_type(wow_retention):
9
+ wow_retention["week"] = pd.to_datetime(wow_retention["week"])
10
+ wow_retention = wow_retention.sort_values(["trader_type", "week"])
11
+ fig = px.line(
12
+ wow_retention,
13
+ x="week",
14
+ y="retention_rate",
15
+ color="trader_type",
16
+ markers=True,
17
+ title="Weekly Retention Rate by Trader Type",
18
+ labels={
19
+ "week": "Week",
20
+ "retention_rate": "Retention Rate (%)",
21
+ "trader_type": "Trader Type",
22
+ },
23
+ )
24
+
25
+ fig.update_layout(
26
+ hovermode="x unified",
27
+ legend=dict(
28
+ yanchor="middle",
29
+ y=0.5,
30
+ xanchor="left",
31
+ x=1.02, # Move legend outside
32
+ orientation="v",
33
+ ),
34
+ yaxis=dict(
35
+ ticksuffix="%",
36
+ range=[
37
+ 0,
38
+ max(wow_retention["retention_rate"]) * 1.1,
39
+ ], # Add 10% padding to y-axis
40
+ ),
41
+ xaxis=dict(tickformat="%Y-%m-%d"),
42
+ margin=dict(r=150), # Add right margin to make space for legend
43
+ )
44
+
45
+ # Add hover template
46
+ fig.update_traces(
47
+ hovertemplate="<b>%{y:.1f}%</b><br>Week: %{x|%Y-%m-%d}<extra></extra>"
48
+ )
49
+
50
+ return fig
51
+
52
+
53
+ def plot_cohort_retention_heatmap(retention_matrix):
54
+ # Create a copy of the matrix to avoid modifying the original
55
+ retention_matrix = retention_matrix.copy()
56
+
57
+ # Convert index to datetime and format to date string
58
+ retention_matrix.index = pd.to_datetime(retention_matrix.index).strftime("%Y-%m-%d")
59
+
60
+ # Create figure and axes with specified size
61
+ plt.figure(figsize=(12, 8))
62
+
63
+ # Create mask for NaN values
64
+ mask = retention_matrix.isna()
65
+
66
+ # Create heatmap
67
+ ax = sns.heatmap(
68
+ data=retention_matrix,
69
+ annot=True, # Show numbers in cells
70
+ fmt=".1f", # Format numbers to 1 decimal place
71
+ cmap="YlOrRd", # Yellow to Orange to Red color scheme
72
+ vmin=0,
73
+ vmax=100,
74
+ center=50,
75
+ cbar_kws={"label": "Retention Rate (%)", "format": PercentFormatter()},
76
+ mask=mask,
77
+ annot_kws={"size": 8},
78
+ )
79
+
80
+ # Customize the plot
81
+ plt.title("Cohort Retention Analysis", pad=20, size=14)
82
+ plt.xlabel("Weeks Since First Trade", size=12)
83
+ plt.ylabel("Cohort Starting Week", size=12)
84
+
85
+ # Format week numbers on x-axis
86
+ x_labels = [f"Week {i}" for i in retention_matrix.columns]
87
+ ax.set_xticklabels(x_labels, rotation=45, ha="right")
88
+
89
+ # Set y-axis labels rotation
90
+ plt.yticks(rotation=0)
91
+
92
+ # Add gridlines
93
+ ax.set_axisbelow(True)
94
+
95
+ # Adjust layout to prevent label cutoff
96
+ plt.tight_layout()
97
+
98
+ return plt