File size: 2,963 Bytes
87eca50
355fb10
87eca50
 
355fb10
87eca50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
355fb10
 
 
87eca50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
355fb10
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import plotly.express as px
import gradio as gr
import plotly.graph_objects as go
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.ticker import PercentFormatter


def plot_wow_retention_by_type(wow_retention):
    wow_retention["week"] = pd.to_datetime(wow_retention["week"])
    wow_retention = wow_retention.sort_values(["trader_type", "week"])
    fig = px.line(
        wow_retention,
        x="week",
        y="retention_rate",
        color="trader_type",
        markers=True,
        title="Weekly Retention Rate by Trader Type",
        labels={
            "week": "Week",
            "retention_rate": "Retention Rate (%)",
            "trader_type": "Trader Type",
        },
    )

    fig.update_layout(
        hovermode="x unified",
        legend=dict(
            yanchor="middle",
            y=0.5,
            xanchor="left",
            x=1.02,  # Move legend outside
            orientation="v",
        ),
        yaxis=dict(
            ticksuffix="%",
            range=[
                0,
                max(wow_retention["retention_rate"]) * 1.1,
            ],  # Add 10% padding to y-axis
        ),
        xaxis=dict(tickformat="%Y-%m-%d"),
        margin=dict(r=150),  # Add right margin to make space for legend
    )

    # Add hover template
    fig.update_traces(
        hovertemplate="<b>%{y:.1f}%</b><br>Week: %{x|%Y-%m-%d}<extra></extra>"
    )
    return gr.Plot(
        value=fig,
    )


def plot_cohort_retention_heatmap(retention_matrix):
    # Create a copy of the matrix to avoid modifying the original
    retention_matrix = retention_matrix.copy()

    # Convert index to datetime and format to date string
    retention_matrix.index = pd.to_datetime(retention_matrix.index).strftime("%Y-%m-%d")

    # Create figure and axes with specified size
    plt.figure(figsize=(12, 8))

    # Create mask for NaN values
    mask = retention_matrix.isna()

    # Create heatmap
    ax = sns.heatmap(
        data=retention_matrix,
        annot=True,  # Show numbers in cells
        fmt=".1f",  # Format numbers to 1 decimal place
        cmap="YlOrRd",  # Yellow to Orange to Red color scheme
        vmin=0,
        vmax=100,
        center=50,
        cbar_kws={"label": "Retention Rate (%)", "format": PercentFormatter()},
        mask=mask,
        annot_kws={"size": 8},
    )

    # Customize the plot
    plt.title("Cohort Retention Analysis", pad=20, size=14)
    plt.xlabel("Weeks Since First Trade", size=12)
    plt.ylabel("Cohort Starting Week", size=12)

    # Format week numbers on x-axis
    x_labels = [f"Week {i}" for i in retention_matrix.columns]
    ax.set_xticklabels(x_labels, rotation=45, ha="right")

    # Set y-axis labels rotation
    plt.yticks(rotation=0)

    # Add gridlines
    ax.set_axisbelow(True)

    # Adjust layout to prevent label cutoff
    plt.tight_layout()

    cohort_fig = ax.get_figure()
    return gr.Plot(value=cohort_fig)