fffiloni commited on
Commit
239dc60
·
verified ·
1 Parent(s): c6e7b6c

gradio can track iterations callbacks for single task

Browse files
Files changed (1) hide show
  1. app.py +7 -4
app.py CHANGED
@@ -3,7 +3,7 @@ from main import main
3
  from arguments import parse_args
4
  import os
5
 
6
- def generate_image(prompt, model, num_iterations, learning_rate, progress = gr.Progress(track_tqdm=True)):
7
  # Set up arguments
8
  args = parse_args()
9
  args.task = "single"
@@ -16,8 +16,11 @@ def generate_image(prompt, model, num_iterations, learning_rate, progress = gr.P
16
  args.save_all_images = True
17
 
18
  try:
19
- # Run the main function
20
- main(args)
 
 
 
21
 
22
  settings = (
23
  f"{args.model}{'_' + args.prompt if args.task == 't2i-compbench' else ''}"
@@ -66,7 +69,7 @@ with gr.Blocks() as demo:
66
  with gr.Row():
67
  with gr.Column():
68
  prompt = gr.Textbox(label="Prompt")
69
- chosen_model = gr.Dropdown(["sd-turbo", "sdxl-turbo", "pixart", "hyper-sd"], label="Model")
70
 
71
  with gr.Row():
72
  n_iter = gr.Slider(minimum=10, maximum=100, step=10, value=50, label="Number of Iterations")
 
3
  from arguments import parse_args
4
  import os
5
 
6
+ def generate_image(prompt, model, num_iterations, learning_rate, progress=gr.Progress(track_tqdm=True)):
7
  # Set up arguments
8
  args = parse_args()
9
  args.task = "single"
 
16
  args.save_all_images = True
17
 
18
  try:
19
+ # Run the main function with progress tracking
20
+ def progress_callback(step):
21
+ progress(step / num_iterations, f"Iteration {step}/{num_iterations}")
22
+
23
+ best_image, total_init_rewards, total_best_rewards = main(args, progress_callback)
24
 
25
  settings = (
26
  f"{args.model}{'_' + args.prompt if args.task == 't2i-compbench' else ''}"
 
69
  with gr.Row():
70
  with gr.Column():
71
  prompt = gr.Textbox(label="Prompt")
72
+ chosen_model = gr.Dropdown(["sd-turbo", "sdxl-turbo", "pixart", "hyper-sd"], label="Model", value="sd-turbo")
73
 
74
  with gr.Row():
75
  n_iter = gr.Slider(minimum=10, maximum=100, step=10, value=50, label="Number of Iterations")