Spaces:
Runtime error
Runtime error
1lint
commited on
Commit
·
0d0a1c2
1
Parent(s):
34d5395
fix and revise app
Browse files- src/app.py +32 -28
- src/controlnet_pipe.py +1 -1
- src/lab.py +10 -2
- src/ui_assets/footer.MD +3 -0
- src/ui_assets/footer.html +0 -9
- src/ui_assets/header.MD +8 -0
- src/ui_assets/header.html +0 -23
- src/ui_functions.py +7 -8
src/app.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import gradio as gr
|
2 |
from multiprocessing import cpu_count
|
3 |
-
|
4 |
from src.ui_shared import (
|
5 |
model_ids,
|
6 |
scheduler_names,
|
@@ -13,10 +13,12 @@ from src.ui_functions import generate, run_training
|
|
13 |
|
14 |
default_img_size = 512
|
15 |
|
16 |
-
|
|
|
17 |
header = fp.read()
|
18 |
|
19 |
-
|
|
|
20 |
footer = fp.read()
|
21 |
|
22 |
|
@@ -25,13 +27,11 @@ theme = gr.themes.Soft(
|
|
25 |
neutral_hue="slate",
|
26 |
)
|
27 |
|
28 |
-
from gradio.themes.builder_app import css
|
29 |
-
|
30 |
with gr.Blocks(theme=theme) as demo:
|
31 |
|
32 |
-
gr.
|
33 |
|
34 |
-
with gr.Row():
|
35 |
with gr.Column(scale=70):
|
36 |
prompt = gr.Textbox(
|
37 |
label="Prompt", placeholder="Press <Shift+Enter> to generate", lines=2
|
@@ -53,15 +53,17 @@ with gr.Blocks(theme=theme) as demo:
|
|
53 |
|
54 |
with gr.Column(scale=30):
|
55 |
model_name = gr.Dropdown(
|
56 |
-
label="Model", choices=model_ids, value=model_ids[0]
|
57 |
)
|
58 |
controlnet_name = gr.Dropdown(
|
59 |
-
label="Controlnet", choices=controlnet_ids, value=controlnet_ids[0]
|
60 |
)
|
61 |
scheduler_name = gr.Dropdown(
|
62 |
-
label="Scheduler", choices=scheduler_names, value=default_scheduler
|
63 |
)
|
64 |
-
|
|
|
|
|
65 |
|
66 |
with gr.Row():
|
67 |
with gr.Column():
|
@@ -114,7 +116,7 @@ with gr.Blocks(theme=theme) as demo:
|
|
114 |
)
|
115 |
|
116 |
|
117 |
-
with gr.Tab("Train
|
118 |
with gr.Row():
|
119 |
train_batch_size = gr.Slider(
|
120 |
label="Training Batch Size",
|
@@ -133,8 +135,8 @@ with gr.Blocks(theme=theme) as demo:
|
|
133 |
)
|
134 |
|
135 |
with gr.Row():
|
136 |
-
|
137 |
-
label="Total training
|
138 |
)
|
139 |
train_learning_rate = gr.Number(label="Learning Rate", value=5.0e-6)
|
140 |
|
@@ -160,7 +162,7 @@ with gr.Blocks(theme=theme) as demo:
|
|
160 |
with gr.Row():
|
161 |
controlnet_weights_path = gr.Textbox(
|
162 |
label=f"Repo for initializing Controlnet Weights",
|
163 |
-
value="
|
164 |
)
|
165 |
output_dir = gr.Textbox(
|
166 |
label=f"Output directory for trained weights", value="./models"
|
@@ -195,7 +197,7 @@ with gr.Blocks(theme=theme) as demo:
|
|
195 |
# vram_guage = gr.Slider(0, torch.cuda.memory_reserved(0)/giga, label='VRAM Allocated to Reserved (GB)', value=0, step=1)
|
196 |
# demo.load(lambda : torch.cuda.memory_allocated(0)/giga, inputs=[], outputs=vram_guage, every=0.5, show_progress=False)
|
197 |
|
198 |
-
|
199 |
|
200 |
inputs = [
|
201 |
model_name,
|
@@ -228,7 +230,7 @@ with gr.Blocks(theme=theme) as demo:
|
|
228 |
train_batch_size,
|
229 |
train_whole_controlnet,
|
230 |
gradient_accumulation_steps,
|
231 |
-
|
232 |
train_learning_rate,
|
233 |
output_dir,
|
234 |
checkpointing_steps,
|
@@ -242,19 +244,21 @@ with gr.Blocks(theme=theme) as demo:
|
|
242 |
outputs=[training_status],
|
243 |
)
|
244 |
|
245 |
-
# from gradio.themes.
|
246 |
-
|
247 |
-
None,
|
248 |
-
None,
|
249 |
-
None,
|
250 |
_js="""() => {
|
251 |
if (document.querySelectorAll('.dark').length) {
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
)
|
|
|
|
|
258 |
|
259 |
if __name__ == "__main__":
|
260 |
-
demo.queue(concurrency_count=cpu_count()).launch()
|
|
|
1 |
import gradio as gr
|
2 |
from multiprocessing import cpu_count
|
3 |
+
from pathlib import Path
|
4 |
from src.ui_shared import (
|
5 |
model_ids,
|
6 |
scheduler_names,
|
|
|
13 |
|
14 |
default_img_size = 512
|
15 |
|
16 |
+
|
17 |
+
with open(f"{assets_directory}/header.MD") as fp:
|
18 |
header = fp.read()
|
19 |
|
20 |
+
|
21 |
+
with open(f"{assets_directory}/footer.MD") as fp:
|
22 |
footer = fp.read()
|
23 |
|
24 |
|
|
|
27 |
neutral_hue="slate",
|
28 |
)
|
29 |
|
|
|
|
|
30 |
with gr.Blocks(theme=theme) as demo:
|
31 |
|
32 |
+
header_component = gr.Markdown(header)
|
33 |
|
34 |
+
with gr.Row().style(equal_height=True):
|
35 |
with gr.Column(scale=70):
|
36 |
prompt = gr.Textbox(
|
37 |
label="Prompt", placeholder="Press <Shift+Enter> to generate", lines=2
|
|
|
53 |
|
54 |
with gr.Column(scale=30):
|
55 |
model_name = gr.Dropdown(
|
56 |
+
label="Model", choices=model_ids, value=model_ids[0], allow_custom_value=True
|
57 |
)
|
58 |
controlnet_name = gr.Dropdown(
|
59 |
+
label="Controlnet", choices=controlnet_ids, value=controlnet_ids[0], allow_custom_value=True
|
60 |
)
|
61 |
scheduler_name = gr.Dropdown(
|
62 |
+
label="Scheduler", choices=scheduler_names, value=default_scheduler, allow_custom_value=True
|
63 |
)
|
64 |
+
with gr.Row():
|
65 |
+
generate_button = gr.Button(value="Generate", variant="primary")
|
66 |
+
dark_mode_btn = gr.Button("Dark Mode", variant="secondary")
|
67 |
|
68 |
with gr.Row():
|
69 |
with gr.Column():
|
|
|
116 |
)
|
117 |
|
118 |
|
119 |
+
with gr.Tab("Train Anime ControlNet") as tab:
|
120 |
with gr.Row():
|
121 |
train_batch_size = gr.Slider(
|
122 |
label="Training Batch Size",
|
|
|
135 |
)
|
136 |
|
137 |
with gr.Row():
|
138 |
+
num_train_epochs = gr.Number(
|
139 |
+
label="Total training epochs", value=2
|
140 |
)
|
141 |
train_learning_rate = gr.Number(label="Learning Rate", value=5.0e-6)
|
142 |
|
|
|
162 |
with gr.Row():
|
163 |
controlnet_weights_path = gr.Textbox(
|
164 |
label=f"Repo for initializing Controlnet Weights",
|
165 |
+
value="lint/anime_control/anime_merge",
|
166 |
)
|
167 |
output_dir = gr.Textbox(
|
168 |
label=f"Output directory for trained weights", value="./models"
|
|
|
197 |
# vram_guage = gr.Slider(0, torch.cuda.memory_reserved(0)/giga, label='VRAM Allocated to Reserved (GB)', value=0, step=1)
|
198 |
# demo.load(lambda : torch.cuda.memory_allocated(0)/giga, inputs=[], outputs=vram_guage, every=0.5, show_progress=False)
|
199 |
|
200 |
+
footer_component = gr.Markdown(footer)
|
201 |
|
202 |
inputs = [
|
203 |
model_name,
|
|
|
230 |
train_batch_size,
|
231 |
train_whole_controlnet,
|
232 |
gradient_accumulation_steps,
|
233 |
+
num_train_epochs,
|
234 |
train_learning_rate,
|
235 |
output_dir,
|
236 |
checkpointing_steps,
|
|
|
244 |
outputs=[training_status],
|
245 |
)
|
246 |
|
247 |
+
# from gradio.themes.builder
|
248 |
+
toggle_dark_mode_args = dict(
|
249 |
+
fn=None,
|
250 |
+
inputs=None,
|
251 |
+
outputs=None,
|
252 |
_js="""() => {
|
253 |
if (document.querySelectorAll('.dark').length) {
|
254 |
+
document.querySelectorAll('.dark').forEach(el => el.classList.remove('dark'));
|
255 |
+
} else {
|
256 |
+
document.querySelector('body').classList.add('dark');
|
257 |
+
}
|
258 |
+
}""",
|
259 |
)
|
260 |
+
demo.load(**toggle_dark_mode_args)
|
261 |
+
dark_mode_btn.click(**toggle_dark_mode_args)
|
262 |
|
263 |
if __name__ == "__main__":
|
264 |
+
demo.queue(concurrency_count=cpu_count()).launch(favicon_path=favicon_path)
|
src/controlnet_pipe.py
CHANGED
@@ -212,7 +212,7 @@ class ControlNetPipe(StableDiffusionControlNetPipeline):
|
|
212 |
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
213 |
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
214 |
|
215 |
-
if
|
216 |
controlnet_prompt_embeds = prompt_embeds
|
217 |
|
218 |
# 8. Denoising loop
|
|
|
212 |
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
213 |
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
214 |
|
215 |
+
if controlnet_prompt_embeds is None:
|
216 |
controlnet_prompt_embeds = prompt_embeds
|
217 |
|
218 |
# 8. Denoising loop
|
src/lab.py
CHANGED
@@ -372,9 +372,12 @@ class Lab(Accelerator):
|
|
372 |
str(save_dir / "unet"), safe_serialization=to_safetensors
|
373 |
)
|
374 |
|
375 |
-
def train(self, num_train_epochs=1000):
|
376 |
args = self.args
|
377 |
|
|
|
|
|
|
|
378 |
max_train_steps = (
|
379 |
num_train_epochs
|
380 |
* len(self.train_dataloader)
|
@@ -386,6 +389,7 @@ class Lab(Accelerator):
|
|
386 |
|
387 |
self.global_step = 0
|
388 |
|
|
|
389 |
# Only show the progress bar once on each machine.
|
390 |
progress_bar = tqdm(
|
391 |
range(max_train_steps),
|
@@ -396,11 +400,13 @@ class Lab(Accelerator):
|
|
396 |
try:
|
397 |
for epoch in range(num_train_epochs):
|
398 |
# run training loop
|
|
|
|
|
399 |
if self.controlnet:
|
400 |
self.controlnet.train()
|
401 |
else:
|
402 |
self.unet.train()
|
403 |
-
for batch in self.train_dataloader:
|
404 |
loss, encoder_hidden_states = self.compute_loss(batch)
|
405 |
|
406 |
loss /= args.gradient_accumulation_steps
|
@@ -416,6 +422,8 @@ class Lab(Accelerator):
|
|
416 |
# Checks if the accelerator has performed an optimization step behind the scenes
|
417 |
if self.sync_gradients:
|
418 |
progress_bar.update(1)
|
|
|
|
|
419 |
self.global_step += 1
|
420 |
|
421 |
if self.is_main_process:
|
|
|
372 |
str(save_dir / "unet"), safe_serialization=to_safetensors
|
373 |
)
|
374 |
|
375 |
+
def train(self, num_train_epochs=1000, gr_progress = None):
|
376 |
args = self.args
|
377 |
|
378 |
+
if args.num_train_epochs:
|
379 |
+
num_train_epochs = args.num_train_epochs
|
380 |
+
|
381 |
max_train_steps = (
|
382 |
num_train_epochs
|
383 |
* len(self.train_dataloader)
|
|
|
389 |
|
390 |
self.global_step = 0
|
391 |
|
392 |
+
|
393 |
# Only show the progress bar once on each machine.
|
394 |
progress_bar = tqdm(
|
395 |
range(max_train_steps),
|
|
|
400 |
try:
|
401 |
for epoch in range(num_train_epochs):
|
402 |
# run training loop
|
403 |
+
if gr_progress is not None:
|
404 |
+
gr_progress(0, desc=f"Starting Epoch {epoch}")
|
405 |
if self.controlnet:
|
406 |
self.controlnet.train()
|
407 |
else:
|
408 |
self.unet.train()
|
409 |
+
for i, batch in enumerate(self.train_dataloader):
|
410 |
loss, encoder_hidden_states = self.compute_loss(batch)
|
411 |
|
412 |
loss /= args.gradient_accumulation_steps
|
|
|
422 |
# Checks if the accelerator has performed an optimization step behind the scenes
|
423 |
if self.sync_gradients:
|
424 |
progress_bar.update(1)
|
425 |
+
if gr_progress is not None:
|
426 |
+
gr_progress(float(i/len(self.train_dataloader)))
|
427 |
self.global_step += 1
|
428 |
|
429 |
if self.is_main_process:
|
src/ui_assets/footer.MD
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
### <p style="text-align: center;">Licensed under [Apache 2.0](https://www.apache.org/licenses/LICENSE-2.0)</p>
|
src/ui_assets/footer.html
DELETED
@@ -1,9 +0,0 @@
|
|
1 |
-
|
2 |
-
<!-- based on https://huggingface.co/spaces/stabilityai/stable-diffusion/blob/main/app.py -->
|
3 |
-
|
4 |
-
|
5 |
-
<div class="footer">
|
6 |
-
<p><h4>LICENSE</h4>
|
7 |
-
The default model is licensed with a <a href="https://huggingface.co/stabilityai/stable-diffusion-2/blob/main/LICENSE-MODEL" style="text-decoration: underline;" target="_blank">CreativeML OpenRAIL++</a> license. The authors claim no rights on the outputs you generate, you are free to use them and are accountable for their use which must not go against the provisions set in this license. The license forbids you from sharing any content that violates any laws, produce any harm to a person, disseminate any personal information that would be meant for harm, spread misinformation and target vulnerable groups. For the full list of restrictions please <a href="https://huggingface.co/spaces/CompVis/stable-diffusion-license" target="_blank" style="text-decoration: underline;" target="_blank">read the license</a></p>
|
8 |
-
</div>
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/ui_assets/header.MD
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# <p style="text-align: center;">**Anime ControlNet Web UI**</p>
|
3 |
+
|
4 |
+
### <p style="text-align: center;">Try Anime Controlnet Web UI with free GPU at [yfu.one](https://yfu.one/full_ui/)</p>
|
5 |
+
|
6 |
+
### <p style="text-align: center;">[HuggingFace Models](https://huggingface.co/lint/anime_control) for downloading anime controlnet weights | [Github Repo](https://github.com/1lint/anime_controlnet) for training code</p>
|
7 |
+
|
8 |
+
|
src/ui_assets/header.html
DELETED
@@ -1,23 +0,0 @@
|
|
1 |
-
|
2 |
-
<!-- based on https://huggingface.co/spaces/stabilityai/stable-diffusion/blob/main/app.py -->
|
3 |
-
|
4 |
-
<div style="text-align: center; margin: 0 auto;">
|
5 |
-
<div
|
6 |
-
style="
|
7 |
-
display: inline-flex;
|
8 |
-
align-items: center;
|
9 |
-
gap: 0.8rem;
|
10 |
-
font-size: 1.75rem;
|
11 |
-
"
|
12 |
-
>
|
13 |
-
<svg version="1.1" id="Layer_1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" x="0px" y="0px" viewBox="0 0 32 32" style="enable-background:new 0 0 512 512;" xml:space="preserve" width="32" height="32"><path style="fill:#FCD577;" d="M29.545 29.791V2.21c-1.22 0 -2.21 -0.99 -2.21 -2.21H4.665c0 1.22 -0.99 2.21 -2.21 2.21v27.581c1.22 0 2.21 0.99 2.21 2.21H27.335C27.335 30.779 28.325 29.791 29.545 29.791z"/><path x="98.205" y="58.928" style="fill:#99B6C6;" width="315.577" height="394.144" d="M6.138 3.683H25.861V28.317H6.138V3.683z"/><path x="98.205" y="58.928" style="fill:#7BD4EF;" width="315.577" height="131.317" d="M6.138 3.683H25.861V11.89H6.138V3.683z"/><g><path style="fill:#7190A5;" d="M14.498 10.274c0 1.446 0.983 1.155 1.953 1.502l0.504 5.317c0 0 -5.599 0.989 -6.026 2.007l0.27 -2.526c0.924 -1.462 1.286 -4.864 1.419 -6.809l0.086 0.006C12.697 9.876 14.498 10.166 14.498 10.274z"/><path style="fill:#7190A5;" d="M21.96 17.647c0 0 -0.707 1.458 -1.716 1.903c0 0 -1.502 -0.827 -1.502 -0.827c-2.276 -1.557 -2.366 -8.3 -2.366 -8.3c0 -1.718 -0.185 -1.615 -1.429 -1.615c-1.167 0 -2.127 -0.606 -2.242 0.963l-0.086 -0.006c0.059 -0.859 0.074 -1.433 0.074 -1.433c0 -1.718 1.449 -3.11 3.237 -3.11s3.237 1.392 3.237 3.11C19.168 8.332 19.334 15.617 21.96 17.647z"/></g><path style="fill:#6C8793;" d="M12.248 24.739c1.538 0.711 3.256 1.591 3.922 2.258c-1.374 0.354 -2.704 0.798 -3.513 1.32h-2.156c-1.096 -0.606 -2.011 -1.472 -2.501 -2.702c-1.953 -4.907 2.905 -8.664 2.905 -8.664c0.001 -0.001 0.002 -0.002 0.003 -0.003c0.213 -0.214 0.523 -0.301 0.811 -0.21l0.02 0.006c-0.142 0.337 -0.03 0.71 0.517 1.108c1.264 0.919 3.091 1.131 4.416 1.143c-1.755 1.338 -3.42 3.333 -4.367 5.618L12.248 24.739z"/><path style="fill:#577484;" d="M16.17 26.997c-0.666 -0.666 -2.385 -1.548 -3.922 -2.258l0.059 -0.126c0.947 -2.284 2.612 -4.28 4.367 -5.618c0.001 0 0.001 0 0.001 0c0.688 -0.525 1.391 -0.948 2.068 -1.247c0.001 0 0.001 0 0.001 0c1.009 -0.446 1.964 -0.617 2.742 -0.44c0.61 0.138 1.109 0.492 1.439 1.095c1.752 3.205 0.601 9.913 0.601 9.913H12.657C13.466 27.796 14.796 27.352 16.17 26.997z"/><path style="fill:#F7DEB0;" d="M14.38 13.1c-0.971 -0.347 -1.687 -1.564 -1.687 -3.01c0 -0.107 0.004 -0.213 0.011 -0.318c0.116 -1.569 1.075 -2.792 2.242 -2.792c1.244 0 2.253 1.392 2.253 3.11c0 0 -0.735 6.103 1.542 7.66c-0.677 0.299 -1.38 0.722 -2.068 1.247c0 0 0 0 -0.001 0c-1.326 -0.012 -3.152 -0.223 -4.416 -1.143c-0.547 -0.398 -0.659 -0.771 -0.517 -1.108c0.426 -1.018 3.171 -1.697 3.171 -1.697L14.38 13.1z"/><path style="fill:#E5CA9E;" d="M14.38 13.1c0 0 1.019 0.216 1.544 -0.309c0 0 -0.401 1.04 -1.346 1.04"/><g><path style="fill:#EAC36E;" points="437.361,0 413.79,58.926 472.717,35.356 " d="M27.335 0L25.862 3.683L29.545 2.21"/><path style="fill:#EAC36E;" points="437.361,512 413.79,453.074 472.717,476.644 " d="M27.335 32L25.862 28.317L29.545 29.791"/><path style="fill:#EAC36E;" points="74.639,512 98.21,453.074 39.283,476.644 " d="M4.665 32L6.138 28.317L2.455 29.791"/><path style="fill:#EAC36E;" points="39.283,35.356 98.21,58.926 74.639,0 " d="M2.455 2.21L6.138 3.683L4.665 0"/><path style="fill:#EAC36E;" d="M26.425 28.881H5.574V3.119h20.851v25.761H26.425zM6.702 27.754h18.597V4.246H6.702V27.754z"/></g><g><path style="fill:#486572;" d="M12.758 21.613c-0.659 0.767 -1.245 1.613 -1.722 2.531l0.486 0.202C11.82 23.401 12.241 22.483 12.758 21.613z"/><path style="fill:#486572;" d="M21.541 25.576l-0.37 0.068c-0.553 0.101 -1.097 0.212 -1.641 0.331l-0.071 -0.201l-0.059 -0.167c-0.019 -0.056 -0.035 -0.112 -0.052 -0.169l-0.104 -0.338l-0.088 -0.342c-0.112 -0.457 -0.197 -0.922 -0.235 -1.393c-0.035 -0.47 -0.032 -0.947 0.042 -1.417c0.072 -0.47 0.205 -0.935 0.422 -1.369c-0.272 0.402 -0.469 0.856 -0.606 1.329c-0.138 0.473 -0.207 0.967 -0.234 1.462c-0.024 0.496 0.002 0.993 0.057 1.487l0.046 0.37l0.063 0.367c0.011 0.061 0.02 0.123 0.033 0.184l0.039 0.182l0.037 0.174c-0.677 0.157 -1.351 0.327 -2.019 0.514c-0.131 0.037 -0.262 0.075 -0.392 0.114l0.004 -0.004c-0.117 -0.095 -0.232 -0.197 -0.35 -0.275c-0.059 -0.041 -0.117 -0.084 -0.177 -0.122l-0.179 -0.112c-0.239 -0.147 -0.482 -0.279 -0.727 -0.406c-0.489 -0.252 -0.985 -0.479 -1.484 -0.697c-0.998 -0.433 -2.01 -0.825 -3.026 -1.196c0.973 0.475 1.937 0.969 2.876 1.499c0.469 0.266 0.932 0.539 1.379 0.832c0.223 0.146 0.442 0.297 0.648 0.456l0.154 0.119c0.05 0.041 0.097 0.083 0.145 0.124c0.002 0.002 0.004 0.003 0.005 0.005c-0.339 0.109 -0.675 0.224 -1.009 0.349c-0.349 0.132 -0.696 0.273 -1.034 0.431c-0.338 0.159 -0.668 0.337 -0.973 0.549c0.322 -0.186 0.662 -0.334 1.01 -0.463c0.347 -0.129 0.701 -0.239 1.056 -0.34c0.394 -0.111 0.79 -0.208 1.19 -0.297c0.006 0.006 0.013 0.013 0.019 0.019l0.03 -0.03c0.306 -0.068 0.614 -0.132 0.922 -0.192c0.727 -0.14 1.457 -0.258 2.189 -0.362c0.731 -0.103 1.469 -0.195 2.197 -0.265l0.374 -0.036L21.541 25.576z"/></g></svg>
|
14 |
-
|
15 |
-
<h1 style="font-weight: 1000; margin-bottom: 8px;margin-top:8px">
|
16 |
-
<a href="https://github.com/1lint/style_controlnet">
|
17 |
-
Style ControlNet Web UI
|
18 |
-
</a>
|
19 |
-
</h1>
|
20 |
-
</div>
|
21 |
-
<p> Use the ControlNet architecture to control Stable Diffusion image generation style</p>
|
22 |
-
</div>
|
23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/ui_functions.py
CHANGED
@@ -135,14 +135,14 @@ def generate(
|
|
135 |
controlnet_prompt=None,
|
136 |
controlnet_negative_prompt=None,
|
137 |
controlnet_cond_scale=1.0,
|
138 |
-
progress=gr.Progress(
|
139 |
):
|
140 |
|
141 |
if seed == -1:
|
142 |
seed = random.randint(0, 2147483647)
|
143 |
|
144 |
if guidance_image:
|
145 |
-
|
146 |
else:
|
147 |
guidance_image = torch.zeros(n_images, 3, height, width)
|
148 |
|
@@ -199,13 +199,13 @@ def run_training(
|
|
199 |
train_batch_size,
|
200 |
train_whole_controlnet,
|
201 |
gradient_accumulation_steps,
|
202 |
-
|
203 |
train_learning_rate,
|
204 |
output_dir,
|
205 |
checkpointing_steps,
|
206 |
image_logging_steps,
|
207 |
save_whole_pipeline,
|
208 |
-
progress=gr.Progress(
|
209 |
):
|
210 |
global pipe
|
211 |
|
@@ -240,14 +240,13 @@ def run_training(
|
|
240 |
valid_data_dir=valid_data_dir,
|
241 |
resolution=512,
|
242 |
from_hf_hub = train_data_dir == "lint/anybooru",
|
243 |
-
controlnet_hint_key=
|
244 |
|
245 |
# training args
|
246 |
# options are ["zero convolutions", "input hint blocks"], trains whole controlnet by default
|
247 |
training_stage="" if train_whole_controlnet else "zero convolutions",
|
248 |
learning_rate=float(train_learning_rate),
|
249 |
-
num_train_epochs=
|
250 |
-
max_train_steps=int(max_train_steps),
|
251 |
seed=3434554,
|
252 |
max_grad_norm=1.0,
|
253 |
gradient_accumulation_steps=int(gradient_accumulation_steps),
|
@@ -271,7 +270,7 @@ def run_training(
|
|
271 |
|
272 |
try:
|
273 |
lab = Lab(training_args, pipe)
|
274 |
-
lab.train(training_args.num_train_epochs)
|
275 |
except Exception as e:
|
276 |
raise gr.Error(e)
|
277 |
|
|
|
135 |
controlnet_prompt=None,
|
136 |
controlnet_negative_prompt=None,
|
137 |
controlnet_cond_scale=1.0,
|
138 |
+
progress=gr.Progress(),
|
139 |
):
|
140 |
|
141 |
if seed == -1:
|
142 |
seed = random.randint(0, 2147483647)
|
143 |
|
144 |
if guidance_image:
|
145 |
+
guidance_image = extract_canny(guidance_image)
|
146 |
else:
|
147 |
guidance_image = torch.zeros(n_images, 3, height, width)
|
148 |
|
|
|
199 |
train_batch_size,
|
200 |
train_whole_controlnet,
|
201 |
gradient_accumulation_steps,
|
202 |
+
num_train_epochs,
|
203 |
train_learning_rate,
|
204 |
output_dir,
|
205 |
checkpointing_steps,
|
206 |
image_logging_steps,
|
207 |
save_whole_pipeline,
|
208 |
+
progress=gr.Progress(),
|
209 |
):
|
210 |
global pipe
|
211 |
|
|
|
240 |
valid_data_dir=valid_data_dir,
|
241 |
resolution=512,
|
242 |
from_hf_hub = train_data_dir == "lint/anybooru",
|
243 |
+
controlnet_hint_key="canny",
|
244 |
|
245 |
# training args
|
246 |
# options are ["zero convolutions", "input hint blocks"], trains whole controlnet by default
|
247 |
training_stage="" if train_whole_controlnet else "zero convolutions",
|
248 |
learning_rate=float(train_learning_rate),
|
249 |
+
num_train_epochs=int(num_train_epochs),
|
|
|
250 |
seed=3434554,
|
251 |
max_grad_norm=1.0,
|
252 |
gradient_accumulation_steps=int(gradient_accumulation_steps),
|
|
|
270 |
|
271 |
try:
|
272 |
lab = Lab(training_args, pipe)
|
273 |
+
lab.train(training_args.num_train_epochs, gr_progress=progress)
|
274 |
except Exception as e:
|
275 |
raise gr.Error(e)
|
276 |
|