Spaces:
Runtime error
Runtime error
Commit
•
a47c17f
1
Parent(s):
74e4319
Add v2-768
Browse files- app.py +4 -5
- train_dreambooth.py +3 -3
app.py
CHANGED
@@ -58,8 +58,8 @@ def swap_base_model(selected_model):
|
|
58 |
global model_to_load
|
59 |
if(selected_model == "v1-5"):
|
60 |
model_to_load = model_v1
|
61 |
-
|
62 |
-
|
63 |
else:
|
64 |
model_to_load = model_v2_512
|
65 |
|
@@ -171,8 +171,7 @@ def train(*inputs):
|
|
171 |
Training_Steps=1400
|
172 |
|
173 |
stptxt = int((Training_Steps*Train_text_encoder_for)/100)
|
174 |
-
|
175 |
-
gradient_checkpointing=False
|
176 |
resolution = 512 if which_model != "v2-768" else 768
|
177 |
cache_latents = True if which_model != "v1-5" else False
|
178 |
if (type_of_thing == "object" or type_of_thing == "style" or (type_of_thing == "person" and not experimental_face_improvement)):
|
@@ -445,7 +444,7 @@ with gr.Blocks(css=css) as demo:
|
|
445 |
|
446 |
with gr.Row() as what_are_you_training:
|
447 |
type_of_thing = gr.Dropdown(label="What would you like to train?", choices=["object", "person", "style"], value="object", interactive=True)
|
448 |
-
base_model_to_use = gr.Dropdown(label="Which base model would you like to use?", choices=["v1-5", "v2-512"], value="v1-5", interactive=True)
|
449 |
|
450 |
#Very hacky approach to emulate dynamically created Gradio components
|
451 |
with gr.Row() as upload_your_concept:
|
|
|
58 |
global model_to_load
|
59 |
if(selected_model == "v1-5"):
|
60 |
model_to_load = model_v1
|
61 |
+
elif(selected_model == "v2-768"):
|
62 |
+
model_to_load = model_v2
|
63 |
else:
|
64 |
model_to_load = model_v2_512
|
65 |
|
|
|
171 |
Training_Steps=1400
|
172 |
|
173 |
stptxt = int((Training_Steps*Train_text_encoder_for)/100)
|
174 |
+
gradient_checkpointing = False if which_model == "v1-5" else True
|
|
|
175 |
resolution = 512 if which_model != "v2-768" else 768
|
176 |
cache_latents = True if which_model != "v1-5" else False
|
177 |
if (type_of_thing == "object" or type_of_thing == "style" or (type_of_thing == "person" and not experimental_face_improvement)):
|
|
|
444 |
|
445 |
with gr.Row() as what_are_you_training:
|
446 |
type_of_thing = gr.Dropdown(label="What would you like to train?", choices=["object", "person", "style"], value="object", interactive=True)
|
447 |
+
base_model_to_use = gr.Dropdown(label="Which base model would you like to use?", choices=["v1-5", "v2-512", "v2-768"], value="v1-5", interactive=True)
|
448 |
|
449 |
#Very hacky approach to emulate dynamically created Gradio components
|
450 |
with gr.Row() as upload_your_concept:
|
train_dreambooth.py
CHANGED
@@ -710,10 +710,10 @@ def run_training(args_imported):
|
|
710 |
# Convert images to latent space
|
711 |
with torch.no_grad():
|
712 |
if args.cache_latents:
|
713 |
-
|
714 |
else:
|
715 |
-
|
716 |
-
latents =
|
717 |
|
718 |
# Sample noise that we'll add to the latents
|
719 |
noise = torch.randn_like(latents)
|
|
|
710 |
# Convert images to latent space
|
711 |
with torch.no_grad():
|
712 |
if args.cache_latents:
|
713 |
+
latents_dist = batch[0][0]
|
714 |
else:
|
715 |
+
latents_dist = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist
|
716 |
+
latents = latents_dist.sample() * 0.18215
|
717 |
|
718 |
# Sample noise that we'll add to the latents
|
719 |
noise = torch.randn_like(latents)
|