fffiloni commited on
Commit
0fae3db
·
verified ·
1 Parent(s): 67611b7

add weights parameters

Browse files
Files changed (1) hide show
  1. app.py +18 -6
app.py CHANGED
@@ -24,7 +24,7 @@ def list_iter_images(save_dir):
24
  image_paths.extend(glob.glob(os.path.join(save_dir, f'*.{ext}')))
25
 
26
  # Now image_paths contains the list of all image file paths
27
- print(image_paths)
28
 
29
  return image_paths
30
 
@@ -50,11 +50,12 @@ def clean_dir(save_dir):
50
  print(f"{save_dir} does not exist.")
51
 
52
  def start_over(gallery_state):
 
53
  if gallery_state is not None:
54
  gallery_state = None
55
  return gallery_state, None, None, gr.update(visible=False)
56
 
57
- def setup_model(prompt, model, num_iterations, learning_rate, progress=gr.Progress(track_tqdm=True)):
58
 
59
  """Clear CUDA memory before starting the training."""
60
  torch.cuda.empty_cache() # Free up cached memory
@@ -70,14 +71,17 @@ def setup_model(prompt, model, num_iterations, learning_rate, progress=gr.Progre
70
  args.save_dir = "./outputs"
71
  args.save_all_images = True
72
 
 
 
 
 
 
73
  args, trainer, device, dtype, shape, enable_grad, settings = setup(args)
74
  loaded_setup = [args, trainer, device, dtype, shape, enable_grad, settings]
75
 
76
  return None, loaded_setup
77
 
78
  def generate_image(setup_args, num_iterations):
79
-
80
- """Clear CUDA memory before starting executing task."""
81
  torch.cuda.empty_cache() # Free up cached memory
82
 
83
  args = setup_args[0]
@@ -88,6 +92,7 @@ def generate_image(setup_args, num_iterations):
88
  enable_grad = setup_args[5]
89
 
90
  settings = setup_args[6]
 
91
 
92
  save_dir = f"{args.save_dir}/{args.task}/{settings}/{args.prompt}"
93
  clean_dir(save_dir)
@@ -139,7 +144,7 @@ def generate_image(setup_args, num_iterations):
139
 
140
  except Exception as e:
141
  torch.cuda.empty_cache() # Free up cached memory
142
- yield (None, f"An error occurred: {str(e)}", None)
143
 
144
  def show_gallery_output(gallery_state):
145
  if gallery_state is not None:
@@ -179,6 +184,13 @@ with gr.Blocks() as demo:
179
  n_iter = gr.Slider(minimum=10, maximum=100, step=10, value=50, label="Number of Iterations")
180
  learning_rate = gr.Slider(minimum=0.1, maximum=10.0, step=0.1, value=5.0, label="Learning Rate")
181
 
 
 
 
 
 
 
 
182
  submit_btn = gr.Button("Submit")
183
 
184
  gr.Examples(
@@ -204,7 +216,7 @@ with gr.Blocks() as demo:
204
  outputs = [gallery_state, output_image, status, iter_gallery]
205
  ).then(
206
  fn = setup_model,
207
- inputs = [prompt, chosen_model, n_iter, learning_rate],
208
  outputs = [output_image, loaded_model_setup]
209
  ).then(
210
  fn = generate_image,
 
24
  image_paths.extend(glob.glob(os.path.join(save_dir, f'*.{ext}')))
25
 
26
  # Now image_paths contains the list of all image file paths
27
+ #print(image_paths)
28
 
29
  return image_paths
30
 
 
50
  print(f"{save_dir} does not exist.")
51
 
52
  def start_over(gallery_state):
53
+ torch.cuda.empty_cache() # Free up cached memory
54
  if gallery_state is not None:
55
  gallery_state = None
56
  return gallery_state, None, None, gr.update(visible=False)
57
 
58
+ def setup_model(prompt, model, num_iterations, learning_rate, hps_w, imgrw_w, pcks_w, clip_w, progress=gr.Progress(track_tqdm=True)):
59
 
60
  """Clear CUDA memory before starting the training."""
61
  torch.cuda.empty_cache() # Free up cached memory
 
71
  args.save_dir = "./outputs"
72
  args.save_all_images = True
73
 
74
+ args.hps_weighting = hps_w
75
+ args.imagereward_weighting = imgrw_w
76
+ args.pickscore_weighting = pcks_w
77
+ args.clip_weighting = clip_w
78
+
79
  args, trainer, device, dtype, shape, enable_grad, settings = setup(args)
80
  loaded_setup = [args, trainer, device, dtype, shape, enable_grad, settings]
81
 
82
  return None, loaded_setup
83
 
84
  def generate_image(setup_args, num_iterations):
 
 
85
  torch.cuda.empty_cache() # Free up cached memory
86
 
87
  args = setup_args[0]
 
92
  enable_grad = setup_args[5]
93
 
94
  settings = setup_args[6]
95
+ print(f"SETTINGS: {settings}")
96
 
97
  save_dir = f"{args.save_dir}/{args.task}/{settings}/{args.prompt}"
98
  clean_dir(save_dir)
 
144
 
145
  except Exception as e:
146
  torch.cuda.empty_cache() # Free up cached memory
147
+ return (None, f"An error occurred: {str(e)}", None)
148
 
149
  def show_gallery_output(gallery_state):
150
  if gallery_state is not None:
 
184
  n_iter = gr.Slider(minimum=10, maximum=100, step=10, value=50, label="Number of Iterations")
185
  learning_rate = gr.Slider(minimum=0.1, maximum=10.0, step=0.1, value=5.0, label="Learning Rate")
186
 
187
+ with gr.Accordion("Advanced Settings", open=False):
188
+ with gr.Column():
189
+ hps_w = gr.Slider(label="HPS weight", step=0.1, minimum=0.0, maximum=10.0, value=5.0)
190
+ imgrw_w = gr.Slider(label="ImageReward weight", step=0.1, minimum=0, maximum=5.0, value=1.0)
191
+ pcks_w = gr.Slider(label="PickScore weight", step=0.01, minimum=0, maximum=5.0, value=0.05)
192
+ clip_w = gr.Slider(label="CLIP weight", step=0.01, minimum=0, maximum=0.1, value=0.01)
193
+
194
  submit_btn = gr.Button("Submit")
195
 
196
  gr.Examples(
 
216
  outputs = [gallery_state, output_image, status, iter_gallery]
217
  ).then(
218
  fn = setup_model,
219
+ inputs = [prompt, chosen_model, n_iter, hps_w, imgrw_w, pcks_w, clip_w, learning_rate],
220
  outputs = [output_image, loaded_model_setup]
221
  ).then(
222
  fn = generate_image,