File size: 2,024 Bytes
addbb37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
db805e9
 
 
 
addbb37
 
c98b4d9
addbb37
 
 
 
 
 
 
 
 
 
 
 
db805e9
 
addbb37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import gradio as gr
import matplotlib.pyplot as plt


def plot_forecast(num_param, batch_size, precision, seq_len):
    # Convert number (input as B)
    num_param = float(num_param) * 1e9
    
    # Convert precision to bytes
    precision = {"float32": 4, "float16": 2, "bfloat16": 2}[precision]

    # Model Parameters: N×precision
    y1 = num_param * precision / (1024**3)

    # Activations: B×Sequence Length×K×precision
    K = 4.6894e-04 * num_param + 1.8494e06
    y2 = batch_size * seq_len * K * precision / (1024**3)

    # Optimizer States: 2×N×precision
    y3 = 2 * num_param * precision / (1024**3)

    # Gradients: N×precision
    y4 = num_param * 4 / (1024**3)

    fig = plt.figure(figsize=(4, 4))
    ax = fig.add_subplot(111)

    # Create stacked bars
    ax.bar(0, y1, color="r")
    ax.bar(0, y2, bottom=y1, color="b")
    ax.bar(0, y3, bottom=y1 + y2, color="g")
    ax.bar(0, y4, bottom=y1 + y2 + y3, color="y")

    # Add text labels inside the bars
    ax.text(0, y1 / 2, "Model Parameters", ha="center", va="center", color="white", fontweight="bold")
    ax.text(0, y1 + y2 / 2,"Activations", ha="center", va="center", color="white", fontweight="bold")
    ax.text(0, y1 + y2 + y3 / 2, "Optimizer States", ha="center", va="center", color="white", fontweight="bold")
    ax.text(0, y1 + y2 + y3 + y4 / 2, "Gradients", ha="center", va="center", color="white", fontweight="bold")

    # remove x axis
    ax.xaxis.set_visible(False)

    # Set GB as the unit for the y-axis
    ax.set_ylabel("Memory (GB)")
    fig.tight_layout()
    return fig


demo = gr.Interface(
    plot_forecast,
    [
        gr.Number(7, label="Number of parameters (B)"),
        gr.Radio([1, 2, 4, 8, 16, 32, 64, 128], value=8, label="Batch size"),
        gr.Radio(["float32", "float16", "bfloat16"], value="float32", label="Precision"),
        gr.Slider(1, 1024, label="Sequence Length", step=1, value=128),
    ],
    gr.Plot(label="forecast", format="png"),
)

if __name__ == "__main__":
    demo.launch()