hysts
commited on
Commit
•
69ed433
1
Parent(s):
e305340
Use Uploader to upload models in training time
Browse filesUsing two different upload methods was not a good idea.
So, stop using upload method provided by train_dreambooth_lora.py
and use Uploader class in this repo.
Also, to make it easier to port updates for train_dreambooth_lora.py
from the diffusers library, reset changes.
- train_dreambooth_lora.py +39 -44
- trainer.py +7 -0
- utils.py +38 -0
train_dreambooth_lora.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1 |
#!/usr/bin/env python
|
2 |
-
# This file is adapted from https://github.com/huggingface/diffusers/blob/a66f2baeb782e091dde4e1e6394e46f169e5ba58/examples/dreambooth/train_dreambooth_lora.py
|
3 |
-
# The original license is as below.
|
4 |
-
#
|
5 |
# coding=utf-8
|
|
|
|
|
|
|
|
|
6 |
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
7 |
#
|
8 |
# Licensed under the Apache License, Version 2.0 (the "License");
|
@@ -25,6 +26,7 @@ import warnings
|
|
25 |
from pathlib import Path
|
26 |
from typing import Optional
|
27 |
|
|
|
28 |
import torch
|
29 |
import torch.nn.functional as F
|
30 |
import torch.utils.checkpoint
|
@@ -48,7 +50,7 @@ from diffusers.models.cross_attention import LoRACrossAttnProcessor
|
|
48 |
from diffusers.optimization import get_scheduler
|
49 |
from diffusers.utils import check_min_version, is_wandb_available
|
50 |
from diffusers.utils.import_utils import is_xformers_available
|
51 |
-
from huggingface_hub import HfFolder, Repository, create_repo,
|
52 |
from PIL import Image
|
53 |
from torchvision import transforms
|
54 |
from tqdm.auto import tqdm
|
@@ -61,9 +63,9 @@ check_min_version("0.12.0.dev0")
|
|
61 |
logger = get_logger(__name__)
|
62 |
|
63 |
|
64 |
-
def save_model_card(repo_name,
|
65 |
-
img_str =
|
66 |
-
for i, image in enumerate(images
|
67 |
image.save(os.path.join(repo_folder, f"image_{i}.png"))
|
68 |
img_str += f"![img_{i}](./image_{i}.png)\n"
|
69 |
|
@@ -71,7 +73,6 @@ def save_model_card(repo_name, base_model, instance_prompt, test_prompt="", imag
|
|
71 |
---
|
72 |
license: creativeml-openrail-m
|
73 |
base_model: {base_model}
|
74 |
-
instance_prompt: {instance_prompt}
|
75 |
tags:
|
76 |
- stable-diffusion
|
77 |
- stable-diffusion-diffusers
|
@@ -79,11 +80,11 @@ tags:
|
|
79 |
- diffusers
|
80 |
inference: true
|
81 |
---
|
82 |
-
"""
|
83 |
model_card = f"""
|
84 |
# LoRA DreamBooth - {repo_name}
|
85 |
|
86 |
-
These are LoRA adaption weights for
|
87 |
{img_str}
|
88 |
"""
|
89 |
with open(os.path.join(repo_folder, "README.md"), "w") as f:
|
@@ -364,9 +365,6 @@ def parse_args(input_args=None):
|
|
364 |
parser.add_argument(
|
365 |
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
|
366 |
)
|
367 |
-
parser.add_argument("--private_repo", action="store_true")
|
368 |
-
parser.add_argument("--delete_existing_repo", action="store_true")
|
369 |
-
parser.add_argument("--upload_to_lora_library", action="store_true")
|
370 |
|
371 |
if input_args is not None:
|
372 |
args = parser.parse_args(input_args)
|
@@ -610,17 +608,11 @@ def main(args):
|
|
610 |
if accelerator.is_main_process:
|
611 |
if args.push_to_hub:
|
612 |
if args.hub_model_id is None:
|
613 |
-
|
614 |
-
repo_name = get_full_repo_name(Path(args.output_dir).name, organization=organization, token=args.hub_token)
|
615 |
else:
|
616 |
repo_name = args.hub_model_id
|
617 |
|
618 |
-
|
619 |
-
try:
|
620 |
-
delete_repo(repo_name, token=args.hub_token)
|
621 |
-
except Exception:
|
622 |
-
pass
|
623 |
-
create_repo(repo_name, token=args.hub_token, private=args.private_repo)
|
624 |
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
|
625 |
|
626 |
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
|
@@ -826,14 +818,21 @@ def main(args):
|
|
826 |
dirs = os.listdir(args.output_dir)
|
827 |
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
828 |
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
829 |
-
path = dirs[-1]
|
830 |
-
accelerator.print(f"Resuming from checkpoint {path}")
|
831 |
-
accelerator.load_state(os.path.join(args.output_dir, path))
|
832 |
-
global_step = int(path.split("-")[1])
|
833 |
|
834 |
-
|
835 |
-
|
836 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
837 |
|
838 |
# Only show the progress bar once on each machine.
|
839 |
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
|
@@ -943,6 +942,9 @@ def main(args):
|
|
943 |
images = pipeline(prompt, num_inference_steps=25, generator=generator).images
|
944 |
|
945 |
for tracker in accelerator.trackers:
|
|
|
|
|
|
|
946 |
if tracker.name == "wandb":
|
947 |
tracker.log(
|
948 |
{
|
@@ -974,11 +976,15 @@ def main(args):
|
|
974 |
pipeline.unet.load_attn_procs(args.output_dir)
|
975 |
|
976 |
# run inference
|
977 |
-
|
978 |
-
|
979 |
-
|
|
|
980 |
|
981 |
for tracker in accelerator.trackers:
|
|
|
|
|
|
|
982 |
if tracker.name == "wandb":
|
983 |
tracker.log(
|
984 |
{
|
@@ -992,23 +998,12 @@ def main(args):
|
|
992 |
if args.push_to_hub:
|
993 |
save_model_card(
|
994 |
repo_name,
|
995 |
-
base_model=args.pretrained_model_name_or_path,
|
996 |
-
instance_prompt=args.instance_prompt,
|
997 |
-
test_prompt=args.validation_prompt,
|
998 |
images=images,
|
999 |
-
repo_folder=args.output_dir,
|
1000 |
-
)
|
1001 |
-
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
|
1002 |
-
else:
|
1003 |
-
repo_name = Path(args.output_dir).name
|
1004 |
-
save_model_card(
|
1005 |
-
repo_name,
|
1006 |
base_model=args.pretrained_model_name_or_path,
|
1007 |
-
|
1008 |
-
test_prompt=args.validation_prompt,
|
1009 |
-
images=images,
|
1010 |
repo_folder=args.output_dir,
|
1011 |
)
|
|
|
1012 |
|
1013 |
accelerator.end_training()
|
1014 |
|
|
|
1 |
#!/usr/bin/env python
|
|
|
|
|
|
|
2 |
# coding=utf-8
|
3 |
+
#
|
4 |
+
# This file is copied from https://github.com/huggingface/diffusers/blob/febaf863026bd014b7a14349336544fc109d0f57/examples/dreambooth/train_dreambooth_lora.py
|
5 |
+
# The original license is as below:
|
6 |
+
#
|
7 |
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
8 |
#
|
9 |
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
26 |
from pathlib import Path
|
27 |
from typing import Optional
|
28 |
|
29 |
+
import numpy as np
|
30 |
import torch
|
31 |
import torch.nn.functional as F
|
32 |
import torch.utils.checkpoint
|
|
|
50 |
from diffusers.optimization import get_scheduler
|
51 |
from diffusers.utils import check_min_version, is_wandb_available
|
52 |
from diffusers.utils.import_utils import is_xformers_available
|
53 |
+
from huggingface_hub import HfFolder, Repository, create_repo, whoami
|
54 |
from PIL import Image
|
55 |
from torchvision import transforms
|
56 |
from tqdm.auto import tqdm
|
|
|
63 |
logger = get_logger(__name__)
|
64 |
|
65 |
|
66 |
+
def save_model_card(repo_name, images=None, base_model=str, prompt=str, repo_folder=None):
|
67 |
+
img_str = ""
|
68 |
+
for i, image in enumerate(images):
|
69 |
image.save(os.path.join(repo_folder, f"image_{i}.png"))
|
70 |
img_str += f"![img_{i}](./image_{i}.png)\n"
|
71 |
|
|
|
73 |
---
|
74 |
license: creativeml-openrail-m
|
75 |
base_model: {base_model}
|
|
|
76 |
tags:
|
77 |
- stable-diffusion
|
78 |
- stable-diffusion-diffusers
|
|
|
80 |
- diffusers
|
81 |
inference: true
|
82 |
---
|
83 |
+
"""
|
84 |
model_card = f"""
|
85 |
# LoRA DreamBooth - {repo_name}
|
86 |
|
87 |
+
These are LoRA adaption weights for {repo_name}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following. \n
|
88 |
{img_str}
|
89 |
"""
|
90 |
with open(os.path.join(repo_folder, "README.md"), "w") as f:
|
|
|
365 |
parser.add_argument(
|
366 |
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
|
367 |
)
|
|
|
|
|
|
|
368 |
|
369 |
if input_args is not None:
|
370 |
args = parser.parse_args(input_args)
|
|
|
608 |
if accelerator.is_main_process:
|
609 |
if args.push_to_hub:
|
610 |
if args.hub_model_id is None:
|
611 |
+
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
|
|
|
612 |
else:
|
613 |
repo_name = args.hub_model_id
|
614 |
|
615 |
+
create_repo(repo_name, exist_ok=True, token=args.hub_token)
|
|
|
|
|
|
|
|
|
|
|
616 |
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
|
617 |
|
618 |
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
|
|
|
818 |
dirs = os.listdir(args.output_dir)
|
819 |
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
820 |
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
821 |
+
path = dirs[-1] if len(dirs) > 0 else None
|
|
|
|
|
|
|
822 |
|
823 |
+
if path is None:
|
824 |
+
accelerator.print(
|
825 |
+
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
|
826 |
+
)
|
827 |
+
args.resume_from_checkpoint = None
|
828 |
+
else:
|
829 |
+
accelerator.print(f"Resuming from checkpoint {path}")
|
830 |
+
accelerator.load_state(os.path.join(args.output_dir, path))
|
831 |
+
global_step = int(path.split("-")[1])
|
832 |
+
|
833 |
+
resume_global_step = global_step * args.gradient_accumulation_steps
|
834 |
+
first_epoch = global_step // num_update_steps_per_epoch
|
835 |
+
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
|
836 |
|
837 |
# Only show the progress bar once on each machine.
|
838 |
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
|
|
|
942 |
images = pipeline(prompt, num_inference_steps=25, generator=generator).images
|
943 |
|
944 |
for tracker in accelerator.trackers:
|
945 |
+
if tracker.name == "tensorboard":
|
946 |
+
np_images = np.stack([np.asarray(img) for img in images])
|
947 |
+
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
|
948 |
if tracker.name == "wandb":
|
949 |
tracker.log(
|
950 |
{
|
|
|
976 |
pipeline.unet.load_attn_procs(args.output_dir)
|
977 |
|
978 |
# run inference
|
979 |
+
if args.validation_prompt and args.num_validation_images > 0:
|
980 |
+
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
981 |
+
prompt = args.num_validation_images * [args.validation_prompt]
|
982 |
+
images = pipeline(prompt, num_inference_steps=25, generator=generator).images
|
983 |
|
984 |
for tracker in accelerator.trackers:
|
985 |
+
if tracker.name == "tensorboard":
|
986 |
+
np_images = np.stack([np.asarray(img) for img in images])
|
987 |
+
tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
|
988 |
if tracker.name == "wandb":
|
989 |
tracker.log(
|
990 |
{
|
|
|
998 |
if args.push_to_hub:
|
999 |
save_model_card(
|
1000 |
repo_name,
|
|
|
|
|
|
|
1001 |
images=images,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1002 |
base_model=args.pretrained_model_name_or_path,
|
1003 |
+
prompt=args.instance_prompt,
|
|
|
|
|
1004 |
repo_folder=args.output_dir,
|
1005 |
)
|
1006 |
+
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
|
1007 |
|
1008 |
accelerator.end_training()
|
1009 |
|
trainer.py
CHANGED
@@ -14,6 +14,7 @@ import torch
|
|
14 |
from huggingface_hub import HfApi
|
15 |
|
16 |
from app_upload import LoRAModelUploader
|
|
|
17 |
|
18 |
|
19 |
def pad_image(image: PIL.Image.Image) -> PIL.Image.Image:
|
@@ -125,6 +126,12 @@ class Trainer:
|
|
125 |
command_s = ' '.join(command.split())
|
126 |
f.write(command_s)
|
127 |
subprocess.run(shlex.split(command))
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
message = 'Training completed!'
|
129 |
print(message)
|
130 |
|
|
|
14 |
from huggingface_hub import HfApi
|
15 |
|
16 |
from app_upload import LoRAModelUploader
|
17 |
+
from utils import save_model_card
|
18 |
|
19 |
|
20 |
def pad_image(image: PIL.Image.Image) -> PIL.Image.Image:
|
|
|
126 |
command_s = ' '.join(command.split())
|
127 |
f.write(command_s)
|
128 |
subprocess.run(shlex.split(command))
|
129 |
+
save_model_card(save_dir=output_dir,
|
130 |
+
base_model=base_model,
|
131 |
+
instance_prompt=instance_prompt,
|
132 |
+
test_prompt=validation_prompt,
|
133 |
+
test_image_dir='test_images')
|
134 |
+
|
135 |
message = 'Training completed!'
|
136 |
print(message)
|
137 |
|
utils.py
CHANGED
@@ -18,3 +18,41 @@ def find_exp_dirs(ignore_repo: bool = False) -> list[str]:
|
|
18 |
exp_dir for exp_dir in exp_dirs if not (exp_dir / '.git').exists()
|
19 |
]
|
20 |
return [path.relative_to(repo_dir).as_posix() for path in exp_dirs]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
exp_dir for exp_dir in exp_dirs if not (exp_dir / '.git').exists()
|
19 |
]
|
20 |
return [path.relative_to(repo_dir).as_posix() for path in exp_dirs]
|
21 |
+
|
22 |
+
|
23 |
+
def save_model_card(
|
24 |
+
save_dir: pathlib.Path,
|
25 |
+
base_model: str,
|
26 |
+
instance_prompt: str,
|
27 |
+
test_prompt: str = '',
|
28 |
+
test_image_dir: str = '',
|
29 |
+
) -> None:
|
30 |
+
image_str = ''
|
31 |
+
if test_prompt and test_image_dir:
|
32 |
+
image_paths = sorted((save_dir / test_image_dir).glob('*'))
|
33 |
+
if image_paths:
|
34 |
+
image_str = f'Test prompt: {test_prompt}\n'
|
35 |
+
for image_path in image_paths:
|
36 |
+
rel_path = image_path.relative_to(save_dir)
|
37 |
+
image_str += f'![{image_path.stem}]({rel_path})\n'
|
38 |
+
|
39 |
+
model_card = f'''---
|
40 |
+
license: creativeml-openrail-m
|
41 |
+
base_model: {base_model}
|
42 |
+
instance_prompt: {instance_prompt}
|
43 |
+
tags:
|
44 |
+
- stable-diffusion
|
45 |
+
- stable-diffusion-diffusers
|
46 |
+
- text-to-image
|
47 |
+
- diffusers
|
48 |
+
inference: true
|
49 |
+
---
|
50 |
+
# LoRA DreamBooth - {save_dir.name}
|
51 |
+
|
52 |
+
These are LoRA adaption weights for [{base_model}](https://huggingface.co/{base_model}). The weights were trained on the instance prompt "{instance_prompt}" using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following.
|
53 |
+
|
54 |
+
{image_str}
|
55 |
+
'''
|
56 |
+
|
57 |
+
with open(save_dir / 'README.md', 'w') as f:
|
58 |
+
f.write(model_card)
|