Spaces:
Build error
Build error
Nupur Kumari
commited on
Commit
·
f4d0eb6
1
Parent(s):
c5ee943
update
Browse files- trainer.py +4 -1
trainer.py
CHANGED
@@ -70,6 +70,7 @@ class Trainer:
|
|
70 |
batch_size: int,
|
71 |
use_8bit_adam: bool,
|
72 |
gradient_checkpointing: bool,
|
|
|
73 |
) -> tuple[dict, list[pathlib.Path]]:
|
74 |
if not torch.cuda.is_available():
|
75 |
raise gr.Error('CUDA is not available.')
|
@@ -94,7 +95,7 @@ class Trainer:
|
|
94 |
--output_dir={self.output_dir} \
|
95 |
--instance_prompt="{concept_prompt}" \
|
96 |
--class_data_dir={self.class_data_dir} \
|
97 |
-
--with_prior_preservation --
|
98 |
--class_prompt="{class_prompt}" \
|
99 |
--resolution={resolution} \
|
100 |
--train_batch_size={batch_size} \
|
@@ -108,6 +109,8 @@ class Trainer:
|
|
108 |
'''
|
109 |
if modifier_token:
|
110 |
command += ' --modifier_token "<new1>"'
|
|
|
|
|
111 |
if use_8bit_adam:
|
112 |
command += ' --use_8bit_adam'
|
113 |
if train_text_encoder:
|
|
|
70 |
batch_size: int,
|
71 |
use_8bit_adam: bool,
|
72 |
gradient_checkpointing: bool,
|
73 |
+
gen_images: bool,
|
74 |
) -> tuple[dict, list[pathlib.Path]]:
|
75 |
if not torch.cuda.is_available():
|
76 |
raise gr.Error('CUDA is not available.')
|
|
|
95 |
--output_dir={self.output_dir} \
|
96 |
--instance_prompt="{concept_prompt}" \
|
97 |
--class_data_dir={self.class_data_dir} \
|
98 |
+
--with_prior_preservation --prior_loss_weight=1.0 \
|
99 |
--class_prompt="{class_prompt}" \
|
100 |
--resolution={resolution} \
|
101 |
--train_batch_size={batch_size} \
|
|
|
109 |
'''
|
110 |
if modifier_token:
|
111 |
command += ' --modifier_token "<new1>"'
|
112 |
+
if not gen_images:
|
113 |
+
command += ' --real_prior'
|
114 |
if use_8bit_adam:
|
115 |
command += ' --use_8bit_adam'
|
116 |
if train_text_encoder:
|