Spaces:
Sleeping
Sleeping
add weights parameters
Browse files
app.py
CHANGED
@@ -24,7 +24,7 @@ def list_iter_images(save_dir):
|
|
24 |
image_paths.extend(glob.glob(os.path.join(save_dir, f'*.{ext}')))
|
25 |
|
26 |
# Now image_paths contains the list of all image file paths
|
27 |
-
print(image_paths)
|
28 |
|
29 |
return image_paths
|
30 |
|
@@ -50,11 +50,12 @@ def clean_dir(save_dir):
|
|
50 |
print(f"{save_dir} does not exist.")
|
51 |
|
52 |
def start_over(gallery_state):
|
|
|
53 |
if gallery_state is not None:
|
54 |
gallery_state = None
|
55 |
return gallery_state, None, None, gr.update(visible=False)
|
56 |
|
57 |
-
def setup_model(prompt, model, num_iterations, learning_rate, progress=gr.Progress(track_tqdm=True)):
|
58 |
|
59 |
"""Clear CUDA memory before starting the training."""
|
60 |
torch.cuda.empty_cache() # Free up cached memory
|
@@ -70,14 +71,17 @@ def setup_model(prompt, model, num_iterations, learning_rate, progress=gr.Progre
|
|
70 |
args.save_dir = "./outputs"
|
71 |
args.save_all_images = True
|
72 |
|
|
|
|
|
|
|
|
|
|
|
73 |
args, trainer, device, dtype, shape, enable_grad, settings = setup(args)
|
74 |
loaded_setup = [args, trainer, device, dtype, shape, enable_grad, settings]
|
75 |
|
76 |
return None, loaded_setup
|
77 |
|
78 |
def generate_image(setup_args, num_iterations):
|
79 |
-
|
80 |
-
"""Clear CUDA memory before starting executing task."""
|
81 |
torch.cuda.empty_cache() # Free up cached memory
|
82 |
|
83 |
args = setup_args[0]
|
@@ -88,6 +92,7 @@ def generate_image(setup_args, num_iterations):
|
|
88 |
enable_grad = setup_args[5]
|
89 |
|
90 |
settings = setup_args[6]
|
|
|
91 |
|
92 |
save_dir = f"{args.save_dir}/{args.task}/{settings}/{args.prompt}"
|
93 |
clean_dir(save_dir)
|
@@ -139,7 +144,7 @@ def generate_image(setup_args, num_iterations):
|
|
139 |
|
140 |
except Exception as e:
|
141 |
torch.cuda.empty_cache() # Free up cached memory
|
142 |
-
|
143 |
|
144 |
def show_gallery_output(gallery_state):
|
145 |
if gallery_state is not None:
|
@@ -179,6 +184,13 @@ with gr.Blocks() as demo:
|
|
179 |
n_iter = gr.Slider(minimum=10, maximum=100, step=10, value=50, label="Number of Iterations")
|
180 |
learning_rate = gr.Slider(minimum=0.1, maximum=10.0, step=0.1, value=5.0, label="Learning Rate")
|
181 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
submit_btn = gr.Button("Submit")
|
183 |
|
184 |
gr.Examples(
|
@@ -204,7 +216,7 @@ with gr.Blocks() as demo:
|
|
204 |
outputs = [gallery_state, output_image, status, iter_gallery]
|
205 |
).then(
|
206 |
fn = setup_model,
|
207 |
-
inputs = [prompt, chosen_model, n_iter, learning_rate],
|
208 |
outputs = [output_image, loaded_model_setup]
|
209 |
).then(
|
210 |
fn = generate_image,
|
|
|
24 |
image_paths.extend(glob.glob(os.path.join(save_dir, f'*.{ext}')))
|
25 |
|
26 |
# Now image_paths contains the list of all image file paths
|
27 |
+
#print(image_paths)
|
28 |
|
29 |
return image_paths
|
30 |
|
|
|
50 |
print(f"{save_dir} does not exist.")
|
51 |
|
52 |
def start_over(gallery_state):
|
53 |
+
torch.cuda.empty_cache() # Free up cached memory
|
54 |
if gallery_state is not None:
|
55 |
gallery_state = None
|
56 |
return gallery_state, None, None, gr.update(visible=False)
|
57 |
|
58 |
+
def setup_model(prompt, model, num_iterations, learning_rate, hps_w, imgrw_w, pcks_w, clip_w, progress=gr.Progress(track_tqdm=True)):
|
59 |
|
60 |
"""Clear CUDA memory before starting the training."""
|
61 |
torch.cuda.empty_cache() # Free up cached memory
|
|
|
71 |
args.save_dir = "./outputs"
|
72 |
args.save_all_images = True
|
73 |
|
74 |
+
args.hps_weighting = hps_w
|
75 |
+
args.imagereward_weighting = imgrw_w
|
76 |
+
args.pickscore_weighting = pcks_w
|
77 |
+
args.clip_weighting = clip_w
|
78 |
+
|
79 |
args, trainer, device, dtype, shape, enable_grad, settings = setup(args)
|
80 |
loaded_setup = [args, trainer, device, dtype, shape, enable_grad, settings]
|
81 |
|
82 |
return None, loaded_setup
|
83 |
|
84 |
def generate_image(setup_args, num_iterations):
|
|
|
|
|
85 |
torch.cuda.empty_cache() # Free up cached memory
|
86 |
|
87 |
args = setup_args[0]
|
|
|
92 |
enable_grad = setup_args[5]
|
93 |
|
94 |
settings = setup_args[6]
|
95 |
+
print(f"SETTINGS: {settings}")
|
96 |
|
97 |
save_dir = f"{args.save_dir}/{args.task}/{settings}/{args.prompt}"
|
98 |
clean_dir(save_dir)
|
|
|
144 |
|
145 |
except Exception as e:
|
146 |
torch.cuda.empty_cache() # Free up cached memory
|
147 |
+
return (None, f"An error occurred: {str(e)}", None)
|
148 |
|
149 |
def show_gallery_output(gallery_state):
|
150 |
if gallery_state is not None:
|
|
|
184 |
n_iter = gr.Slider(minimum=10, maximum=100, step=10, value=50, label="Number of Iterations")
|
185 |
learning_rate = gr.Slider(minimum=0.1, maximum=10.0, step=0.1, value=5.0, label="Learning Rate")
|
186 |
|
187 |
+
with gr.Accordion("Advanced Settings", open=False):
|
188 |
+
with gr.Column():
|
189 |
+
hps_w = gr.Slider(label="HPS weight", step=0.1, minimum=0.0, maximum=10.0, value=5.0)
|
190 |
+
imgrw_w = gr.Slider(label="ImageReward weight", step=0.1, minimum=0, maximum=5.0, value=1.0)
|
191 |
+
pcks_w = gr.Slider(label="PickScore weight", step=0.01, minimum=0, maximum=5.0, value=0.05)
|
192 |
+
clip_w = gr.Slider(label="CLIP weight", step=0.01, minimum=0, maximum=0.1, value=0.01)
|
193 |
+
|
194 |
submit_btn = gr.Button("Submit")
|
195 |
|
196 |
gr.Examples(
|
|
|
216 |
outputs = [gallery_state, output_image, status, iter_gallery]
|
217 |
).then(
|
218 |
fn = setup_model,
|
219 |
+
inputs = [prompt, chosen_model, n_iter, hps_w, imgrw_w, pcks_w, clip_w, learning_rate],
|
220 |
outputs = [output_image, loaded_model_setup]
|
221 |
).then(
|
222 |
fn = generate_image,
|