hysts commited on
Commit
ccc363e
1 Parent(s): 57920d0

Add changes from https://github.com/huggingface/diffusers/pull/2106

Browse files
Files changed (1) hide show
  1. train_dreambooth_lora.py +20 -14
train_dreambooth_lora.py CHANGED
@@ -215,7 +215,13 @@ def parse_args(input_args=None):
215
  ),
216
  )
217
  parser.add_argument(
218
- "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution"
 
 
 
 
 
 
219
  )
220
  parser.add_argument(
221
  "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
@@ -988,19 +994,19 @@ def main(args):
988
  out_path = test_image_dir / f'image_{i}.png'
989
  image.save(out_path)
990
 
991
- for tracker in accelerator.trackers:
992
- if tracker.name == "tensorboard":
993
- np_images = np.stack([np.asarray(img) for img in images])
994
- tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
995
- if tracker.name == "wandb":
996
- tracker.log(
997
- {
998
- "test": [
999
- wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
1000
- for i, image in enumerate(images)
1001
- ]
1002
- }
1003
- )
1004
 
1005
  if args.push_to_hub:
1006
  save_model_card(
 
215
  ),
216
  )
217
  parser.add_argument(
218
+ "--center_crop",
219
+ default=False,
220
+ action="store_true",
221
+ help=(
222
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
223
+ " cropped. The images will be resized to the resolution first before cropping."
224
+ ),
225
  )
226
  parser.add_argument(
227
  "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
 
994
  out_path = test_image_dir / f'image_{i}.png'
995
  image.save(out_path)
996
 
997
+ for tracker in accelerator.trackers:
998
+ if tracker.name == "tensorboard":
999
+ np_images = np.stack([np.asarray(img) for img in images])
1000
+ tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
1001
+ if tracker.name == "wandb":
1002
+ tracker.log(
1003
+ {
1004
+ "test": [
1005
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
1006
+ for i, image in enumerate(images)
1007
+ ]
1008
+ }
1009
+ )
1010
 
1011
  if args.push_to_hub:
1012
  save_model_card(