multimodalart HF staff commited on
Commit
7967a47
1 Parent(s): cb9bf15

Remove GPU attribution if CUDA error

Browse files
Files changed (1) hide show
  1. app.py +96 -72
app.py CHANGED
@@ -35,15 +35,8 @@ else:
35
  is_shared_ui = False
36
  is_gpu_associated = torch.cuda.is_available()
37
 
38
- css = '''
39
- .instruction{position: absolute; top: 0;right: 0;margin-top: 0px !important}
40
- .arrow{position: absolute;top: 0;right: -110px;margin-top: -8px !important}
41
- #component-4, #component-3, #component-10{min-height: 0}
42
- .duplicate-button img{margin: 0}
43
- '''
44
- maximum_concepts = 3
45
 
46
- #Pre download the files
47
  if(is_gpu_associated):
48
  model_v1 = snapshot_download(repo_id="multimodalart/sd-fine-tunable")
49
  model_v2 = snapshot_download(repo_id="stabilityai/stable-diffusion-2-1", ignore_patterns=["*.ckpt", "*.safetensors"])
@@ -51,8 +44,25 @@ if(is_gpu_associated):
51
  safety_checker = snapshot_download(repo_id="multimodalart/sd-sc")
52
  model_to_load = model_v1
53
 
54
- #with zipfile.ZipFile("mix.zip", 'r') as zip_ref:
55
- # zip_ref.extractall(".")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  def swap_text(option, base):
58
  resize_width = 768 if base == "v2-1-768" else 512
@@ -60,7 +70,7 @@ def swap_text(option, base):
60
  if(option == "object"):
61
  instance_prompt_example = "cttoy"
62
  freeze_for = 30
63
- return [f"You are going to train `object`(s), upload 5-10 images of each object you are planning on training on from different angles/perspectives. You can use services like <a style='text-decoration: underline' target='_blank' href='https://www.birme.net/?target_width={resize_width}&target_height={resize_width}'>birme</a> for smart cropping. {mandatory_liability}:", '''<img src="file/cat-toy.png" />''', f"You should name your concept with a unique made up word that has low chance of the model already knowing it (e.g.: `{instance_prompt_example}` here). Images will be automatically cropped to {resize_width}x{resize_width}.", freeze_for, gr.update(visible=False)]
64
  elif(option == "person"):
65
  instance_prompt_example = "julcto"
66
  freeze_for = 70
@@ -70,27 +80,17 @@ def swap_text(option, base):
70
  prior_preservation_box_update = gr.update(visible=show_prior_preservation)
71
  else:
72
  prior_preservation_box_update = gr.update(visible=show_prior_preservation, value=False)
73
- return [f"You are going to train a `person`(s), upload 10-20 images of each person you are planning on training on from different angles/perspectives. You can use services like <a style='text-decoration: underline' target='_blank' href='https://www.birme.net/?target_width={resize_width}&target_height={resize_width}'>birme</a> for smart cropping. {mandatory_liability}:", '''<img src="file/person.png" />''', f"You should name your concept with a unique made up word that has low chance of the model already knowing it (e.g.: `{instance_prompt_example}` here). Images will be automatically cropped to {resize_width}x{resize_width}.", freeze_for, prior_preservation_box_update]
74
  elif(option == "style"):
75
  instance_prompt_example = "trsldamrl"
76
  freeze_for = 10
77
- return [f"You are going to train a `style`, upload 10-20 images of the style you are planning on training on. You can use services like <a style='text-decoration: underline' target='_blank' href='https://www.birme.net/?target_width={resize_width}&target_height={resize_width}'>birme</a> for smart cropping. {mandatory_liability}:", '''<img src="file/trsl_style.png" />''', f"You should name your concept with a unique made up word that has low chance of the model already knowing it (e.g.: `{instance_prompt_example}` here). Images will be automatically cropped to {resize_width}x{resize_width}", freeze_for, gr.update(visible=False)]
78
-
79
- def swap_base_model(selected_model):
80
- if(is_gpu_associated):
81
- global model_to_load
82
- if(selected_model == "v1-5"):
83
- model_to_load = model_v1
84
- elif(selected_model == "v2-1-768"):
85
- model_to_load = model_v2
86
- else:
87
- model_to_load = model_v2_512
88
 
89
  def count_files(*inputs):
90
  file_counter = 0
91
  concept_counter = 0
92
  for i, input in enumerate(inputs):
93
- if(i < maximum_concepts-1):
94
  files = inputs[i]
95
  if(files):
96
  concept_counter+=1
@@ -133,6 +133,9 @@ def update_steps(*files_list):
133
  file_counter+=len(files)
134
  return(gr.update(value=file_counter*200))
135
 
 
 
 
136
  def pad_image(image):
137
  w, h = image.size
138
  if w == h:
@@ -163,7 +166,34 @@ def validate_model_upload(hf_token, model_name):
163
  if(model_name == ""):
164
  raise gr.Error("Please fill in your model's name")
165
 
166
- def train(*inputs):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  if is_shared_ui:
168
  raise gr.Error("This Space only works in duplicated instances")
169
  if not is_gpu_associated:
@@ -171,6 +201,9 @@ def train(*inputs):
171
  hf_token = inputs[-5]
172
  model_name = inputs[-7]
173
  if(is_spaces):
 
 
 
174
  remove_attribution_after = inputs[-6]
175
  else:
176
  remove_attribution_after = False
@@ -191,7 +224,6 @@ def train(*inputs):
191
  if os.path.exists("model.ckpt"): os.remove("model.ckpt")
192
  if os.path.exists("hastrained.success"): os.remove("hastrained.success")
193
  file_counter = 0
194
- which_model = inputs[-10]
195
  resolution = 512 if which_model != "v2-1-768" else 768
196
  for i, input in enumerate(inputs):
197
  if(i < maximum_concepts-1):
@@ -261,38 +293,22 @@ def train(*inputs):
261
  print("Starting single training...")
262
  lock_file = open("intraining.lock", "w")
263
  lock_file.close()
264
- run_training(args_general)
265
- else:
266
- args_general = argparse.Namespace(
267
- image_captions_filename = True,
268
- train_text_encoder = True if stptxt > 0 else False,
269
- stop_text_encoder_training = stptxt,
270
- save_n_steps = 0,
271
- pretrained_model_name_or_path = model_to_load,
272
- instance_data_dir="instance_images",
273
- class_data_dir="Mix",
274
- output_dir="output_model",
275
- with_prior_preservation=True,
276
- prior_loss_weight=1.0,
277
- instance_prompt="",
278
- seed=42,
279
- resolution=resolution,
280
- mixed_precision="fp16",
281
- train_batch_size=1,
282
- gradient_accumulation_steps=1,
283
- use_8bit_adam=True,
284
- learning_rate=2e-6,
285
- lr_scheduler="polynomial",
286
- lr_warmup_steps = 0,
287
- max_train_steps=Training_Steps,
288
- num_class_images=200,
289
- gradient_checkpointing=gradient_checkpointing,
290
- cache_latents=cache_latents,
291
- )
292
- print("Starting multi-training...")
293
- lock_file = open("intraining.lock", "w")
294
- lock_file.close()
295
- run_training(args_general)
296
  gc.collect()
297
  torch.cuda.empty_cache()
298
  if(which_model == "v1-5"):
@@ -302,6 +318,7 @@ def train(*inputs):
302
  shutil.copy(f"model_index.json", "output_model/model_index.json")
303
 
304
  if(not remove_attribution_after):
 
305
  print("Archiving model file...")
306
  with tarfile.open("diffusers_model.tar", "w") as tar:
307
  tar.add("output_model", arcname=os.path.basename("output_model"))
@@ -310,6 +327,7 @@ def train(*inputs):
310
  trained_file.close()
311
  print("Training completed!")
312
  return [
 
313
  gr.update(visible=True, value=["diffusers_model.tar"]), #result
314
  gr.update(visible=True), #try_your_model
315
  gr.update(visible=True), #push_to_hub
@@ -320,10 +338,7 @@ def train(*inputs):
320
  else:
321
  where_to_upload = inputs[-8]
322
  push(model_name, where_to_upload, hf_token, which_model, True)
323
- hardware_url = f"https://huggingface.co/spaces/{os.environ['SPACE_ID']}/hardware"
324
- headers = { "authorization" : f"Bearer {hf_token}"}
325
- body = {'flavor': 'cpu-basic'}
326
- requests.post(hardware_url, json = body, headers=headers)
327
 
328
  pipe_is_set = False
329
  def generate(prompt, steps):
@@ -338,7 +353,7 @@ def generate(prompt, steps):
338
 
339
  image = pipe(prompt, num_inference_steps=steps).images[0]
340
  return(image)
341
-
342
  def push(model_name, where_to_upload, hf_token, which_model, comes_from_automated=False):
343
  validate_model_upload(hf_token, model_name)
344
  if(not os.path.exists("model.ckpt")):
@@ -425,7 +440,10 @@ Sample pictures of:
425
  extra_message = "Don't forget to remove the GPU attribution after you play with it."
426
  else:
427
  extra_message = "The GPU has been removed automatically as requested, and you can try the model via the model page"
428
- api.create_discussion(repo_id=os.environ['SPACE_ID'], title=f"Your model {model_name} has finished trained from the Dreambooth Train Spaces!", description=f"Your model has been successfully uploaded to: https://huggingface.co/{model_id}. {extra_message}",repo_type="space", token=hf_token)
 
 
 
429
  print("Model uploaded successfully!")
430
  return [gr.update(visible=True, value=f"Successfully uploaded your model. Access it [here](https://huggingface.co/{model_id})"), gr.update(visible=True, value=["diffusers_model.tar", "model.ckpt"])]
431
 
@@ -488,8 +506,8 @@ with gr.Blocks(css=css) as demo:
488
  <div class="gr-prose" style="max-width: 80%">
489
  <h2>Attention - This Space doesn't work in this shared UI</h2>
490
  <p>For it to work, you can either run locally or duplicate the Space and run it on your own profile using a (paid) private T4-small or A10G-small GPU for training. A T4 costs US$0.60/h, so it should cost < US$1 to train most models using default settings with it!&nbsp;&nbsp;<a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{os.environ['SPACE_ID']}?duplicate=true"><img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a></p>
491
- <img class="instruction" src="file/duplicate.png">
492
- <img class="arrow" src="file/arrow.png" />
493
  </div>
494
  ''')
495
  elif(is_spaces):
@@ -519,14 +537,15 @@ with gr.Blocks(css=css) as demo:
519
 
520
  with gr.Row() as what_are_you_training:
521
  type_of_thing = gr.Dropdown(label="What would you like to train?", choices=["object", "person", "style"], value="object", interactive=True)
522
- base_model_to_use = gr.Dropdown(label="Which base model would you like to use?", choices=["v1-5", "v2-1-512", "v2-1-768"], value="v1-5", interactive=True)
523
-
 
524
  #Very hacky approach to emulate dynamically created Gradio components
525
  with gr.Row() as upload_your_concept:
526
  with gr.Column():
527
  thing_description = gr.Markdown("You are going to train an `object`, please upload 5-10 images of the object you are planning on training on from different angles/perspectives. You must have the right to do so and you are liable for the images you use, example")
528
  thing_experimental = gr.Checkbox(label="Improve faces (prior preservation) - can take longer training but can improve faces", visible=False, value=False)
529
- thing_image_example = gr.HTML('''<img src="file/cat-toy.png" />''')
530
  things_naming = gr.Markdown("You should name your concept with a unique made up word that has low chance of the model already knowing it (e.g.: `cttoy` here). Images will be automatically cropped to 512x512.")
531
 
532
  with gr.Column():
@@ -588,6 +607,7 @@ with gr.Blocks(css=css) as demo:
588
  training_summary_token = gr.Textbox(label="Hugging Face Write Token", type="password", visible=True)
589
 
590
  train_btn = gr.Button("Start Training")
 
591
  if(is_shared_ui):
592
  training_ongoing = gr.Markdown("## This Space only works in duplicated instances. Please duplicate it and try again!", visible=False)
593
  elif(not is_gpu_associated):
@@ -595,6 +615,7 @@ with gr.Blocks(css=css) as demo:
595
  else:
596
  training_ongoing = gr.Markdown("## Training is ongoing ⌛... You can close this tab if you like or just wait. If you did not check the `Remove GPU After training`, you can come back here to try your model and upload it after training. Don't forget to remove the GPU attribution after you are done. ", visible=False)
597
 
 
598
  #Post-training UI
599
  completed_training = gr.Markdown('''# ✅ Training completed.
600
  ### Don't forget to remove the GPU attribution after you are done trying and uploading your model''', visible=False)
@@ -624,9 +645,10 @@ with gr.Blocks(css=css) as demo:
624
  type_of_thing.change(fn=swap_text, inputs=[type_of_thing, base_model_to_use], outputs=[thing_description, thing_image_example, things_naming, perc_txt_encoder, thing_experimental], queue=False, show_progress=False)
625
 
626
  #Swap the base model
 
627
  base_model_to_use.change(fn=swap_text, inputs=[type_of_thing, base_model_to_use], outputs=[thing_description, thing_image_example, things_naming, perc_txt_encoder, thing_experimental], queue=False, show_progress=False)
 
628
  base_model_to_use.change(fn=swap_base_model, inputs=base_model_to_use, outputs=[])
629
-
630
  #Update the summary box below the UI according to how many images are uploaded and whether users are using custom settings or not
631
  for file in file_collection:
632
  #file.change(fn=update_steps,inputs=file_collection, outputs=steps)
@@ -641,10 +663,12 @@ with gr.Blocks(css=css) as demo:
641
  if(is_spaces):
642
  training_summary_checkbox.change(fn=checkbox_swap, inputs=training_summary_checkbox, outputs=[training_summary_token_message, training_summary_token, training_summary_model_name, training_summary_where_to_upload],queue=False, show_progress=False)
643
  #Add a message for while it is in training
644
- train_btn.click(lambda:gr.update(visible=True), inputs=None, outputs=training_ongoing)
 
645
 
646
  #The main train function
647
- train_btn.click(fn=train, inputs=is_visible+concept_collection+file_collection+[base_model_to_use]+[thing_experimental]+[training_summary_where_to_upload]+[training_summary_model_name]+[training_summary_checkbox]+[training_summary_token]+[type_of_thing]+[steps]+[perc_txt_encoder]+[swap_auto_calculated], outputs=[result, try_your_model, push_to_hub, convert_button, training_ongoing, completed_training], queue=False)
 
648
 
649
  #Button to generate an image from your trained model after training
650
  generate_button.click(fn=generate, inputs=[prompt, inference_steps], outputs=result_image, queue=False)
 
35
  is_shared_ui = False
36
  is_gpu_associated = torch.cuda.is_available()
37
 
38
+ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
 
 
 
 
 
 
39
 
 
40
  if(is_gpu_associated):
41
  model_v1 = snapshot_download(repo_id="multimodalart/sd-fine-tunable")
42
  model_v2 = snapshot_download(repo_id="stabilityai/stable-diffusion-2-1", ignore_patterns=["*.ckpt", "*.safetensors"])
 
44
  safety_checker = snapshot_download(repo_id="multimodalart/sd-sc")
45
  model_to_load = model_v1
46
 
47
+ def swap_base_model(selected_model):
48
+ if(is_gpu_associated):
49
+ global model_to_load
50
+ if(selected_model == "v1-5"):
51
+ model_to_load = model_v1
52
+ elif(selected_model == "v2-1-768"):
53
+ model_to_load = model_v2
54
+ else:
55
+ model_to_load = model_v2_512
56
+
57
+
58
+
59
+ css = '''
60
+ .instruction{position: absolute; top: 0;right: 0;margin-top: 0px !important}
61
+ .arrow{position: absolute;top: 0;right: -110px;margin-top: -8px !important}
62
+ #component-4, #component-3, #component-10{min-height: 0}
63
+ .duplicate-button img{margin: 0}
64
+ '''
65
+ maximum_concepts = 3
66
 
67
  def swap_text(option, base):
68
  resize_width = 768 if base == "v2-1-768" else 512
 
70
  if(option == "object"):
71
  instance_prompt_example = "cttoy"
72
  freeze_for = 30
73
+ return [f"You are going to train `object`(s), upload 5-10 images of each object you are planning on training on from different angles/perspectives. You can use services like <a style='text-decoration: underline' target='_blank' href='https://www.birme.net/?target_width={resize_width}&target_height={resize_width}'>birme</a> for smart cropping. {mandatory_liability}:", '''<img src="file=cat-toy.png" />''', f"You should name your concept with a unique made up word that has low chance of the model already knowing it (e.g.: `{instance_prompt_example}` here). Images will be automatically cropped to {resize_width}x{resize_width}.", freeze_for, gr.update(visible=False)]
74
  elif(option == "person"):
75
  instance_prompt_example = "julcto"
76
  freeze_for = 70
 
80
  prior_preservation_box_update = gr.update(visible=show_prior_preservation)
81
  else:
82
  prior_preservation_box_update = gr.update(visible=show_prior_preservation, value=False)
83
+ return [f"You are going to train a `person`(s), upload 10-20 images of each person you are planning on training on from different angles/perspectives. You can use services like <a style='text-decoration: underline' target='_blank' href='https://www.birme.net/?target_width={resize_width}&target_height={resize_width}'>birme</a> for smart cropping. {mandatory_liability}:", '''<img src="file=person.png" />''', f"You should name your concept with a unique made up word that has low chance of the model already knowing it (e.g.: `{instance_prompt_example}` here). Images will be automatically cropped to {resize_width}x{resize_width}.", freeze_for, prior_preservation_box_update]
84
  elif(option == "style"):
85
  instance_prompt_example = "trsldamrl"
86
  freeze_for = 10
87
+ return [f"You are going to train a `style`, upload 10-20 images of the style you are planning on training on. You can use services like <a style='text-decoration: underline' target='_blank' href='https://www.birme.net/?target_width={resize_width}&target_height={resize_width}'>birme</a> for smart cropping. {mandatory_liability}:", '''<img src="file=trsl_style.png" />''', f"You should name your concept with a unique made up word that has low chance of the model already knowing it (e.g.: `{instance_prompt_example}` here). Images will be automatically cropped to {resize_width}x{resize_width}", freeze_for, gr.update(visible=False)]
 
 
 
 
 
 
 
 
 
 
88
 
89
  def count_files(*inputs):
90
  file_counter = 0
91
  concept_counter = 0
92
  for i, input in enumerate(inputs):
93
+ if(i < maximum_concepts):
94
  files = inputs[i]
95
  if(files):
96
  concept_counter+=1
 
133
  file_counter+=len(files)
134
  return(gr.update(value=file_counter*200))
135
 
136
+ def visualise_progress_bar():
137
+ return gr.update(visible=True)
138
+
139
  def pad_image(image):
140
  w, h = image.size
141
  if w == h:
 
166
  if(model_name == ""):
167
  raise gr.Error("Please fill in your model's name")
168
 
169
+ def swap_hardware(hf_token, hardware="cpu-basic"):
170
+ hardware_url = f"https://huggingface.co/spaces/{os.environ['SPACE_ID']}/hardware"
171
+ headers = { "authorization" : f"Bearer {hf_token}"}
172
+ body = {'flavor': hardware}
173
+ requests.post(hardware_url, json = body, headers=headers)
174
+
175
+ def swap_sleep_time(hf_token,sleep_time):
176
+ sleep_time_url = f"https://huggingface.co/api/spaces/{os.environ['SPACE_ID']}/sleeptime"
177
+ headers = { "authorization" : f"Bearer {hf_token}"}
178
+ body = {'seconds':sleep_time}
179
+ requests.post(sleep_time_url,json=body,headers=headers)
180
+
181
+ def get_sleep_time(hf_token):
182
+ sleep_time_url = f"https://huggingface.co/api/spaces/{os.environ['SPACE_ID']}"
183
+ headers = { "authorization" : f"Bearer {hf_token}"}
184
+ response = requests.get(sleep_time_url,headers=headers)
185
+ return response.json()['runtime']['gcTimeout']
186
+
187
+ def write_to_community(title, description,hf_token):
188
+ from huggingface_hub import HfApi
189
+ api = HfApi()
190
+ api.create_discussion(repo_id=os.environ['SPACE_ID'], title=title, description=description,repo_type="space", token=hf_token)
191
+
192
+ def train(progress=gr.Progress(track_tqdm=True), *inputs):
193
+ which_model = inputs[-10]
194
+ if(which_model == ""):
195
+ raise gr.Error("You forgot to select a base model to use")
196
+
197
  if is_shared_ui:
198
  raise gr.Error("This Space only works in duplicated instances")
199
  if not is_gpu_associated:
 
201
  hf_token = inputs[-5]
202
  model_name = inputs[-7]
203
  if(is_spaces):
204
+ sleep_time = get_sleep_time(hf_token)
205
+ if sleep_time:
206
+ swap_sleep_time(hf_token, -1)
207
  remove_attribution_after = inputs[-6]
208
  else:
209
  remove_attribution_after = False
 
224
  if os.path.exists("model.ckpt"): os.remove("model.ckpt")
225
  if os.path.exists("hastrained.success"): os.remove("hastrained.success")
226
  file_counter = 0
 
227
  resolution = 512 if which_model != "v2-1-768" else 768
228
  for i, input in enumerate(inputs):
229
  if(i < maximum_concepts-1):
 
293
  print("Starting single training...")
294
  lock_file = open("intraining.lock", "w")
295
  lock_file.close()
296
+ try:
297
+ run_training(args_general)
298
+ except Exception as e:
299
+ if(is_spaces):
300
+ title="There was an error on during your training"
301
+ description=f'''
302
+ Unfortunately there was an error during training your {model_name} model.
303
+ Please check it out below. Feel free to report this issue to [Dreambooth Training](https://huggingface.co/spaces/multimodalart/dreambooth-training):
304
+ ```
305
+ {str(e)}
306
+ ```
307
+ '''
308
+ swap_hardware(hf_token, "cpu-basic")
309
+ write_to_community(title,description,hf_token)
310
+
311
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
312
  gc.collect()
313
  torch.cuda.empty_cache()
314
  if(which_model == "v1-5"):
 
318
  shutil.copy(f"model_index.json", "output_model/model_index.json")
319
 
320
  if(not remove_attribution_after):
321
+ swap_sleep_time(hf_token, sleep_time)
322
  print("Archiving model file...")
323
  with tarfile.open("diffusers_model.tar", "w") as tar:
324
  tar.add("output_model", arcname=os.path.basename("output_model"))
 
327
  trained_file.close()
328
  print("Training completed!")
329
  return [
330
+ gr.update(visible=False), #progress_bar
331
  gr.update(visible=True, value=["diffusers_model.tar"]), #result
332
  gr.update(visible=True), #try_your_model
333
  gr.update(visible=True), #push_to_hub
 
338
  else:
339
  where_to_upload = inputs[-8]
340
  push(model_name, where_to_upload, hf_token, which_model, True)
341
+ swap_hardware(hf_token, "cpu-basic")
 
 
 
342
 
343
  pipe_is_set = False
344
  def generate(prompt, steps):
 
353
 
354
  image = pipe(prompt, num_inference_steps=steps).images[0]
355
  return(image)
356
+
357
  def push(model_name, where_to_upload, hf_token, which_model, comes_from_automated=False):
358
  validate_model_upload(hf_token, model_name)
359
  if(not os.path.exists("model.ckpt")):
 
440
  extra_message = "Don't forget to remove the GPU attribution after you play with it."
441
  else:
442
  extra_message = "The GPU has been removed automatically as requested, and you can try the model via the model page"
443
+ title=f"Your model {model_name} has finished trained from the Dreambooth Train Spaces!"
444
+ description=f"Your model has been successfully uploaded to: https://huggingface.co/{model_id}. {extra_message}"
445
+ write_to_community(title, description, hf_token)
446
+ #api.create_discussion(repo_id=os.environ['SPACE_ID'], title=f"Your model {model_name} has finished trained from the Dreambooth Train Spaces!", description=f"Your model has been successfully uploaded to: https://huggingface.co/{model_id}. {extra_message}",repo_type="space", token=hf_token)
447
  print("Model uploaded successfully!")
448
  return [gr.update(visible=True, value=f"Successfully uploaded your model. Access it [here](https://huggingface.co/{model_id})"), gr.update(visible=True, value=["diffusers_model.tar", "model.ckpt"])]
449
 
 
506
  <div class="gr-prose" style="max-width: 80%">
507
  <h2>Attention - This Space doesn't work in this shared UI</h2>
508
  <p>For it to work, you can either run locally or duplicate the Space and run it on your own profile using a (paid) private T4-small or A10G-small GPU for training. A T4 costs US$0.60/h, so it should cost < US$1 to train most models using default settings with it!&nbsp;&nbsp;<a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{os.environ['SPACE_ID']}?duplicate=true"><img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a></p>
509
+ <img class="instruction" src="file=duplicate.png">
510
+ <img class="arrow" src="file=arrow.png" />
511
  </div>
512
  ''')
513
  elif(is_spaces):
 
537
 
538
  with gr.Row() as what_are_you_training:
539
  type_of_thing = gr.Dropdown(label="What would you like to train?", choices=["object", "person", "style"], value="object", interactive=True)
540
+ with gr.Column():
541
+ base_model_to_use = gr.Dropdown(label="Which base model would you like to use?", choices=["v1-5", "v2-1-512", "v2-1-768"], value="v1-5", interactive=True)
542
+
543
  #Very hacky approach to emulate dynamically created Gradio components
544
  with gr.Row() as upload_your_concept:
545
  with gr.Column():
546
  thing_description = gr.Markdown("You are going to train an `object`, please upload 5-10 images of the object you are planning on training on from different angles/perspectives. You must have the right to do so and you are liable for the images you use, example")
547
  thing_experimental = gr.Checkbox(label="Improve faces (prior preservation) - can take longer training but can improve faces", visible=False, value=False)
548
+ thing_image_example = gr.HTML('''<img src="file=cat-toy.png" />''')
549
  things_naming = gr.Markdown("You should name your concept with a unique made up word that has low chance of the model already knowing it (e.g.: `cttoy` here). Images will be automatically cropped to 512x512.")
550
 
551
  with gr.Column():
 
607
  training_summary_token = gr.Textbox(label="Hugging Face Write Token", type="password", visible=True)
608
 
609
  train_btn = gr.Button("Start Training")
610
+ progress_bar = gr.Textbox(visible=False)
611
  if(is_shared_ui):
612
  training_ongoing = gr.Markdown("## This Space only works in duplicated instances. Please duplicate it and try again!", visible=False)
613
  elif(not is_gpu_associated):
 
615
  else:
616
  training_ongoing = gr.Markdown("## Training is ongoing ⌛... You can close this tab if you like or just wait. If you did not check the `Remove GPU After training`, you can come back here to try your model and upload it after training. Don't forget to remove the GPU attribution after you are done. ", visible=False)
617
 
618
+
619
  #Post-training UI
620
  completed_training = gr.Markdown('''# ✅ Training completed.
621
  ### Don't forget to remove the GPU attribution after you are done trying and uploading your model''', visible=False)
 
645
  type_of_thing.change(fn=swap_text, inputs=[type_of_thing, base_model_to_use], outputs=[thing_description, thing_image_example, things_naming, perc_txt_encoder, thing_experimental], queue=False, show_progress=False)
646
 
647
  #Swap the base model
648
+
649
  base_model_to_use.change(fn=swap_text, inputs=[type_of_thing, base_model_to_use], outputs=[thing_description, thing_image_example, things_naming, perc_txt_encoder, thing_experimental], queue=False, show_progress=False)
650
+ #base_model_to_use.change(fn=visualise_progress_bar, inputs=[], outputs=progress_bar)
651
  base_model_to_use.change(fn=swap_base_model, inputs=base_model_to_use, outputs=[])
 
652
  #Update the summary box below the UI according to how many images are uploaded and whether users are using custom settings or not
653
  for file in file_collection:
654
  #file.change(fn=update_steps,inputs=file_collection, outputs=steps)
 
663
  if(is_spaces):
664
  training_summary_checkbox.change(fn=checkbox_swap, inputs=training_summary_checkbox, outputs=[training_summary_token_message, training_summary_token, training_summary_model_name, training_summary_where_to_upload],queue=False, show_progress=False)
665
  #Add a message for while it is in training
666
+
667
+ #train_btn.click(lambda:gr.update(visible=True), inputs=None, outputs=training_ongoing)
668
 
669
  #The main train function
670
+ train_btn.click(lambda:gr.update(visible=True), inputs=[], outputs=progress_bar)
671
+ train_btn.click(fn=train, inputs=is_visible+concept_collection+file_collection+[base_model_to_use]+[thing_experimental]+[training_summary_where_to_upload]+[training_summary_model_name]+[training_summary_checkbox]+[training_summary_token]+[type_of_thing]+[steps]+[perc_txt_encoder]+[swap_auto_calculated], outputs=[progress_bar, result, try_your_model, push_to_hub, convert_button, training_ongoing, completed_training], queue=False)
672
 
673
  #Button to generate an image from your trained model after training
674
  generate_button.click(fn=generate, inputs=[prompt, inference_steps], outputs=result_image, queue=False)