1lint commited on
Commit
0d0a1c2
·
1 Parent(s): 34d5395

fix and revise app

Browse files
src/app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
  from multiprocessing import cpu_count
3
-
4
  from src.ui_shared import (
5
  model_ids,
6
  scheduler_names,
@@ -13,10 +13,12 @@ from src.ui_functions import generate, run_training
13
 
14
  default_img_size = 512
15
 
16
- with open(f"{assets_directory}/header.html") as fp:
 
17
  header = fp.read()
18
 
19
- with open(f"{assets_directory}/footer.html") as fp:
 
20
  footer = fp.read()
21
 
22
 
@@ -25,13 +27,11 @@ theme = gr.themes.Soft(
25
  neutral_hue="slate",
26
  )
27
 
28
- from gradio.themes.builder_app import css
29
-
30
  with gr.Blocks(theme=theme) as demo:
31
 
32
- gr.HTML(header)
33
 
34
- with gr.Row():
35
  with gr.Column(scale=70):
36
  prompt = gr.Textbox(
37
  label="Prompt", placeholder="Press <Shift+Enter> to generate", lines=2
@@ -53,15 +53,17 @@ with gr.Blocks(theme=theme) as demo:
53
 
54
  with gr.Column(scale=30):
55
  model_name = gr.Dropdown(
56
- label="Model", choices=model_ids, value=model_ids[0]
57
  )
58
  controlnet_name = gr.Dropdown(
59
- label="Controlnet", choices=controlnet_ids, value=controlnet_ids[0]
60
  )
61
  scheduler_name = gr.Dropdown(
62
- label="Scheduler", choices=scheduler_names, value=default_scheduler
63
  )
64
- generate_button = gr.Button(value="Generate", variant="primary")
 
 
65
 
66
  with gr.Row():
67
  with gr.Column():
@@ -114,7 +116,7 @@ with gr.Blocks(theme=theme) as demo:
114
  )
115
 
116
 
117
- with gr.Tab("Train Style ControlNet") as tab:
118
  with gr.Row():
119
  train_batch_size = gr.Slider(
120
  label="Training Batch Size",
@@ -133,8 +135,8 @@ with gr.Blocks(theme=theme) as demo:
133
  )
134
 
135
  with gr.Row():
136
- max_train_steps = gr.Number(
137
- label="Total training steps", value=16000
138
  )
139
  train_learning_rate = gr.Number(label="Learning Rate", value=5.0e-6)
140
 
@@ -160,7 +162,7 @@ with gr.Blocks(theme=theme) as demo:
160
  with gr.Row():
161
  controlnet_weights_path = gr.Textbox(
162
  label=f"Repo for initializing Controlnet Weights",
163
- value="andite/anything-v4.0/unet",
164
  )
165
  output_dir = gr.Textbox(
166
  label=f"Output directory for trained weights", value="./models"
@@ -195,7 +197,7 @@ with gr.Blocks(theme=theme) as demo:
195
  # vram_guage = gr.Slider(0, torch.cuda.memory_reserved(0)/giga, label='VRAM Allocated to Reserved (GB)', value=0, step=1)
196
  # demo.load(lambda : torch.cuda.memory_allocated(0)/giga, inputs=[], outputs=vram_guage, every=0.5, show_progress=False)
197
 
198
- # gr.HTML(footer)
199
 
200
  inputs = [
201
  model_name,
@@ -228,7 +230,7 @@ with gr.Blocks(theme=theme) as demo:
228
  train_batch_size,
229
  train_whole_controlnet,
230
  gradient_accumulation_steps,
231
- max_train_steps,
232
  train_learning_rate,
233
  output_dir,
234
  checkpointing_steps,
@@ -242,19 +244,21 @@ with gr.Blocks(theme=theme) as demo:
242
  outputs=[training_status],
243
  )
244
 
245
- # from gradio.themes.builder_app
246
- demo.load(
247
- None,
248
- None,
249
- None,
250
  _js="""() => {
251
  if (document.querySelectorAll('.dark').length) {
252
- document.querySelectorAll('.dark').forEach(el => el.classList.remove('dark'));
253
- } else {
254
- document.querySelector('body').classList.add('dark');
255
- }
256
- }""",
257
  )
 
 
258
 
259
  if __name__ == "__main__":
260
- demo.queue(concurrency_count=cpu_count()).launch()
 
1
  import gradio as gr
2
  from multiprocessing import cpu_count
3
+ from pathlib import Path
4
  from src.ui_shared import (
5
  model_ids,
6
  scheduler_names,
 
13
 
14
  default_img_size = 512
15
 
16
+
17
+ with open(f"{assets_directory}/header.MD") as fp:
18
  header = fp.read()
19
 
20
+
21
+ with open(f"{assets_directory}/footer.MD") as fp:
22
  footer = fp.read()
23
 
24
 
 
27
  neutral_hue="slate",
28
  )
29
 
 
 
30
  with gr.Blocks(theme=theme) as demo:
31
 
32
+ header_component = gr.Markdown(header)
33
 
34
+ with gr.Row().style(equal_height=True):
35
  with gr.Column(scale=70):
36
  prompt = gr.Textbox(
37
  label="Prompt", placeholder="Press <Shift+Enter> to generate", lines=2
 
53
 
54
  with gr.Column(scale=30):
55
  model_name = gr.Dropdown(
56
+ label="Model", choices=model_ids, value=model_ids[0], allow_custom_value=True
57
  )
58
  controlnet_name = gr.Dropdown(
59
+ label="Controlnet", choices=controlnet_ids, value=controlnet_ids[0], allow_custom_value=True
60
  )
61
  scheduler_name = gr.Dropdown(
62
+ label="Scheduler", choices=scheduler_names, value=default_scheduler, allow_custom_value=True
63
  )
64
+ with gr.Row():
65
+ generate_button = gr.Button(value="Generate", variant="primary")
66
+ dark_mode_btn = gr.Button("Dark Mode", variant="secondary")
67
 
68
  with gr.Row():
69
  with gr.Column():
 
116
  )
117
 
118
 
119
+ with gr.Tab("Train Anime ControlNet") as tab:
120
  with gr.Row():
121
  train_batch_size = gr.Slider(
122
  label="Training Batch Size",
 
135
  )
136
 
137
  with gr.Row():
138
+ num_train_epochs = gr.Number(
139
+ label="Total training epochs", value=2
140
  )
141
  train_learning_rate = gr.Number(label="Learning Rate", value=5.0e-6)
142
 
 
162
  with gr.Row():
163
  controlnet_weights_path = gr.Textbox(
164
  label=f"Repo for initializing Controlnet Weights",
165
+ value="lint/anime_control/anime_merge",
166
  )
167
  output_dir = gr.Textbox(
168
  label=f"Output directory for trained weights", value="./models"
 
197
  # vram_guage = gr.Slider(0, torch.cuda.memory_reserved(0)/giga, label='VRAM Allocated to Reserved (GB)', value=0, step=1)
198
  # demo.load(lambda : torch.cuda.memory_allocated(0)/giga, inputs=[], outputs=vram_guage, every=0.5, show_progress=False)
199
 
200
+ footer_component = gr.Markdown(footer)
201
 
202
  inputs = [
203
  model_name,
 
230
  train_batch_size,
231
  train_whole_controlnet,
232
  gradient_accumulation_steps,
233
+ num_train_epochs,
234
  train_learning_rate,
235
  output_dir,
236
  checkpointing_steps,
 
244
  outputs=[training_status],
245
  )
246
 
247
+ # from gradio.themes.builder
248
+ toggle_dark_mode_args = dict(
249
+ fn=None,
250
+ inputs=None,
251
+ outputs=None,
252
  _js="""() => {
253
  if (document.querySelectorAll('.dark').length) {
254
+ document.querySelectorAll('.dark').forEach(el => el.classList.remove('dark'));
255
+ } else {
256
+ document.querySelector('body').classList.add('dark');
257
+ }
258
+ }""",
259
  )
260
+ demo.load(**toggle_dark_mode_args)
261
+ dark_mode_btn.click(**toggle_dark_mode_args)
262
 
263
  if __name__ == "__main__":
264
+ demo.queue(concurrency_count=cpu_count()).launch(favicon_path=favicon_path)
src/controlnet_pipe.py CHANGED
@@ -212,7 +212,7 @@ class ControlNetPipe(StableDiffusionControlNetPipeline):
212
  # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
213
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
214
 
215
- if not controlnet_prompt_embeds:
216
  controlnet_prompt_embeds = prompt_embeds
217
 
218
  # 8. Denoising loop
 
212
  # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
213
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
214
 
215
+ if controlnet_prompt_embeds is None:
216
  controlnet_prompt_embeds = prompt_embeds
217
 
218
  # 8. Denoising loop
src/lab.py CHANGED
@@ -372,9 +372,12 @@ class Lab(Accelerator):
372
  str(save_dir / "unet"), safe_serialization=to_safetensors
373
  )
374
 
375
- def train(self, num_train_epochs=1000):
376
  args = self.args
377
 
 
 
 
378
  max_train_steps = (
379
  num_train_epochs
380
  * len(self.train_dataloader)
@@ -386,6 +389,7 @@ class Lab(Accelerator):
386
 
387
  self.global_step = 0
388
 
 
389
  # Only show the progress bar once on each machine.
390
  progress_bar = tqdm(
391
  range(max_train_steps),
@@ -396,11 +400,13 @@ class Lab(Accelerator):
396
  try:
397
  for epoch in range(num_train_epochs):
398
  # run training loop
 
 
399
  if self.controlnet:
400
  self.controlnet.train()
401
  else:
402
  self.unet.train()
403
- for batch in self.train_dataloader:
404
  loss, encoder_hidden_states = self.compute_loss(batch)
405
 
406
  loss /= args.gradient_accumulation_steps
@@ -416,6 +422,8 @@ class Lab(Accelerator):
416
  # Checks if the accelerator has performed an optimization step behind the scenes
417
  if self.sync_gradients:
418
  progress_bar.update(1)
 
 
419
  self.global_step += 1
420
 
421
  if self.is_main_process:
 
372
  str(save_dir / "unet"), safe_serialization=to_safetensors
373
  )
374
 
375
+ def train(self, num_train_epochs=1000, gr_progress = None):
376
  args = self.args
377
 
378
+ if args.num_train_epochs:
379
+ num_train_epochs = args.num_train_epochs
380
+
381
  max_train_steps = (
382
  num_train_epochs
383
  * len(self.train_dataloader)
 
389
 
390
  self.global_step = 0
391
 
392
+
393
  # Only show the progress bar once on each machine.
394
  progress_bar = tqdm(
395
  range(max_train_steps),
 
400
  try:
401
  for epoch in range(num_train_epochs):
402
  # run training loop
403
+ if gr_progress is not None:
404
+ gr_progress(0, desc=f"Starting Epoch {epoch}")
405
  if self.controlnet:
406
  self.controlnet.train()
407
  else:
408
  self.unet.train()
409
+ for i, batch in enumerate(self.train_dataloader):
410
  loss, encoder_hidden_states = self.compute_loss(batch)
411
 
412
  loss /= args.gradient_accumulation_steps
 
422
  # Checks if the accelerator has performed an optimization step behind the scenes
423
  if self.sync_gradients:
424
  progress_bar.update(1)
425
+ if gr_progress is not None:
426
+ gr_progress(float(i/len(self.train_dataloader)))
427
  self.global_step += 1
428
 
429
  if self.is_main_process:
src/ui_assets/footer.MD ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+
2
+
3
+ ### <p style="text-align: center;">Licensed under [Apache 2.0](https://www.apache.org/licenses/LICENSE-2.0)</p>
src/ui_assets/footer.html DELETED
@@ -1,9 +0,0 @@
1
-
2
- <!-- based on https://huggingface.co/spaces/stabilityai/stable-diffusion/blob/main/app.py -->
3
-
4
-
5
- <div class="footer">
6
- <p><h4>LICENSE</h4>
7
- The default model is licensed with a <a href="https://huggingface.co/stabilityai/stable-diffusion-2/blob/main/LICENSE-MODEL" style="text-decoration: underline;" target="_blank">CreativeML OpenRAIL++</a> license. The authors claim no rights on the outputs you generate, you are free to use them and are accountable for their use which must not go against the provisions set in this license. The license forbids you from sharing any content that violates any laws, produce any harm to a person, disseminate any personal information that would be meant for harm, spread misinformation and target vulnerable groups. For the full list of restrictions please <a href="https://huggingface.co/spaces/CompVis/stable-diffusion-license" target="_blank" style="text-decoration: underline;" target="_blank">read the license</a></p>
8
- </div>
9
-
 
 
 
 
 
 
 
 
 
 
src/ui_assets/header.MD ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+
2
+ # <p style="text-align: center;">**Anime ControlNet Web UI**</p>
3
+
4
+ ### <p style="text-align: center;">Try Anime Controlnet Web UI with free GPU at [yfu.one](https://yfu.one/full_ui/)</p>
5
+
6
+ ### <p style="text-align: center;">[HuggingFace Models](https://huggingface.co/lint/anime_control) for downloading anime controlnet weights | [Github Repo](https://github.com/1lint/anime_controlnet) for training code</p>
7
+
8
+
src/ui_assets/header.html DELETED
@@ -1,23 +0,0 @@
1
-
2
- <!-- based on https://huggingface.co/spaces/stabilityai/stable-diffusion/blob/main/app.py -->
3
-
4
- <div style="text-align: center; margin: 0 auto;">
5
- <div
6
- style="
7
- display: inline-flex;
8
- align-items: center;
9
- gap: 0.8rem;
10
- font-size: 1.75rem;
11
- "
12
- >
13
- <svg version="1.1" id="Layer_1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" x="0px" y="0px" viewBox="0 0 32 32" style="enable-background:new 0 0 512 512;" xml:space="preserve" width="32" height="32"><path style="fill:#FCD577;" d="M29.545 29.791V2.21c-1.22 0 -2.21 -0.99 -2.21 -2.21H4.665c0 1.22 -0.99 2.21 -2.21 2.21v27.581c1.22 0 2.21 0.99 2.21 2.21H27.335C27.335 30.779 28.325 29.791 29.545 29.791z"/><path x="98.205" y="58.928" style="fill:#99B6C6;" width="315.577" height="394.144" d="M6.138 3.683H25.861V28.317H6.138V3.683z"/><path x="98.205" y="58.928" style="fill:#7BD4EF;" width="315.577" height="131.317" d="M6.138 3.683H25.861V11.89H6.138V3.683z"/><g><path style="fill:#7190A5;" d="M14.498 10.274c0 1.446 0.983 1.155 1.953 1.502l0.504 5.317c0 0 -5.599 0.989 -6.026 2.007l0.27 -2.526c0.924 -1.462 1.286 -4.864 1.419 -6.809l0.086 0.006C12.697 9.876 14.498 10.166 14.498 10.274z"/><path style="fill:#7190A5;" d="M21.96 17.647c0 0 -0.707 1.458 -1.716 1.903c0 0 -1.502 -0.827 -1.502 -0.827c-2.276 -1.557 -2.366 -8.3 -2.366 -8.3c0 -1.718 -0.185 -1.615 -1.429 -1.615c-1.167 0 -2.127 -0.606 -2.242 0.963l-0.086 -0.006c0.059 -0.859 0.074 -1.433 0.074 -1.433c0 -1.718 1.449 -3.11 3.237 -3.11s3.237 1.392 3.237 3.11C19.168 8.332 19.334 15.617 21.96 17.647z"/></g><path style="fill:#6C8793;" d="M12.248 24.739c1.538 0.711 3.256 1.591 3.922 2.258c-1.374 0.354 -2.704 0.798 -3.513 1.32h-2.156c-1.096 -0.606 -2.011 -1.472 -2.501 -2.702c-1.953 -4.907 2.905 -8.664 2.905 -8.664c0.001 -0.001 0.002 -0.002 0.003 -0.003c0.213 -0.214 0.523 -0.301 0.811 -0.21l0.02 0.006c-0.142 0.337 -0.03 0.71 0.517 1.108c1.264 0.919 3.091 1.131 4.416 1.143c-1.755 1.338 -3.42 3.333 -4.367 5.618L12.248 24.739z"/><path style="fill:#577484;" d="M16.17 26.997c-0.666 -0.666 -2.385 -1.548 -3.922 -2.258l0.059 -0.126c0.947 -2.284 2.612 -4.28 4.367 -5.618c0.001 0 0.001 0 0.001 0c0.688 -0.525 1.391 -0.948 2.068 -1.247c0.001 0 0.001 0 0.001 0c1.009 -0.446 1.964 -0.617 2.742 -0.44c0.61 0.138 1.109 0.492 1.439 1.095c1.752 3.205 0.601 9.913 0.601 9.913H12.657C13.466 27.796 14.796 27.352 16.17 26.997z"/><path style="fill:#F7DEB0;" d="M14.38 13.1c-0.971 -0.347 -1.687 -1.564 -1.687 -3.01c0 -0.107 0.004 -0.213 0.011 -0.318c0.116 -1.569 1.075 -2.792 2.242 -2.792c1.244 0 2.253 1.392 2.253 3.11c0 0 -0.735 6.103 1.542 7.66c-0.677 0.299 -1.38 0.722 -2.068 1.247c0 0 0 0 -0.001 0c-1.326 -0.012 -3.152 -0.223 -4.416 -1.143c-0.547 -0.398 -0.659 -0.771 -0.517 -1.108c0.426 -1.018 3.171 -1.697 3.171 -1.697L14.38 13.1z"/><path style="fill:#E5CA9E;" d="M14.38 13.1c0 0 1.019 0.216 1.544 -0.309c0 0 -0.401 1.04 -1.346 1.04"/><g><path style="fill:#EAC36E;" points="437.361,0 413.79,58.926 472.717,35.356 " d="M27.335 0L25.862 3.683L29.545 2.21"/><path style="fill:#EAC36E;" points="437.361,512 413.79,453.074 472.717,476.644 " d="M27.335 32L25.862 28.317L29.545 29.791"/><path style="fill:#EAC36E;" points="74.639,512 98.21,453.074 39.283,476.644 " d="M4.665 32L6.138 28.317L2.455 29.791"/><path style="fill:#EAC36E;" points="39.283,35.356 98.21,58.926 74.639,0 " d="M2.455 2.21L6.138 3.683L4.665 0"/><path style="fill:#EAC36E;" d="M26.425 28.881H5.574V3.119h20.851v25.761H26.425zM6.702 27.754h18.597V4.246H6.702V27.754z"/></g><g><path style="fill:#486572;" d="M12.758 21.613c-0.659 0.767 -1.245 1.613 -1.722 2.531l0.486 0.202C11.82 23.401 12.241 22.483 12.758 21.613z"/><path style="fill:#486572;" d="M21.541 25.576l-0.37 0.068c-0.553 0.101 -1.097 0.212 -1.641 0.331l-0.071 -0.201l-0.059 -0.167c-0.019 -0.056 -0.035 -0.112 -0.052 -0.169l-0.104 -0.338l-0.088 -0.342c-0.112 -0.457 -0.197 -0.922 -0.235 -1.393c-0.035 -0.47 -0.032 -0.947 0.042 -1.417c0.072 -0.47 0.205 -0.935 0.422 -1.369c-0.272 0.402 -0.469 0.856 -0.606 1.329c-0.138 0.473 -0.207 0.967 -0.234 1.462c-0.024 0.496 0.002 0.993 0.057 1.487l0.046 0.37l0.063 0.367c0.011 0.061 0.02 0.123 0.033 0.184l0.039 0.182l0.037 0.174c-0.677 0.157 -1.351 0.327 -2.019 0.514c-0.131 0.037 -0.262 0.075 -0.392 0.114l0.004 -0.004c-0.117 -0.095 -0.232 -0.197 -0.35 -0.275c-0.059 -0.041 -0.117 -0.084 -0.177 -0.122l-0.179 -0.112c-0.239 -0.147 -0.482 -0.279 -0.727 -0.406c-0.489 -0.252 -0.985 -0.479 -1.484 -0.697c-0.998 -0.433 -2.01 -0.825 -3.026 -1.196c0.973 0.475 1.937 0.969 2.876 1.499c0.469 0.266 0.932 0.539 1.379 0.832c0.223 0.146 0.442 0.297 0.648 0.456l0.154 0.119c0.05 0.041 0.097 0.083 0.145 0.124c0.002 0.002 0.004 0.003 0.005 0.005c-0.339 0.109 -0.675 0.224 -1.009 0.349c-0.349 0.132 -0.696 0.273 -1.034 0.431c-0.338 0.159 -0.668 0.337 -0.973 0.549c0.322 -0.186 0.662 -0.334 1.01 -0.463c0.347 -0.129 0.701 -0.239 1.056 -0.34c0.394 -0.111 0.79 -0.208 1.19 -0.297c0.006 0.006 0.013 0.013 0.019 0.019l0.03 -0.03c0.306 -0.068 0.614 -0.132 0.922 -0.192c0.727 -0.14 1.457 -0.258 2.189 -0.362c0.731 -0.103 1.469 -0.195 2.197 -0.265l0.374 -0.036L21.541 25.576z"/></g></svg>
14
-
15
- <h1 style="font-weight: 1000; margin-bottom: 8px;margin-top:8px">
16
- <a href="https://github.com/1lint/style_controlnet">
17
- Style ControlNet Web UI
18
- </a>
19
- </h1>
20
- </div>
21
- <p> Use the ControlNet architecture to control Stable Diffusion image generation style</p>
22
- </div>
23
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/ui_functions.py CHANGED
@@ -135,14 +135,14 @@ def generate(
135
  controlnet_prompt=None,
136
  controlnet_negative_prompt=None,
137
  controlnet_cond_scale=1.0,
138
- progress=gr.Progress(track_tqdm=True),
139
  ):
140
 
141
  if seed == -1:
142
  seed = random.randint(0, 2147483647)
143
 
144
  if guidance_image:
145
- guiadnce_image = extract_canny(guidance_image)
146
  else:
147
  guidance_image = torch.zeros(n_images, 3, height, width)
148
 
@@ -199,13 +199,13 @@ def run_training(
199
  train_batch_size,
200
  train_whole_controlnet,
201
  gradient_accumulation_steps,
202
- max_train_steps,
203
  train_learning_rate,
204
  output_dir,
205
  checkpointing_steps,
206
  image_logging_steps,
207
  save_whole_pipeline,
208
- progress=gr.Progress(track_tqdm=True),
209
  ):
210
  global pipe
211
 
@@ -240,14 +240,13 @@ def run_training(
240
  valid_data_dir=valid_data_dir,
241
  resolution=512,
242
  from_hf_hub = train_data_dir == "lint/anybooru",
243
- controlnet_hint_key=None,
244
 
245
  # training args
246
  # options are ["zero convolutions", "input hint blocks"], trains whole controlnet by default
247
  training_stage="" if train_whole_controlnet else "zero convolutions",
248
  learning_rate=float(train_learning_rate),
249
- num_train_epochs=1000,
250
- max_train_steps=int(max_train_steps),
251
  seed=3434554,
252
  max_grad_norm=1.0,
253
  gradient_accumulation_steps=int(gradient_accumulation_steps),
@@ -271,7 +270,7 @@ def run_training(
271
 
272
  try:
273
  lab = Lab(training_args, pipe)
274
- lab.train(training_args.num_train_epochs)
275
  except Exception as e:
276
  raise gr.Error(e)
277
 
 
135
  controlnet_prompt=None,
136
  controlnet_negative_prompt=None,
137
  controlnet_cond_scale=1.0,
138
+ progress=gr.Progress(),
139
  ):
140
 
141
  if seed == -1:
142
  seed = random.randint(0, 2147483647)
143
 
144
  if guidance_image:
145
+ guidance_image = extract_canny(guidance_image)
146
  else:
147
  guidance_image = torch.zeros(n_images, 3, height, width)
148
 
 
199
  train_batch_size,
200
  train_whole_controlnet,
201
  gradient_accumulation_steps,
202
+ num_train_epochs,
203
  train_learning_rate,
204
  output_dir,
205
  checkpointing_steps,
206
  image_logging_steps,
207
  save_whole_pipeline,
208
+ progress=gr.Progress(),
209
  ):
210
  global pipe
211
 
 
240
  valid_data_dir=valid_data_dir,
241
  resolution=512,
242
  from_hf_hub = train_data_dir == "lint/anybooru",
243
+ controlnet_hint_key="canny",
244
 
245
  # training args
246
  # options are ["zero convolutions", "input hint blocks"], trains whole controlnet by default
247
  training_stage="" if train_whole_controlnet else "zero convolutions",
248
  learning_rate=float(train_learning_rate),
249
+ num_train_epochs=int(num_train_epochs),
 
250
  seed=3434554,
251
  max_grad_norm=1.0,
252
  gradient_accumulation_steps=int(gradient_accumulation_steps),
 
270
 
271
  try:
272
  lab = Lab(training_args, pipe)
273
+ lab.train(training_args.num_train_epochs, gr_progress=progress)
274
  except Exception as e:
275
  raise gr.Error(e)
276