Nupur Kumari commited on
Commit
f4d0eb6
·
1 Parent(s): c5ee943
Files changed (1) hide show
  1. trainer.py +4 -1
trainer.py CHANGED
@@ -70,6 +70,7 @@ class Trainer:
70
  batch_size: int,
71
  use_8bit_adam: bool,
72
  gradient_checkpointing: bool,
 
73
  ) -> tuple[dict, list[pathlib.Path]]:
74
  if not torch.cuda.is_available():
75
  raise gr.Error('CUDA is not available.')
@@ -94,7 +95,7 @@ class Trainer:
94
  --output_dir={self.output_dir} \
95
  --instance_prompt="{concept_prompt}" \
96
  --class_data_dir={self.class_data_dir} \
97
- --with_prior_preservation --real_prior --prior_loss_weight=1.0 \
98
  --class_prompt="{class_prompt}" \
99
  --resolution={resolution} \
100
  --train_batch_size={batch_size} \
@@ -108,6 +109,8 @@ class Trainer:
108
  '''
109
  if modifier_token:
110
  command += ' --modifier_token "<new1>"'
 
 
111
  if use_8bit_adam:
112
  command += ' --use_8bit_adam'
113
  if train_text_encoder:
 
70
  batch_size: int,
71
  use_8bit_adam: bool,
72
  gradient_checkpointing: bool,
73
+ gen_images: bool,
74
  ) -> tuple[dict, list[pathlib.Path]]:
75
  if not torch.cuda.is_available():
76
  raise gr.Error('CUDA is not available.')
 
95
  --output_dir={self.output_dir} \
96
  --instance_prompt="{concept_prompt}" \
97
  --class_data_dir={self.class_data_dir} \
98
+ --with_prior_preservation --prior_loss_weight=1.0 \
99
  --class_prompt="{class_prompt}" \
100
  --resolution={resolution} \
101
  --train_batch_size={batch_size} \
 
109
  '''
110
  if modifier_token:
111
  command += ' --modifier_token "<new1>"'
112
+ if not gen_images:
113
+ command += ' --real_prior'
114
  if use_8bit_adam:
115
  command += ' --use_8bit_adam'
116
  if train_text_encoder: