qgallouedec HF staff commited on
Commit
2d0e923
·
verified ·
1 Parent(s): 5ee5935

Add num_gpus

Browse files
Files changed (1) hide show
  1. app.py +4 -2
app.py CHANGED
@@ -2,7 +2,7 @@ import gradio as gr
2
  import matplotlib.pyplot as plt
3
 
4
 
5
- def plot_forecast(num_param, precision, grad_ckpt, batch_size, seq_len):
6
  # Convert number (input as B)
7
  num_param = float(num_param) * 1e9
8
 
@@ -100,6 +100,8 @@ with gr.Blocks() as demo:
100
  seq_len = gr.Slider(1, 1000, label="Sequence Length", step=1, value=256)
101
 
102
  with gr.Accordion("Advanced", open=False):
 
 
103
  with gr.Accordion("Data"):
104
  grad_ckpt = gr.Checkbox(False, label="Gradient Checkpointing")
105
 
@@ -108,7 +110,7 @@ with gr.Blocks() as demo:
108
  with gr.Column():
109
  plot = gr.Plot(label="forecast", format="png")
110
 
111
- submit.click(plot_forecast, [num_param, precision, grad_ckpt, batch_size, seq_len], plot)
112
 
113
  if __name__ == "__main__":
114
  demo.launch()
 
2
  import matplotlib.pyplot as plt
3
 
4
 
5
+ def plot_forecast(num_param, precision, grad_ckpt, batch_size, seq_len, num_gpus):
6
  # Convert number (input as B)
7
  num_param = float(num_param) * 1e9
8
 
 
100
  seq_len = gr.Slider(1, 1000, label="Sequence Length", step=1, value=256)
101
 
102
  with gr.Accordion("Advanced", open=False):
103
+ with gr.Accordion("Hardware"):
104
+ num_gpus = gr.Number(1, label="Number of GPUs")
105
  with gr.Accordion("Data"):
106
  grad_ckpt = gr.Checkbox(False, label="Gradient Checkpointing")
107
 
 
110
  with gr.Column():
111
  plot = gr.Plot(label="forecast", format="png")
112
 
113
+ submit.click(plot_forecast, [num_param, precision, grad_ckpt, batch_size, seq_len, num_gpus], plot)
114
 
115
  if __name__ == "__main__":
116
  demo.launch()