Spaces:
Running
Running
File size: 2,900 Bytes
addbb37 ac8821a addbb37 f8eee5a addbb37 ac8821a db805e9 ac8821a f8eee5a ac8821a addbb37 ac8821a f8eee5a ac8821a f8eee5a addbb37 f8eee5a addbb37 f8eee5a ac8821a addbb37 f8eee5a addbb37 f8eee5a addbb37 ac8821a addbb37 ac8821a addbb37 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
import gradio as gr
import matplotlib.pyplot as plt
def plot_forecast(num_param, precision, batch_size, seq_len):
# Convert number (input as B)
num_param = float(num_param) * 1e9
# Convert precision to bytes
precision = {"float32": 4, "float16": 2, "bfloat16": 2}[precision]
# Model Parameters: N×precision
y1 = num_param * precision / 1e9
# Optimizer States: 2×N×precision
y2 = 2 * num_param * precision / 1e9
# Activations: B×Sequence Length×K×precision
K = 4.6894e-4 * num_param + 1.8494e6
y3 = batch_size * seq_len * K * precision / 1e9
# Gradients: N×precision
y4 = num_param * precision / 1e9
# Optimizer intermediates: N×precision
y5 = num_param * precision / 1e9
# Calculate total memory
total_memory = y1 + y2 + max(y3, y4 + y5)
fig = plt.figure(figsize=(4, 4))
ax = fig.add_subplot(111)
# Create stacked bars
bar_width = 0.5
ax.bar(0, y1, width=bar_width, color="r")
ax.bar(0, y2, bottom=y1, width=bar_width, color="b")
ax.bar(-bar_width / 4, y3, bottom=y1 + y2, width=bar_width / 2, color="g")
ax.bar(bar_width / 4, y4, bottom=y1 + y2, width=bar_width / 2, color="y")
ax.bar(bar_width / 4, y5, bottom=y1 + y2 + y4, width=bar_width / 2, color="c")
# Add text labels inside the bars
ax.text(0, y1 / 2, f"Model Parameters ({y1:.1f} GB)", ha="center", va="center", color="white", fontweight="bold")
ax.text(
0, y1 + y2 / 2, f"Optimizer States ({y2:.1f} GB)", ha="center", va="center", color="white", fontweight="bold"
)
ax.text(
-bar_width / 4,
y1 + y2 + y3 / 2,
f"Activations\n({y3:.1f} GB)",
ha="center",
va="center",
color="white",
fontweight="bold",
)
ax.text(
bar_width / 4,
y1 + y2 + y4 / 2,
f"Gradients\n({y4:.1f} GB)",
ha="center",
va="center",
color="white",
fontweight="bold",
)
ax.text(
bar_width / 4,
y1 + y2 + y4 + y5 / 2,
f"Optimizer\nintermediates\n({y5:.1f} GB)",
ha="center",
va="center",
color="white",
fontweight="bold",
)
# Or as title
ax.set_title(f"Total Memory: {total_memory:.1f} GB", fontweight="bold")
# Remove x-axis
ax.xaxis.set_visible(False)
# Set GB as the unit for the y-axis
ax.set_ylabel("Memory (GB)")
# Adjust layout
fig.tight_layout()
return fig
demo = gr.Interface(
plot_forecast,
[
gr.Number(3, label="Number of parameters (B)"),
gr.Radio(["float32", "float16", "bfloat16"], value="float32", label="Precision"),
gr.Slider(1, 128, label="Batch size", step=1, value=8),
gr.Slider(1, 1000, label="Sequence Length", step=1, value=256),
],
gr.Plot(label="forecast", format="png"),
)
if __name__ == "__main__":
demo.launch()
|