Joeywsuarez mkshing commited on
Commit
3916ae1
0 Parent(s):

Duplicate from svdiff-library/SVDiff-Training-UI

Browse files

Co-authored-by: mkshing <[email protected]>

Files changed (18) hide show
  1. .gitattributes +34 -0
  2. .gitignore +165 -0
  3. .pre-commit-config.yaml +37 -0
  4. .style.yapf +5 -0
  5. LICENSE +21 -0
  6. README.md +15 -0
  7. app.py +78 -0
  8. app_inference.py +170 -0
  9. app_training.py +147 -0
  10. app_upload.py +100 -0
  11. constants.py +6 -0
  12. inference.py +111 -0
  13. requirements.txt +5 -0
  14. style.css +3 -0
  15. train_svdiff.py +1057 -0
  16. trainer.py +177 -0
  17. uploader.py +42 -0
  18. utils.py +58 -0
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ training_data/
2
+ experiments/
3
+ wandb/
4
+
5
+
6
+ # Byte-compiled / optimized / DLL files
7
+ __pycache__/
8
+ *.py[cod]
9
+ *$py.class
10
+
11
+ # C extensions
12
+ *.so
13
+
14
+ # Distribution / packaging
15
+ .Python
16
+ build/
17
+ develop-eggs/
18
+ dist/
19
+ downloads/
20
+ eggs/
21
+ .eggs/
22
+ lib/
23
+ lib64/
24
+ parts/
25
+ sdist/
26
+ var/
27
+ wheels/
28
+ share/python-wheels/
29
+ *.egg-info/
30
+ .installed.cfg
31
+ *.egg
32
+ MANIFEST
33
+
34
+ # PyInstaller
35
+ # Usually these files are written by a python script from a template
36
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
37
+ *.manifest
38
+ *.spec
39
+
40
+ # Installer logs
41
+ pip-log.txt
42
+ pip-delete-this-directory.txt
43
+
44
+ # Unit test / coverage reports
45
+ htmlcov/
46
+ .tox/
47
+ .nox/
48
+ .coverage
49
+ .coverage.*
50
+ .cache
51
+ nosetests.xml
52
+ coverage.xml
53
+ *.cover
54
+ *.py,cover
55
+ .hypothesis/
56
+ .pytest_cache/
57
+ cover/
58
+
59
+ # Translations
60
+ *.mo
61
+ *.pot
62
+
63
+ # Django stuff:
64
+ *.log
65
+ local_settings.py
66
+ db.sqlite3
67
+ db.sqlite3-journal
68
+
69
+ # Flask stuff:
70
+ instance/
71
+ .webassets-cache
72
+
73
+ # Scrapy stuff:
74
+ .scrapy
75
+
76
+ # Sphinx documentation
77
+ docs/_build/
78
+
79
+ # PyBuilder
80
+ .pybuilder/
81
+ target/
82
+
83
+ # Jupyter Notebook
84
+ .ipynb_checkpoints
85
+
86
+ # IPython
87
+ profile_default/
88
+ ipython_config.py
89
+
90
+ # pyenv
91
+ # For a library or package, you might want to ignore these files since the code is
92
+ # intended to run in multiple environments; otherwise, check them in:
93
+ # .python-version
94
+
95
+ # pipenv
96
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
97
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
98
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
99
+ # install all needed dependencies.
100
+ #Pipfile.lock
101
+
102
+ # poetry
103
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
104
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
105
+ # commonly ignored for libraries.
106
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
107
+ #poetry.lock
108
+
109
+ # pdm
110
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
111
+ #pdm.lock
112
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
113
+ # in version control.
114
+ # https://pdm.fming.dev/#use-with-ide
115
+ .pdm.toml
116
+
117
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
118
+ __pypackages__/
119
+
120
+ # Celery stuff
121
+ celerybeat-schedule
122
+ celerybeat.pid
123
+
124
+ # SageMath parsed files
125
+ *.sage.py
126
+
127
+ # Environments
128
+ .env
129
+ .venv
130
+ env/
131
+ venv/
132
+ ENV/
133
+ env.bak/
134
+ venv.bak/
135
+
136
+ # Spyder project settings
137
+ .spyderproject
138
+ .spyproject
139
+
140
+ # Rope project settings
141
+ .ropeproject
142
+
143
+ # mkdocs documentation
144
+ /site
145
+
146
+ # mypy
147
+ .mypy_cache/
148
+ .dmypy.json
149
+ dmypy.json
150
+
151
+ # Pyre type checker
152
+ .pyre/
153
+
154
+ # pytype static type analyzer
155
+ .pytype/
156
+
157
+ # Cython debug symbols
158
+ cython_debug/
159
+
160
+ # PyCharm
161
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
162
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
163
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
164
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
165
+ #.idea/
.pre-commit-config.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ exclude: train_dreambooth_lora.py
2
+ repos:
3
+ - repo: https://github.com/pre-commit/pre-commit-hooks
4
+ rev: v4.2.0
5
+ hooks:
6
+ - id: check-executables-have-shebangs
7
+ - id: check-json
8
+ - id: check-merge-conflict
9
+ - id: check-shebang-scripts-are-executable
10
+ - id: check-toml
11
+ - id: check-yaml
12
+ - id: double-quote-string-fixer
13
+ - id: end-of-file-fixer
14
+ - id: mixed-line-ending
15
+ args: ['--fix=lf']
16
+ - id: requirements-txt-fixer
17
+ - id: trailing-whitespace
18
+ - repo: https://github.com/myint/docformatter
19
+ rev: v1.4
20
+ hooks:
21
+ - id: docformatter
22
+ args: ['--in-place']
23
+ - repo: https://github.com/pycqa/isort
24
+ rev: 5.10.1
25
+ hooks:
26
+ - id: isort
27
+ - repo: https://github.com/pre-commit/mirrors-mypy
28
+ rev: v0.991
29
+ hooks:
30
+ - id: mypy
31
+ args: ['--ignore-missing-imports']
32
+ additional_dependencies: ['types-python-slugify']
33
+ - repo: https://github.com/google/yapf
34
+ rev: v0.32.0
35
+ hooks:
36
+ - id: yapf
37
+ args: ['--parallel', '--in-place']
.style.yapf ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ [style]
2
+ based_on_style = pep8
3
+ blank_line_before_nested_class_or_def = false
4
+ spaces_before_comment = 2
5
+ split_before_logical_operator = true
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2022 hysts
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: SVDiff-pytorch Training UI
3
+ emoji: ⚡
4
+ colorFrom: red
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 3.16.2
8
+ python_version: 3.10.9
9
+ app_file: app.py
10
+ pinned: false
11
+ license: mit
12
+ duplicated_from: svdiff-library/SVDiff-Training-UI
13
+ ---
14
+
15
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+
7
+ import gradio as gr
8
+ import torch
9
+
10
+ from app_inference import create_inference_demo
11
+ from app_training import create_training_demo
12
+ from app_upload import create_upload_demo
13
+ from inference import InferencePipeline
14
+ from trainer import Trainer
15
+
16
+ TITLE = """# SVDiff-pytorch Training UI
17
+ This demo is based on https://github.com/mkshing/svdiff-pytorch, which is an implementation of "SVDiff: Compact Parameter Space for Diffusion Fine-Tuning" by [mkshing](https://twitter.com/mk1stats)
18
+ """
19
+
20
+ ORIGINAL_SPACE_ID = 'mshing/SVDiff-pytorch-UI'
21
+ SPACE_ID = os.getenv('SPACE_ID', ORIGINAL_SPACE_ID)
22
+ SHARED_UI_WARNING = f'''# Attention - This Space doesn't work in this shared UI. You can duplicate and use it with a paid private T4 GPU.
23
+
24
+ <center><a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAAAAXNSR0IArs4c6QAAAP5JREFUOE+lk7FqAkEURY+ltunEgFXS2sZGIbXfEPdLlnxJyDdYB62sbbUKpLbVNhyYFzbrrA74YJlh9r079973psed0cvUD4A+4HoCjsA85X0Dfn/RBLBgBDxnQPfAEJgBY+A9gALA4tcbamSzS4xq4FOQAJgCDwV2CPKV8tZAJcAjMMkUe1vX+U+SMhfAJEHasQIWmXNN3abzDwHUrgcRGmYcgKe0bxrblHEB4E/pndMazNpSZGcsZdBlYJcEL9Afo75molJyM2FxmPgmgPqlWNLGfwZGG6UiyEvLzHYDmoPkDDiNm9JR9uboiONcBXrpY1qmgs21x1QwyZcpvxt9NS09PlsPAAAAAElFTkSuQmCC&logoWidth=14" alt="Duplicate Space"></a></center>
25
+ '''
26
+
27
+ if os.getenv('SYSTEM') == 'spaces' and SPACE_ID != ORIGINAL_SPACE_ID:
28
+ SETTINGS = f'<a href="https://huggingface.co/spaces/{SPACE_ID}/settings">Settings</a>'
29
+ else:
30
+ SETTINGS = 'Settings'
31
+ CUDA_NOT_AVAILABLE_WARNING = f'''# Attention - Running on CPU.
32
+ <center>
33
+ You can assign a GPU in the {SETTINGS} tab if you are running this on HF Spaces.
34
+ "T4 small" is sufficient to run this demo.
35
+ </center>
36
+ '''
37
+
38
+ HF_TOKEN_NOT_SPECIFIED_WARNING = f'''# Attention - The environment variable `HF_TOKEN` is not specified. Please specify your Hugging Face token with write permission as the value of it.
39
+ <center>
40
+ You can check and create your Hugging Face tokens <a href="https://huggingface.co/settings/tokens" target="_blank">here</a>.
41
+ You can specify environment variables in the "Repository secrets" section of the {SETTINGS} tab.
42
+ </center>
43
+ '''
44
+
45
+ HF_TOKEN = os.getenv('HF_TOKEN')
46
+
47
+
48
+ def show_warning(warning_text: str) -> gr.Blocks:
49
+ with gr.Blocks() as demo:
50
+ with gr.Box():
51
+ gr.Markdown(warning_text)
52
+ return demo
53
+
54
+
55
+ pipe = InferencePipeline(HF_TOKEN)
56
+ trainer = Trainer(HF_TOKEN)
57
+
58
+ with gr.Blocks(css='style.css') as demo:
59
+ if os.getenv('IS_SHARED_UI'):
60
+ show_warning(SHARED_UI_WARNING)
61
+ if not torch.cuda.is_available():
62
+ show_warning(CUDA_NOT_AVAILABLE_WARNING)
63
+ if not HF_TOKEN:
64
+ show_warning(HF_TOKEN_NOT_SPECIFIED_WARNING)
65
+
66
+ gr.Markdown(TITLE)
67
+ with gr.Tabs():
68
+ with gr.TabItem('Train'):
69
+ create_training_demo(trainer, pipe)
70
+ with gr.TabItem('Test'):
71
+ create_inference_demo(pipe, HF_TOKEN)
72
+ with gr.TabItem('Upload'):
73
+ gr.Markdown('''
74
+ - You can use this tab to upload models later if you choose not to upload models in training time or if upload in training time failed.
75
+ ''')
76
+ create_upload_demo(HF_TOKEN)
77
+
78
+ demo.queue(max_size=1).launch(share=False)
app_inference.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import enum
6
+
7
+ import gradio as gr
8
+ from huggingface_hub import HfApi
9
+
10
+ from inference import InferencePipeline
11
+ from utils import find_exp_dirs
12
+
13
+ SAMPLE_MODEL_IDS = [
14
+ 'svdiff-library/svdiff_dog_example',
15
+ 'svdiff-library/svdiff_chair_example',
16
+ ]
17
+
18
+
19
+ class ModelSource(enum.Enum):
20
+ SAMPLE = 'Sample'
21
+ HUB_LIB = 'Hub (svdiff-library)'
22
+ LOCAL = 'Local'
23
+
24
+
25
+ class InferenceUtil:
26
+ def __init__(self, hf_token: str | None):
27
+ self.hf_token = hf_token
28
+
29
+ @staticmethod
30
+ def load_sample_model_list():
31
+ return gr.update(choices=SAMPLE_MODEL_IDS, value=SAMPLE_MODEL_IDS[0])
32
+
33
+ def load_hub_model_list(self) -> dict:
34
+ api = HfApi(token=self.hf_token)
35
+ choices = [
36
+ info.modelId for info in api.list_models(author='svdiff-library')
37
+ ]
38
+ return gr.update(choices=choices,
39
+ value=choices[0] if choices else None)
40
+
41
+ @staticmethod
42
+ def load_local_model_list() -> dict:
43
+ choices = find_exp_dirs()
44
+ return gr.update(choices=choices,
45
+ value=choices[0] if choices else None)
46
+
47
+ def reload_model_list(self, model_source: str) -> dict:
48
+ if model_source == ModelSource.SAMPLE.value:
49
+ return self.load_sample_model_list()
50
+ elif model_source == ModelSource.HUB_LIB.value:
51
+ return self.load_hub_model_list()
52
+ elif model_source == ModelSource.LOCAL.value:
53
+ return self.load_local_model_list()
54
+ else:
55
+ raise ValueError
56
+
57
+ def load_model_info(self, model_id: str) -> tuple[str, str]:
58
+ try:
59
+ card = InferencePipeline.get_model_card(model_id,
60
+ self.hf_token)
61
+ except Exception:
62
+ return '', ''
63
+ base_model = getattr(card.data, 'base_model', '')
64
+ instance_prompt = getattr(card.data, 'instance_prompt', '')
65
+ return base_model, instance_prompt
66
+
67
+ def reload_model_list_and_update_model_info(
68
+ self, model_source: str) -> tuple[dict, str, str]:
69
+ model_list_update = self.reload_model_list(model_source)
70
+ model_list = model_list_update['choices']
71
+ model_info = self.load_model_info(model_list[0] if model_list else '')
72
+ return model_list_update, *model_info
73
+
74
+
75
+ def create_inference_demo(pipe: InferencePipeline,
76
+ hf_token: str | None = None) -> gr.Blocks:
77
+ app = InferenceUtil(hf_token)
78
+
79
+ with gr.Blocks() as demo:
80
+ with gr.Row():
81
+ with gr.Column():
82
+ with gr.Box():
83
+ model_source = gr.Radio(
84
+ label='Model Source',
85
+ choices=[_.value for _ in ModelSource],
86
+ value=ModelSource.SAMPLE.value)
87
+ reload_button = gr.Button('Reload Model List')
88
+ model_id = gr.Dropdown(label='Model ID',
89
+ choices=SAMPLE_MODEL_IDS,
90
+ value=SAMPLE_MODEL_IDS[0])
91
+ with gr.Accordion(
92
+ label=
93
+ 'Model info (Base model and instance prompt used for training)',
94
+ open=False):
95
+ with gr.Row():
96
+ base_model_used_for_training = gr.Text(
97
+ label='Base model', interactive=False)
98
+ instance_prompt_used_for_training = gr.Text(
99
+ label='Instance prompt', interactive=False)
100
+ prompt = gr.Textbox(
101
+ label='Prompt',
102
+ max_lines=1,
103
+ placeholder='Example: "A picture of a sks dog in a bucket"'
104
+ )
105
+ seed = gr.Slider(label='Seed',
106
+ minimum=0,
107
+ maximum=100000,
108
+ step=1,
109
+ value=0)
110
+ with gr.Accordion('Other Parameters', open=False):
111
+ num_steps = gr.Slider(label='Number of Steps',
112
+ minimum=0,
113
+ maximum=100,
114
+ step=1,
115
+ value=25)
116
+ guidance_scale = gr.Slider(label='CFG Scale',
117
+ minimum=0,
118
+ maximum=50,
119
+ step=0.1,
120
+ value=7.5)
121
+
122
+ run_button = gr.Button('Generate')
123
+
124
+ gr.Markdown('''
125
+ - After training, you can press "Reload Model List" button to load your trained model names.
126
+ ''')
127
+ with gr.Column():
128
+ result = gr.Image(label='Result')
129
+
130
+ model_source.change(
131
+ fn=app.reload_model_list_and_update_model_info,
132
+ inputs=model_source,
133
+ outputs=[
134
+ model_id,
135
+ base_model_used_for_training,
136
+ instance_prompt_used_for_training,
137
+ ])
138
+ reload_button.click(
139
+ fn=app.reload_model_list_and_update_model_info,
140
+ inputs=model_source,
141
+ outputs=[
142
+ model_id,
143
+ base_model_used_for_training,
144
+ instance_prompt_used_for_training,
145
+ ])
146
+ model_id.change(fn=app.load_model_info,
147
+ inputs=model_id,
148
+ outputs=[
149
+ base_model_used_for_training,
150
+ instance_prompt_used_for_training,
151
+ ])
152
+ inputs = [
153
+ model_id,
154
+ prompt,
155
+ seed,
156
+ num_steps,
157
+ guidance_scale,
158
+ ]
159
+ prompt.submit(fn=pipe.run, inputs=inputs, outputs=result)
160
+ run_button.click(fn=pipe.run, inputs=inputs, outputs=result)
161
+ return demo
162
+
163
+
164
+ if __name__ == '__main__':
165
+ import os
166
+
167
+ hf_token = os.getenv('HF_TOKEN')
168
+ pipe = InferencePipeline(hf_token)
169
+ demo = create_inference_demo(pipe, hf_token)
170
+ demo.queue(max_size=10).launch(share=True, debug=True)
app_training.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+
7
+ import gradio as gr
8
+
9
+ from constants import UploadTarget
10
+ from inference import InferencePipeline
11
+ from trainer import Trainer
12
+
13
+
14
+ def create_training_demo(trainer: Trainer,
15
+ pipe: InferencePipeline | None = None) -> gr.Blocks:
16
+ with gr.Blocks() as demo:
17
+ with gr.Row():
18
+ with gr.Column():
19
+ with gr.Box():
20
+ gr.Markdown('Training Data')
21
+ instance_images = gr.Files(label='Instance images')
22
+ instance_prompt = gr.Textbox(label='Instance prompt',
23
+ max_lines=1)
24
+ gr.Markdown('''
25
+ - Upload images of the style you are planning on training on.
26
+ - For an instance prompt, use a unique, made up word to avoid collisions.
27
+ ''')
28
+ with gr.Box():
29
+ gr.Markdown('Output Model')
30
+ output_model_name = gr.Text(label='Name of your model',
31
+ max_lines=1)
32
+ delete_existing_model = gr.Checkbox(
33
+ label='Delete existing model of the same name',
34
+ value=False)
35
+ validation_prompt = gr.Text(label='Validation Prompt')
36
+ with gr.Box():
37
+ gr.Markdown('Upload Settings')
38
+ with gr.Row():
39
+ upload_to_hub = gr.Checkbox(
40
+ label='Upload model to Hub', value=True)
41
+ use_private_repo = gr.Checkbox(label='Private',
42
+ value=True)
43
+ delete_existing_repo = gr.Checkbox(
44
+ label='Delete existing repo of the same name',
45
+ value=False)
46
+ upload_to = gr.Radio(
47
+ label='Upload to',
48
+ choices=[_.value for _ in UploadTarget],
49
+ value=UploadTarget.SVDIFF_LIBRARY.value)
50
+ gr.Markdown('''
51
+ - By default, trained models will be uploaded to [SVDiff-pytorch Library](https://huggingface.co/svdiff-library).
52
+ - You can also choose "Personal Profile", in which case, the model will be uploaded to https://huggingface.co/{your_username}/{model_name}.
53
+ ''')
54
+
55
+ with gr.Box():
56
+ gr.Markdown('Training Parameters')
57
+ with gr.Row():
58
+ base_model = gr.Text(
59
+ label='Base Model',
60
+ value='runwayml/stable-diffusion-v1-5',
61
+ max_lines=1)
62
+ resolution = gr.Dropdown(choices=['512', '768'],
63
+ value='512',
64
+ label='Resolution')
65
+ num_training_steps = gr.Number(
66
+ label='Number of Training Steps', value=1000, precision=0)
67
+ learning_rate = gr.Number(label='Learning Rate', value=0.001)
68
+ gradient_accumulation = gr.Number(
69
+ label='Number of Gradient Accumulation',
70
+ value=1,
71
+ precision=0)
72
+ seed = gr.Slider(label='Seed',
73
+ minimum=0,
74
+ maximum=100000,
75
+ step=1,
76
+ value=0)
77
+ fp16 = gr.Checkbox(label='FP16', value=False)
78
+ use_8bit_adam = gr.Checkbox(label='Use 8bit Adam', value=True)
79
+ gradient_checkpointing = gr.Checkbox(label='Use gradient checkpointing', value=True)
80
+ # enable_xformers_memory_efficient_attention = gr.Checkbox(label='Use xformers', value=True)
81
+ checkpointing_steps = gr.Number(label='Checkpointing Steps',
82
+ value=200,
83
+ precision=0)
84
+ use_wandb = gr.Checkbox(label='Use W&B',
85
+ value=False,
86
+ interactive=bool(
87
+ os.getenv('WANDB_API_KEY')))
88
+ validation_epochs = gr.Number(label='Validation Epochs',
89
+ value=200,
90
+ precision=0)
91
+ gr.Markdown('''
92
+ - The base model must be a model that is compatible with [diffusers](https://github.com/huggingface/diffusers) library.
93
+ - It takes a few minutes to download the base model first.
94
+ - You may want to try a small number of steps first, like 1, to see if everything works fine in your environment.
95
+ - You can check the training status by pressing the "Open logs" button if you are running this on your Space.
96
+ - You need to set the environment variable `WANDB_API_KEY` if you'd like to use [W&B](https://wandb.ai/site). See [W&B documentation](https://docs.wandb.ai/guides/track/advanced/environment-variables).
97
+ - **Note:** Due to [this issue](https://github.com/huggingface/accelerate/issues/944), currently, training will not terminate properly if you use W&B.
98
+ ''')
99
+
100
+ remove_gpu_after_training = gr.Checkbox(
101
+ label='Remove GPU after training',
102
+ value=False,
103
+ interactive=bool(os.getenv('SPACE_ID')),
104
+ visible=False)
105
+ run_button = gr.Button('Start Training')
106
+
107
+ with gr.Box():
108
+ gr.Markdown('Output message')
109
+ output_message = gr.Markdown()
110
+
111
+ if pipe is not None:
112
+ run_button.click(fn=pipe.clear)
113
+ run_button.click(fn=trainer.run,
114
+ inputs=[
115
+ instance_images,
116
+ instance_prompt,
117
+ output_model_name,
118
+ delete_existing_model,
119
+ validation_prompt,
120
+ base_model,
121
+ resolution,
122
+ num_training_steps,
123
+ learning_rate,
124
+ gradient_accumulation,
125
+ seed,
126
+ fp16,
127
+ use_8bit_adam,
128
+ gradient_checkpointing,
129
+ # enable_xformers_memory_efficient_attention,
130
+ checkpointing_steps,
131
+ use_wandb,
132
+ validation_epochs,
133
+ upload_to_hub,
134
+ use_private_repo,
135
+ delete_existing_repo,
136
+ upload_to,
137
+ remove_gpu_after_training,
138
+ ],
139
+ outputs=output_message)
140
+ return demo
141
+
142
+
143
+ if __name__ == '__main__':
144
+ hf_token = os.getenv('HF_TOKEN')
145
+ trainer = Trainer(hf_token)
146
+ demo = create_training_demo(trainer)
147
+ demo.queue(max_size=1).launch(share=True, debug=True)
app_upload.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import pathlib
6
+
7
+ import gradio as gr
8
+ import slugify
9
+
10
+ from constants import UploadTarget
11
+ from uploader import Uploader
12
+ from utils import find_exp_dirs
13
+
14
+
15
+ class ModelUploader(Uploader):
16
+ def upload_model(
17
+ self,
18
+ folder_path: str,
19
+ repo_name: str,
20
+ upload_to: str,
21
+ private: bool,
22
+ delete_existing_repo: bool,
23
+ ) -> str:
24
+ if not folder_path:
25
+ raise ValueError
26
+ if not repo_name:
27
+ repo_name = pathlib.Path(folder_path).name
28
+ repo_name = slugify.slugify(repo_name)
29
+
30
+ if upload_to == UploadTarget.PERSONAL_PROFILE.value:
31
+ organization = ''
32
+ elif upload_to == UploadTarget.SVDIFF_LIBRARY.value:
33
+ organization = 'svdiff-library'
34
+ else:
35
+ raise ValueError
36
+
37
+ return self.upload(folder_path,
38
+ repo_name,
39
+ organization=organization,
40
+ private=private,
41
+ delete_existing_repo=delete_existing_repo)
42
+
43
+
44
+ def load_local_model_list() -> dict:
45
+ choices = find_exp_dirs(ignore_repo=True)
46
+ return gr.update(choices=choices, value=choices[0] if choices else None)
47
+
48
+
49
+ def create_upload_demo(hf_token: str | None) -> gr.Blocks:
50
+ uploader = ModelUploader(hf_token)
51
+ model_dirs = find_exp_dirs(ignore_repo=True)
52
+
53
+ with gr.Blocks() as demo:
54
+ with gr.Box():
55
+ gr.Markdown('Local Models')
56
+ reload_button = gr.Button('Reload Model List')
57
+ model_dir = gr.Dropdown(
58
+ label='Model names',
59
+ choices=model_dirs,
60
+ value=model_dirs[0] if model_dirs else None)
61
+ with gr.Box():
62
+ gr.Markdown('Upload Settings')
63
+ with gr.Row():
64
+ use_private_repo = gr.Checkbox(label='Private', value=True)
65
+ delete_existing_repo = gr.Checkbox(
66
+ label='Delete existing repo of the same name', value=False)
67
+ upload_to = gr.Radio(label='Upload to',
68
+ choices=[_.value for _ in UploadTarget],
69
+ value=UploadTarget.SVDIFF_LIBRARY.value)
70
+ model_name = gr.Textbox(label='Model Name')
71
+ upload_button = gr.Button('Upload')
72
+ gr.Markdown('''
73
+ - You can upload your trained model to your personal profile (i.e. https://huggingface.co/{your_username}/{model_name}) or to the public [SVDiff-pytorch Concepts Library](https://huggingface.co/svdiff-library) (i.e. https://huggingface.co/svdiff-library/{model_name}).
74
+ ''')
75
+ with gr.Box():
76
+ gr.Markdown('Output message')
77
+ output_message = gr.Markdown()
78
+
79
+ reload_button.click(fn=load_local_model_list,
80
+ inputs=None,
81
+ outputs=model_dir)
82
+ upload_button.click(fn=uploader.upload_model,
83
+ inputs=[
84
+ model_dir,
85
+ model_name,
86
+ upload_to,
87
+ use_private_repo,
88
+ delete_existing_repo,
89
+ ],
90
+ outputs=output_message)
91
+
92
+ return demo
93
+
94
+
95
+ if __name__ == '__main__':
96
+ import os
97
+
98
+ hf_token = os.getenv('HF_TOKEN')
99
+ demo = create_upload_demo(hf_token)
100
+ demo.queue(max_size=1).launch(share=True, debug=True)
constants.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import enum
2
+
3
+
4
+ class UploadTarget(enum.Enum):
5
+ PERSONAL_PROFILE = 'Personal Profile'
6
+ SVDIFF_LIBRARY = 'SVDiff Library'
inference.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import gc
4
+ import pathlib
5
+
6
+ import gradio as gr
7
+ import PIL.Image
8
+ import torch
9
+ from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
10
+ from huggingface_hub import ModelCard
11
+ from svdiff_pytorch import load_unet_for_svdiff, load_text_encoder_for_svdiff, SCHEDULER_MAPPING, image_grid
12
+
13
+
14
+
15
+ class InferencePipeline:
16
+ def __init__(self, hf_token: str | None = None):
17
+ self.hf_token = hf_token
18
+ self.pipe = None
19
+ self.device = torch.device(
20
+ 'cuda:0' if torch.cuda.is_available() else 'cpu')
21
+ self.model_id = None
22
+ self.base_model_id = None
23
+
24
+ def clear(self) -> None:
25
+ self.model_id = None
26
+ self.base_model_id = None
27
+ del self.pipe
28
+ self.pipe = None
29
+ torch.cuda.empty_cache()
30
+ gc.collect()
31
+
32
+ @staticmethod
33
+ def check_if_model_is_local(model_id: str) -> bool:
34
+ return pathlib.Path(model_id).exists()
35
+
36
+ @staticmethod
37
+ def get_model_card(model_id: str,
38
+ hf_token: str | None = None) -> ModelCard:
39
+ if InferencePipeline.check_if_model_is_local(model_id):
40
+ card_path = (pathlib.Path(model_id) / 'README.md').as_posix()
41
+ else:
42
+ card_path = model_id
43
+ return ModelCard.load(card_path, token=hf_token)
44
+
45
+ @staticmethod
46
+ def get_base_model_info(model_id: str,
47
+ hf_token: str | None = None) -> str:
48
+ card = InferencePipeline.get_model_card(model_id, hf_token)
49
+ return card.data.base_model
50
+
51
+ def load_pipe(self, model_id: str) -> None:
52
+ if model_id == self.model_id:
53
+ return
54
+
55
+ base_model_id = self.get_base_model_info(model_id, self.hf_token)
56
+ unet = load_unet_for_svdiff(base_model_id, spectral_shifts_ckpt=model_id, subfolder="unet").to(self.device)
57
+ # first perform svd and cache
58
+ for module in unet.modules():
59
+ if hasattr(module, "perform_svd"):
60
+ module.perform_svd()
61
+ if self.device.type != 'cpu':
62
+ unet = unet.to(self.device, dtype=torch.float16)
63
+ text_encoder = load_text_encoder_for_svdiff(base_model_id, spectral_shifts_ckpt=model_id, subfolder="text_encoder")
64
+ if self.device.type != 'cpu':
65
+ text_encoder = text_encoder.to(self.device, dtype=torch.float16)
66
+ else:
67
+ text_encoder = text_encoder.to(self.device)
68
+ if base_model_id != self.base_model_id:
69
+ if self.device.type == 'cpu':
70
+ pipe = DiffusionPipeline.from_pretrained(
71
+ base_model_id,
72
+ unet=unet,
73
+ text_encoder=text_encoder,
74
+ use_auth_token=self.hf_token
75
+ )
76
+ else:
77
+ pipe = DiffusionPipeline.from_pretrained(
78
+ base_model_id,
79
+ unet=unet,
80
+ text_encoder=text_encoder,
81
+ torch_dtype=torch.float16,
82
+ use_auth_token=self.hf_token
83
+ )
84
+ pipe = pipe.to(self.device)
85
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
86
+ self.pipe = pipe
87
+
88
+ self.model_id = model_id # type: ignore
89
+ self.base_model_id = base_model_id # type: ignore
90
+
91
+ def run(
92
+ self,
93
+ model_id: str,
94
+ prompt: str,
95
+ seed: int,
96
+ n_steps: int,
97
+ guidance_scale: float,
98
+ ) -> PIL.Image.Image:
99
+ # if not torch.cuda.is_available():
100
+ # raise gr.Error('CUDA is not available.')
101
+
102
+ self.load_pipe(model_id)
103
+
104
+ generator = torch.Generator(device=self.device).manual_seed(seed)
105
+ out = self.pipe(
106
+ prompt,
107
+ num_inference_steps=n_steps,
108
+ guidance_scale=guidance_scale,
109
+ generator=generator,
110
+ ) # type: ignore
111
+ return out.images[0]
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ svdiff-pytorch>=0.2.0
2
+ bitsandbytes==0.35.0
3
+ python-slugify==7.0.0
4
+ tomesd
5
+ gradio==3.16.2
style.css ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }
train_svdiff.py ADDED
@@ -0,0 +1,1057 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import hashlib
3
+ import logging
4
+ import math
5
+ import os
6
+ import warnings
7
+ from pathlib import Path
8
+ from typing import Optional
9
+ from packaging import version
10
+ import itertools
11
+
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn.functional as F
15
+ import torch.utils.checkpoint
16
+ import transformers
17
+ from accelerate import Accelerator
18
+ from accelerate.logging import get_logger
19
+ from accelerate.utils import ProjectConfiguration, set_seed
20
+ from huggingface_hub import create_repo, upload_folder
21
+ from packaging import version
22
+ from PIL import Image
23
+ from torch.utils.data import Dataset
24
+ from torchvision import transforms
25
+ from tqdm.auto import tqdm
26
+ from transformers import CLIPTextModel, AutoTokenizer, PretrainedConfig
27
+
28
+ import diffusers
29
+ from diffusers import __version__
30
+ from diffusers import (
31
+ AutoencoderKL,
32
+ DDPMScheduler,
33
+ DiffusionPipeline,
34
+ StableDiffusionPipeline,
35
+ DPMSolverMultistepScheduler,
36
+ )
37
+ from svdiff_pytorch import load_unet_for_svdiff, load_text_encoder_for_svdiff, SCHEDULER_MAPPING
38
+ from diffusers.loaders import AttnProcsLayers
39
+ from diffusers.optimization import get_scheduler
40
+ from diffusers.utils import check_min_version, is_wandb_available
41
+ from diffusers.utils.import_utils import is_xformers_available
42
+ from safetensors import safe_open
43
+ from safetensors.torch import save_file
44
+ if is_wandb_available():
45
+ import wandb
46
+
47
+
48
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
49
+ # check_min_version("0.15.0.dev0")
50
+ diffusers_version = "0.14.0"
51
+ if version.parse(__version__) != version.parse(diffusers_version):
52
+ error_message = f"This example requires a version of {diffusers_version},"
53
+ error_message += f" but the version found is {__version__}.\n"
54
+ raise ImportError(error_message)
55
+
56
+ logger = get_logger(__name__)
57
+
58
+
59
+ def save_model_card(repo_id: str, base_model=str, prompt=str, repo_folder=None):
60
+ yaml = f"""
61
+ ---
62
+ license: creativeml-openrail-m
63
+ base_model: {base_model}
64
+ instance_prompt: {prompt}
65
+ tags:
66
+ - stable-diffusion
67
+ - stable-diffusion-diffusers
68
+ - text-to-image
69
+ - diffusers
70
+ - svdiff
71
+ inference: true
72
+ ---
73
+ """
74
+ model_card = f"""
75
+ # SVDiff-pytorch - {repo_id}
76
+ These are SVDiff weights for {base_model}. The weights were trained on {prompt}.
77
+ """
78
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
79
+ f.write(yaml + model_card)
80
+
81
+
82
+ def parse_args(input_args=None):
83
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
84
+ parser.add_argument(
85
+ "--pretrained_model_name_or_path",
86
+ type=str,
87
+ default=None,
88
+ required=True,
89
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
90
+ )
91
+ parser.add_argument(
92
+ "--pretrained_vae_name_or_path",
93
+ type=str,
94
+ default=None,
95
+ help="Path to pretrained vae or vae identifier from huggingface.co/models. This will be used in prior generation",
96
+ )
97
+ parser.add_argument(
98
+ "--revision",
99
+ type=str,
100
+ default=None,
101
+ required=False,
102
+ help="Revision of pretrained model identifier from huggingface.co/models.",
103
+ )
104
+ parser.add_argument(
105
+ "--tokenizer_name",
106
+ type=str,
107
+ default=None,
108
+ help="Pretrained tokenizer name or path if not the same as model_name",
109
+ )
110
+ parser.add_argument(
111
+ "--instance_data_dir",
112
+ type=str,
113
+ default=None,
114
+ required=True,
115
+ help="A folder containing the training data of instance images.",
116
+ )
117
+ parser.add_argument(
118
+ "--class_data_dir",
119
+ type=str,
120
+ default=None,
121
+ required=False,
122
+ help="A folder containing the training data of class images.",
123
+ )
124
+ parser.add_argument(
125
+ "--instance_prompt",
126
+ type=str,
127
+ default=None,
128
+ required=True,
129
+ help="The prompt with identifier specifying the instance",
130
+ )
131
+ parser.add_argument(
132
+ "--class_prompt",
133
+ type=str,
134
+ default=None,
135
+ help="The prompt to specify images in the same class as provided instance images.",
136
+ )
137
+ parser.add_argument(
138
+ "--validation_prompt",
139
+ type=str,
140
+ default=None,
141
+ help="A prompt that is used during validation to verify that the model is learning.",
142
+ )
143
+ parser.add_argument(
144
+ "--num_validation_images",
145
+ type=int,
146
+ default=4,
147
+ help="Number of images that should be generated during validation with `validation_prompt`.",
148
+ )
149
+ parser.add_argument(
150
+ "--validation_epochs",
151
+ type=int,
152
+ default=50,
153
+ help=(
154
+ "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt"
155
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
156
+ ),
157
+ )
158
+ parser.add_argument(
159
+ "--with_prior_preservation",
160
+ default=False,
161
+ action="store_true",
162
+ help="Flag to add prior preservation loss.",
163
+ )
164
+ parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
165
+ parser.add_argument(
166
+ "--num_class_images",
167
+ type=int,
168
+ default=100,
169
+ help=(
170
+ "Minimal class images for prior preservation loss. If there are not enough images already present in"
171
+ " class_data_dir, additional images will be sampled with class_prompt."
172
+ ),
173
+ )
174
+ parser.add_argument(
175
+ "--output_dir",
176
+ type=str,
177
+ default="lora-dreambooth-model",
178
+ help="The output directory where the model predictions and checkpoints will be written.",
179
+ )
180
+ parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
181
+ parser.add_argument(
182
+ "--resolution",
183
+ type=int,
184
+ default=512,
185
+ help=(
186
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
187
+ " resolution"
188
+ ),
189
+ )
190
+ parser.add_argument(
191
+ "--center_crop",
192
+ default=False,
193
+ action="store_true",
194
+ help=(
195
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
196
+ " cropped. The images will be resized to the resolution first before cropping."
197
+ ),
198
+ )
199
+ parser.add_argument(
200
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
201
+ )
202
+ parser.add_argument(
203
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
204
+ )
205
+ parser.add_argument("--num_train_epochs", type=int, default=1)
206
+ parser.add_argument(
207
+ "--max_train_steps",
208
+ type=int,
209
+ default=None,
210
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
211
+ )
212
+ parser.add_argument(
213
+ "--checkpointing_steps",
214
+ type=int,
215
+ default=500,
216
+ help=(
217
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
218
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
219
+ " training using `--resume_from_checkpoint`."
220
+ ),
221
+ )
222
+ parser.add_argument(
223
+ "--checkpoints_total_limit",
224
+ type=int,
225
+ default=None,
226
+ help=(
227
+ "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
228
+ " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
229
+ " for more docs"
230
+ ),
231
+ )
232
+ parser.add_argument(
233
+ "--resume_from_checkpoint",
234
+ type=str,
235
+ default=None,
236
+ help=(
237
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
238
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
239
+ ),
240
+ )
241
+ parser.add_argument(
242
+ "--gradient_accumulation_steps",
243
+ type=int,
244
+ default=1,
245
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
246
+ )
247
+ parser.add_argument(
248
+ "--gradient_checkpointing",
249
+ action="store_true",
250
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
251
+ )
252
+ parser.add_argument(
253
+ "--learning_rate",
254
+ type=float,
255
+ default=1e-3,
256
+ help="Initial learning rate (after the potential warmup period) to use.",
257
+ )
258
+ parser.add_argument(
259
+ "--learning_rate_1d",
260
+ type=float,
261
+ default=1e-6,
262
+ help="Initial learning rate (after the potential warmup period) to use for 1-d weights",
263
+ )
264
+ parser.add_argument(
265
+ "--scale_lr",
266
+ action="store_true",
267
+ default=False,
268
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
269
+ )
270
+ parser.add_argument(
271
+ "--lr_scheduler",
272
+ type=str,
273
+ default="constant",
274
+ help=(
275
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
276
+ ' "constant", "constant_with_warmup"]'
277
+ ),
278
+ )
279
+ parser.add_argument(
280
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
281
+ )
282
+ parser.add_argument(
283
+ "--lr_num_cycles",
284
+ type=int,
285
+ default=1,
286
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
287
+ )
288
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
289
+ parser.add_argument(
290
+ "--dataloader_num_workers",
291
+ type=int,
292
+ default=0,
293
+ help=(
294
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
295
+ ),
296
+ )
297
+ parser.add_argument(
298
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
299
+ )
300
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
301
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
302
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
303
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
304
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
305
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
306
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
307
+ parser.add_argument(
308
+ "--hub_model_id",
309
+ type=str,
310
+ default=None,
311
+ help="The name of the repository to keep in sync with the local `output_dir`.",
312
+ )
313
+ parser.add_argument(
314
+ "--logging_dir",
315
+ type=str,
316
+ default="logs",
317
+ help=(
318
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
319
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
320
+ ),
321
+ )
322
+ parser.add_argument(
323
+ "--allow_tf32",
324
+ action="store_true",
325
+ help=(
326
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
327
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
328
+ ),
329
+ )
330
+ parser.add_argument(
331
+ "--report_to",
332
+ type=str,
333
+ default="tensorboard",
334
+ help=(
335
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
336
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
337
+ ),
338
+ )
339
+ parser.add_argument(
340
+ "--mixed_precision",
341
+ type=str,
342
+ default=None,
343
+ choices=["no", "fp16", "bf16"],
344
+ help=(
345
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
346
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
347
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
348
+ ),
349
+ )
350
+ parser.add_argument(
351
+ "--prior_generation_precision",
352
+ type=str,
353
+ default=None,
354
+ choices=["no", "fp32", "fp16", "bf16"],
355
+ help=(
356
+ "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
357
+ " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
358
+ ),
359
+ )
360
+ parser.add_argument("--prior_generation_scheduler_type", type=str, choices=["ddim", "plms", "lms", "euler", "euler_ancestral", "dpm_solver++"], default="ddim", help="diffusion scheduler type")
361
+ parser.add_argument("--prior_generation_num_inference_steps", type=int, default=50, help="number of sampling steps")
362
+
363
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
364
+ parser.add_argument(
365
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
366
+ )
367
+ parser.add_argument(
368
+ "--enable_token_merging", action="store_true", help="Whether or not to use tomesd on prior generation"
369
+ )
370
+ parser.add_argument(
371
+ "--train_text_encoder",
372
+ action="store_true",
373
+ help="Whether to train spectral shifts of the text encoder. If set, the text encoder should be float32 precision.",
374
+ )
375
+ if input_args is not None:
376
+ args = parser.parse_args(input_args)
377
+ else:
378
+ args = parser.parse_args()
379
+
380
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
381
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
382
+ args.local_rank = env_local_rank
383
+
384
+ if args.with_prior_preservation:
385
+ if args.class_data_dir is None:
386
+ raise ValueError("You must specify a data directory for class images.")
387
+ if args.class_prompt is None:
388
+ raise ValueError("You must specify prompt for class images.")
389
+ else:
390
+ # logger is not available yet
391
+ if args.class_data_dir is not None:
392
+ warnings.warn("You need not use --class_data_dir without --with_prior_preservation.")
393
+ if args.class_prompt is not None:
394
+ warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
395
+
396
+ return args
397
+
398
+
399
+ class DreamBoothDataset(Dataset):
400
+ """
401
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
402
+ It pre-processes the images and the tokenizes prompts.
403
+ """
404
+
405
+ def __init__(
406
+ self,
407
+ instance_data_root,
408
+ instance_prompt,
409
+ tokenizer,
410
+ class_data_root=None,
411
+ class_prompt=None,
412
+ class_num=None,
413
+ size=512,
414
+ center_crop=False,
415
+ ):
416
+ self.size = size
417
+ self.center_crop = center_crop
418
+ self.tokenizer = tokenizer
419
+
420
+ self.instance_data_root = Path(instance_data_root)
421
+ if not self.instance_data_root.exists():
422
+ raise ValueError("Instance images root doesn't exists.")
423
+
424
+ self.instance_images_path = list(Path(instance_data_root).iterdir())
425
+ self.num_instance_images = len(self.instance_images_path)
426
+ self.instance_prompt = instance_prompt
427
+ self._length = self.num_instance_images
428
+
429
+ if class_data_root is not None:
430
+ self.class_data_root = Path(class_data_root)
431
+ self.class_data_root.mkdir(parents=True, exist_ok=True)
432
+ self.class_images_path = list(self.class_data_root.iterdir())
433
+ if class_num is not None:
434
+ self.num_class_images = min(len(self.class_images_path), class_num)
435
+ else:
436
+ self.num_class_images = len(self.class_images_path)
437
+ self._length = max(self.num_class_images, self.num_instance_images)
438
+ self.class_prompt = class_prompt
439
+ else:
440
+ self.class_data_root = None
441
+
442
+ self.image_transforms = transforms.Compose(
443
+ [
444
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
445
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
446
+ transforms.ToTensor(),
447
+ transforms.Normalize([0.5], [0.5]),
448
+ ]
449
+ )
450
+
451
+ def __len__(self):
452
+ return self._length
453
+
454
+ def __getitem__(self, index):
455
+ example = {}
456
+ instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
457
+ if not instance_image.mode == "RGB":
458
+ instance_image = instance_image.convert("RGB")
459
+ example["instance_images"] = self.image_transforms(instance_image)
460
+ example["instance_prompt_ids"] = self.tokenizer(
461
+ self.instance_prompt,
462
+ truncation=True,
463
+ padding="max_length",
464
+ max_length=self.tokenizer.model_max_length,
465
+ return_tensors="pt",
466
+ ).input_ids
467
+
468
+ if self.class_data_root:
469
+ class_image = Image.open(self.class_images_path[index % self.num_class_images])
470
+ if not class_image.mode == "RGB":
471
+ class_image = class_image.convert("RGB")
472
+ example["class_images"] = self.image_transforms(class_image)
473
+ example["class_prompt_ids"] = self.tokenizer(
474
+ self.class_prompt,
475
+ truncation=True,
476
+ padding="max_length",
477
+ max_length=self.tokenizer.model_max_length,
478
+ return_tensors="pt",
479
+ ).input_ids
480
+
481
+ return example
482
+
483
+
484
+ def collate_fn(examples, with_prior_preservation=False):
485
+ input_ids = [example["instance_prompt_ids"] for example in examples]
486
+ pixel_values = [example["instance_images"] for example in examples]
487
+
488
+ # Concat class and instance examples for prior preservation.
489
+ # We do this to avoid doing two forward passes.
490
+ if with_prior_preservation:
491
+ input_ids += [example["class_prompt_ids"] for example in examples]
492
+ pixel_values += [example["class_images"] for example in examples]
493
+
494
+ pixel_values = torch.stack(pixel_values)
495
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
496
+
497
+ input_ids = torch.cat(input_ids, dim=0)
498
+
499
+ batch = {
500
+ "input_ids": input_ids,
501
+ "pixel_values": pixel_values,
502
+ }
503
+ return batch
504
+
505
+
506
+ class PromptDataset(Dataset):
507
+ "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
508
+
509
+ def __init__(self, prompt, num_samples):
510
+ self.prompt = prompt
511
+ self.num_samples = num_samples
512
+
513
+ def __len__(self):
514
+ return self.num_samples
515
+
516
+ def __getitem__(self, index):
517
+ example = {}
518
+ example["prompt"] = self.prompt
519
+ example["index"] = index
520
+ return example
521
+
522
+
523
+ def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch):
524
+ logger.info(
525
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
526
+ f" {args.validation_prompt}."
527
+ )
528
+ # create pipeline (note: unet and vae are loaded again in float32)
529
+ pipeline = DiffusionPipeline.from_pretrained(
530
+ args.pretrained_model_name_or_path,
531
+ text_encoder=text_encoder,
532
+ tokenizer=tokenizer,
533
+ unet=accelerator.unwrap_model(unet),
534
+ vae=vae,
535
+ revision=args.revision,
536
+ torch_dtype=weight_dtype,
537
+ )
538
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
539
+ pipeline = pipeline.to(accelerator.device)
540
+ pipeline.set_progress_bar_config(disable=True)
541
+
542
+ # run inference
543
+ generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed)
544
+ images = []
545
+ for _ in range(args.num_validation_images):
546
+ with torch.autocast("cuda"):
547
+ image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
548
+ images.append(image)
549
+
550
+ for tracker in accelerator.trackers:
551
+ if tracker.name == "tensorboard":
552
+ np_images = np.stack([np.asarray(img) for img in images])
553
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
554
+ if tracker.name == "wandb":
555
+ tracker.log(
556
+ {
557
+ "validation": [
558
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
559
+ ]
560
+ }
561
+ )
562
+
563
+ del pipeline
564
+ torch.cuda.empty_cache()
565
+
566
+
567
+
568
+ def main(args):
569
+ logging_dir = Path(args.output_dir, args.logging_dir)
570
+
571
+ accelerator_project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit)
572
+
573
+ accelerator = Accelerator(
574
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
575
+ mixed_precision=args.mixed_precision,
576
+ log_with=args.report_to,
577
+ logging_dir=logging_dir,
578
+ project_config=accelerator_project_config,
579
+ )
580
+
581
+ if args.report_to == "wandb":
582
+ if not is_wandb_available():
583
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
584
+ import wandb
585
+
586
+ # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
587
+ # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
588
+ # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
589
+ if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
590
+ raise ValueError(
591
+ "Gradient accumulation is not supported when training the text encoder in distributed training. "
592
+ "Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
593
+ )
594
+ # Make one log on every process with the configuration for debugging.
595
+ logging.basicConfig(
596
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
597
+ datefmt="%m/%d/%Y %H:%M:%S",
598
+ level=logging.INFO,
599
+ )
600
+ logger.info(accelerator.state, main_process_only=False)
601
+ if accelerator.is_local_main_process:
602
+ transformers.utils.logging.set_verbosity_warning()
603
+ diffusers.utils.logging.set_verbosity_info()
604
+ else:
605
+ transformers.utils.logging.set_verbosity_error()
606
+ diffusers.utils.logging.set_verbosity_error()
607
+
608
+ # If passed along, set the training seed now.
609
+ if args.seed is not None:
610
+ set_seed(args.seed)
611
+
612
+ # Generate class images if prior preservation is enabled.
613
+ if args.with_prior_preservation:
614
+ class_images_dir = Path(args.class_data_dir)
615
+ if not class_images_dir.exists():
616
+ class_images_dir.mkdir(parents=True)
617
+ cur_class_images = len(list(class_images_dir.iterdir()))
618
+
619
+ if cur_class_images < args.num_class_images:
620
+ torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
621
+ if args.prior_generation_precision == "fp32":
622
+ torch_dtype = torch.float32
623
+ elif args.prior_generation_precision == "fp16":
624
+ torch_dtype = torch.float16
625
+ elif args.prior_generation_precision == "bf16":
626
+ torch_dtype = torch.bfloat16
627
+ pipeline = StableDiffusionPipeline.from_pretrained(
628
+ args.pretrained_model_name_or_path,
629
+ vae=AutoencoderKL.from_pretrained(
630
+ args.pretrained_vae_name_or_path or args.pretrained_model_name_or_path,
631
+ subfolder=None if args.pretrained_vae_name_or_path else "vae",
632
+ revision=None if args.pretrained_vae_name_or_path else args.revision,
633
+ torch_dtype=torch_dtype
634
+ ),
635
+ torch_dtype=torch_dtype,
636
+ safety_checker=None,
637
+ revision=args.revision,
638
+ )
639
+ pipeline.scheduler = SCHEDULER_MAPPING[args.prior_generation_scheduler_type].from_config(pipeline.scheduler.config)
640
+ if is_xformers_available():
641
+ pipeline.enable_xformers_memory_efficient_attention()
642
+ if args.enable_token_merging:
643
+ try:
644
+ import tomesd
645
+ except ImportError:
646
+ raise ImportError(
647
+ "To use token merging (ToMe), please install the tomesd library: `pip install tomesd`."
648
+ )
649
+ tomesd.apply_patch(pipeline, ratio=0.5)
650
+
651
+ pipeline.set_progress_bar_config(disable=True)
652
+
653
+ num_new_images = args.num_class_images - cur_class_images
654
+ logger.info(f"Number of class images to sample: {num_new_images}.")
655
+
656
+ sample_dataset = PromptDataset(args.class_prompt, num_new_images)
657
+ sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
658
+
659
+ sample_dataloader = accelerator.prepare(sample_dataloader)
660
+ pipeline.to(accelerator.device)
661
+
662
+ for example in tqdm(
663
+ sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
664
+ ):
665
+ images = pipeline(
666
+ example["prompt"],
667
+ num_inference_steps=args.prior_generation_num_inference_steps,
668
+ ).images
669
+
670
+ for i, image in enumerate(images):
671
+ hash_image = hashlib.sha1(image.tobytes()).hexdigest()
672
+ image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
673
+ image.save(image_filename)
674
+
675
+ del pipeline
676
+ if torch.cuda.is_available():
677
+ torch.cuda.empty_cache()
678
+
679
+ # Handle the repository creation
680
+ if accelerator.is_main_process:
681
+ if args.output_dir is not None:
682
+ os.makedirs(args.output_dir, exist_ok=True)
683
+
684
+ if args.push_to_hub:
685
+ repo_id = create_repo(
686
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
687
+ ).repo_id
688
+
689
+ # Load the tokenizer
690
+ if args.tokenizer_name:
691
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
692
+ elif args.pretrained_model_name_or_path:
693
+ tokenizer = AutoTokenizer.from_pretrained(
694
+ args.pretrained_model_name_or_path,
695
+ subfolder="tokenizer",
696
+ revision=args.revision,
697
+ use_fast=False,
698
+ )
699
+
700
+ # Load scheduler and models
701
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
702
+ if args.train_text_encoder:
703
+ text_encoder = load_text_encoder_for_svdiff(args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision)
704
+ else:
705
+ text_encoder = CLIPTextModel.from_pretrained(
706
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
707
+ )
708
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
709
+ unet = load_unet_for_svdiff(args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, low_cpu_mem_usage=True)
710
+
711
+ # We only train the additional spectral shifts
712
+ vae.requires_grad_(False)
713
+ text_encoder.requires_grad_(False)
714
+ unet.requires_grad_(False)
715
+ optim_params = []
716
+ optim_params_1d = []
717
+ for n, p in unet.named_parameters():
718
+ if "delta" in n:
719
+ p.requires_grad = True
720
+ if "norm" in n:
721
+ optim_params_1d.append(p)
722
+ else:
723
+ optim_params.append(p)
724
+ if args.train_text_encoder:
725
+ for n, p in text_encoder.named_parameters():
726
+ if "delta" in n:
727
+ p.requires_grad = True
728
+ if "norm" in n:
729
+ optim_params_1d.append(p)
730
+ else:
731
+ optim_params.append(p)
732
+
733
+ total_params = sum(p.numel() for p in optim_params)
734
+ print(f"Number of Trainable Parameters: {total_params * 1.e-6:.2f} M")
735
+
736
+ if args.enable_xformers_memory_efficient_attention:
737
+ if is_xformers_available():
738
+ import xformers
739
+
740
+ xformers_version = version.parse(xformers.__version__)
741
+ if xformers_version == version.parse("0.0.16"):
742
+ logger.warn(
743
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
744
+ )
745
+ unet.enable_xformers_memory_efficient_attention()
746
+ else:
747
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
748
+
749
+ if args.gradient_checkpointing:
750
+ unet.enable_gradient_checkpointing()
751
+ if args.train_text_encoder:
752
+ text_encoder.gradient_checkpointing_enable()
753
+
754
+ # Check that all trainable models are in full precision
755
+ low_precision_error_string = (
756
+ "Please make sure to always have all model weights in full float32 precision when starting training - even if"
757
+ " doing mixed precision training. copy of the weights should still be float32."
758
+ )
759
+
760
+ if accelerator.unwrap_model(unet).dtype != torch.float32:
761
+ raise ValueError(
762
+ f"Unet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}"
763
+ )
764
+
765
+ if args.train_text_encoder and accelerator.unwrap_model(text_encoder).dtype != torch.float32:
766
+ raise ValueError(
767
+ f"Text encoder loaded as datatype {accelerator.unwrap_model(text_encoder).dtype}."
768
+ f" {low_precision_error_string}"
769
+ )
770
+
771
+ # Enable TF32 for faster training on Ampere GPUs,
772
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
773
+ if args.allow_tf32:
774
+ torch.backends.cuda.matmul.allow_tf32 = True
775
+
776
+ if args.scale_lr:
777
+ args.learning_rate = (
778
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
779
+ )
780
+
781
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
782
+ if args.use_8bit_adam:
783
+ try:
784
+ import bitsandbytes as bnb
785
+ except ImportError:
786
+ raise ImportError(
787
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
788
+ )
789
+
790
+ optimizer_class = bnb.optim.AdamW8bit
791
+ else:
792
+ optimizer_class = torch.optim.AdamW
793
+
794
+ # Optimizer creation
795
+ optimizer = optimizer_class(
796
+ [{"params": optim_params}, {"params": optim_params_1d, "lr": args.learning_rate_1d}],
797
+ lr=args.learning_rate,
798
+ betas=(args.adam_beta1, args.adam_beta2),
799
+ weight_decay=args.adam_weight_decay,
800
+ eps=args.adam_epsilon,
801
+ )
802
+
803
+ # Dataset and DataLoaders creation:
804
+ train_dataset = DreamBoothDataset(
805
+ instance_data_root=args.instance_data_dir,
806
+ instance_prompt=args.instance_prompt,
807
+ class_data_root=args.class_data_dir if args.with_prior_preservation else None,
808
+ class_prompt=args.class_prompt,
809
+ class_num=args.num_class_images,
810
+ tokenizer=tokenizer,
811
+ size=args.resolution,
812
+ center_crop=args.center_crop,
813
+ )
814
+
815
+ train_dataloader = torch.utils.data.DataLoader(
816
+ train_dataset,
817
+ batch_size=args.train_batch_size,
818
+ shuffle=True,
819
+ collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
820
+ num_workers=args.dataloader_num_workers,
821
+ )
822
+
823
+ # Scheduler and math around the number of training steps.
824
+ overrode_max_train_steps = False
825
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
826
+ if args.max_train_steps is None:
827
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
828
+ overrode_max_train_steps = True
829
+
830
+ lr_scheduler = get_scheduler(
831
+ args.lr_scheduler,
832
+ optimizer=optimizer,
833
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
834
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
835
+ num_cycles=args.lr_num_cycles,
836
+ power=args.lr_power,
837
+ )
838
+
839
+ # Prepare everything with our `accelerator`.
840
+ if args.train_text_encoder:
841
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
842
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler
843
+ )
844
+ else:
845
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
846
+ unet, optimizer, train_dataloader, lr_scheduler
847
+ )
848
+
849
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
850
+ # as these models are only used for inference, keeping weights in full precision is not required.
851
+ weight_dtype = torch.float32
852
+ if accelerator.mixed_precision == "fp16":
853
+ weight_dtype = torch.float16
854
+ elif accelerator.mixed_precision == "bf16":
855
+ weight_dtype = torch.bfloat16
856
+
857
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
858
+ # unet.to(accelerator.device, dtype=weight_dtype)
859
+ vae.to(accelerator.device, dtype=weight_dtype)
860
+ if not args.train_text_encoder:
861
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
862
+
863
+
864
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
865
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
866
+ if overrode_max_train_steps:
867
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
868
+ # Afterwards we recalculate our number of training epochs
869
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
870
+
871
+ # We need to initialize the trackers we use, and also store our configuration.
872
+ # The trackers initializes automatically on the main process.
873
+ if accelerator.is_main_process:
874
+ accelerator.init_trackers("svdiff-pytorch", config=vars(args))
875
+
876
+ # cache keys to save
877
+ state_dict_keys = [k for k in accelerator.unwrap_model(unet).state_dict().keys() if "delta" in k]
878
+ if args.train_text_encoder:
879
+ state_dict_keys_te = [k for k in accelerator.unwrap_model(text_encoder).state_dict().keys() if "delta" in k]
880
+
881
+ def save_weights(step, save_path=None):
882
+ # Create the pipeline using using the trained modules and save it.
883
+ if accelerator.is_main_process:
884
+ if save_path is None:
885
+ save_path = os.path.join(args.output_dir, f"checkpoint-{step}")
886
+ os.makedirs(save_path, exist_ok=True)
887
+ state_dict = accelerator.unwrap_model(unet, keep_fp32_wrapper=True).state_dict()
888
+ # state_dict = {k: v for k, v in unet_model.state_dict().items() if "delta" in k}
889
+ state_dict = {k: state_dict[k] for k in state_dict_keys}
890
+ save_file(state_dict, os.path.join(save_path, "spectral_shifts.safetensors"))
891
+ if args.train_text_encoder:
892
+ state_dict = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True).state_dict()
893
+ # state_dict = {k: v for k, v in unet_model.state_dict().items() if "delta" in k}
894
+ state_dict = {k: state_dict[k] for k in state_dict_keys_te}
895
+ save_file(state_dict, os.path.join(save_path, "spectral_shifts_te.safetensors"))
896
+
897
+ print(f"[*] Weights saved at {save_path}")
898
+
899
+ # Train!
900
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
901
+
902
+ logger.info("***** Running training *****")
903
+ logger.info(f" Num examples = {len(train_dataset)}")
904
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
905
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
906
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
907
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
908
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
909
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
910
+ global_step = 0
911
+ first_epoch = 0
912
+
913
+ # Potentially load in the weights and states from a previous save
914
+ if args.resume_from_checkpoint:
915
+ if args.resume_from_checkpoint != "latest":
916
+ path = os.path.basename(args.resume_from_checkpoint)
917
+ else:
918
+ # Get the mos recent checkpoint
919
+ dirs = os.listdir(args.output_dir)
920
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
921
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
922
+ path = dirs[-1] if len(dirs) > 0 else None
923
+
924
+ if path is None:
925
+ accelerator.print(
926
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
927
+ )
928
+ args.resume_from_checkpoint = None
929
+ else:
930
+ accelerator.print(f"Resuming from checkpoint {path}")
931
+ accelerator.load_state(os.path.join(args.output_dir, path))
932
+ global_step = int(path.split("-")[1])
933
+
934
+ resume_global_step = global_step * args.gradient_accumulation_steps
935
+ first_epoch = global_step // num_update_steps_per_epoch
936
+ resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
937
+
938
+ # Only show the progress bar once on each machine.
939
+ progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
940
+ progress_bar.set_description("Steps")
941
+
942
+ for epoch in range(first_epoch, args.num_train_epochs):
943
+ unet.train()
944
+ if args.train_text_encoder:
945
+ text_encoder.train()
946
+ for step, batch in enumerate(train_dataloader):
947
+ # Skip steps until we reach the resumed step
948
+ if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
949
+ if step % args.gradient_accumulation_steps == 0:
950
+ progress_bar.update(1)
951
+ continue
952
+
953
+ with accelerator.accumulate(unet):
954
+ # Convert images to latent space
955
+ latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
956
+ latents = latents * vae.config.scaling_factor
957
+
958
+ # Sample noise that we'll add to the latents
959
+ noise = torch.randn_like(latents)
960
+ bsz = latents.shape[0]
961
+ # Sample a random timestep for each image
962
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
963
+ timesteps = timesteps.long()
964
+
965
+ # Add noise to the latents according to the noise magnitude at each timestep
966
+ # (this is the forward diffusion process)
967
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
968
+
969
+ # Get the text embedding for conditioning
970
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0]
971
+
972
+ # Predict the noise residual
973
+ model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
974
+
975
+ # Get the target for loss depending on the prediction type
976
+ if noise_scheduler.config.prediction_type == "epsilon":
977
+ target = noise
978
+ elif noise_scheduler.config.prediction_type == "v_prediction":
979
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
980
+ else:
981
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
982
+
983
+ if args.with_prior_preservation:
984
+ # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
985
+ model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
986
+ target, target_prior = torch.chunk(target, 2, dim=0)
987
+
988
+ # Compute instance loss
989
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
990
+
991
+ # Compute prior loss
992
+ prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
993
+
994
+ # Add the prior loss to the instance loss.
995
+ loss = loss + args.prior_loss_weight * prior_loss
996
+ else:
997
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
998
+
999
+ accelerator.backward(loss)
1000
+ if accelerator.sync_gradients:
1001
+ params_to_clip = (
1002
+ itertools.chain(unet.parameters(), text_encoder.parameters())
1003
+ if args.train_text_encoder
1004
+ else unet.parameters()
1005
+ )
1006
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
1007
+ optimizer.step()
1008
+ lr_scheduler.step()
1009
+ optimizer.zero_grad()
1010
+
1011
+ # Checks if the accelerator has performed an optimization step behind the scenes
1012
+ if accelerator.sync_gradients:
1013
+ progress_bar.update(1)
1014
+ global_step += 1
1015
+
1016
+ if global_step % args.checkpointing_steps == 0:
1017
+ if accelerator.is_main_process:
1018
+ save_weights(global_step)
1019
+ # save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
1020
+ # accelerator.save_state(save_path)
1021
+ # logger.info(f"Saved state to {save_path}")
1022
+
1023
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "lr_1d": lr_scheduler.get_last_lr()[1]}
1024
+ progress_bar.set_postfix(**logs)
1025
+ accelerator.log(logs, step=global_step)
1026
+
1027
+ if global_step >= args.max_train_steps:
1028
+ break
1029
+
1030
+ if accelerator.is_main_process:
1031
+ if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
1032
+ log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch)
1033
+
1034
+ accelerator.wait_for_everyone()
1035
+ # put the latest checkpoint to output-dir
1036
+ save_weights(global_step, save_path=args.output_dir)
1037
+ if accelerator.is_main_process:
1038
+ if args.push_to_hub:
1039
+ save_model_card(
1040
+ repo_id,
1041
+ base_model=args.pretrained_model_name_or_path,
1042
+ prompt=args.instance_prompt,
1043
+ repo_folder=args.output_dir,
1044
+ )
1045
+ upload_folder(
1046
+ repo_id=repo_id,
1047
+ folder_path=args.output_dir,
1048
+ commit_message="End of training",
1049
+ ignore_patterns=["step_*", "epoch_*"],
1050
+ )
1051
+
1052
+ accelerator.end_training()
1053
+
1054
+
1055
+ if __name__ == "__main__":
1056
+ args = parse_args()
1057
+ main(args)
trainer.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import datetime
4
+ import os
5
+ import pathlib
6
+ import shlex
7
+ import shutil
8
+ import subprocess
9
+
10
+ import gradio as gr
11
+ import PIL.Image
12
+ import slugify
13
+ import torch
14
+ from huggingface_hub import HfApi
15
+ from accelerate.utils import write_basic_config
16
+
17
+
18
+ from app_upload import ModelUploader
19
+ from utils import save_model_card
20
+
21
+ URL_TO_JOIN_LIBRARY_ORG = 'https://huggingface.co/organizations/svdiff-library/share/PZBRRkosXikenXUdjMcvcoFmpWjcWnZjKL'
22
+
23
+
24
+ def pad_image(image: PIL.Image.Image) -> PIL.Image.Image:
25
+ w, h = image.size
26
+ if w == h:
27
+ return image
28
+ elif w > h:
29
+ new_image = PIL.Image.new(image.mode, (w, w), (0, 0, 0))
30
+ new_image.paste(image, (0, (w - h) // 2))
31
+ return new_image
32
+ else:
33
+ new_image = PIL.Image.new(image.mode, (h, h), (0, 0, 0))
34
+ new_image.paste(image, ((h - w) // 2, 0))
35
+ return new_image
36
+
37
+
38
+ class Trainer:
39
+ def __init__(self, hf_token: str | None = None):
40
+ self.hf_token = hf_token
41
+ self.api = HfApi(token=hf_token)
42
+ self.model_uploader = ModelUploader(hf_token)
43
+
44
+ def prepare_dataset(self, instance_images: list, resolution: int,
45
+ instance_data_dir: pathlib.Path) -> None:
46
+ shutil.rmtree(instance_data_dir, ignore_errors=True)
47
+ instance_data_dir.mkdir(parents=True)
48
+ for i, temp_path in enumerate(instance_images):
49
+ image = PIL.Image.open(temp_path.name)
50
+ image = pad_image(image)
51
+ image = image.resize((resolution, resolution))
52
+ image = image.convert('RGB')
53
+ out_path = instance_data_dir / f'{i:03d}.jpg'
54
+ image.save(out_path, format='JPEG', quality=100)
55
+
56
+ def join_library_org(self) -> None:
57
+ subprocess.run(
58
+ shlex.split(
59
+ f'curl -X POST -H "Authorization: Bearer {self.hf_token}" -H "Content-Type: application/json" {URL_TO_JOIN_LIBRARY_ORG}'
60
+ ))
61
+
62
+ def run(
63
+ self,
64
+ instance_images: list | None,
65
+ instance_prompt: str,
66
+ output_model_name: str,
67
+ overwrite_existing_model: bool,
68
+ validation_prompt: str,
69
+ base_model: str,
70
+ resolution_s: str,
71
+ n_steps: int,
72
+ learning_rate: float,
73
+ gradient_accumulation: int,
74
+ seed: int,
75
+ fp16: bool,
76
+ use_8bit_adam: bool,
77
+ gradient_checkpointing: bool,
78
+ # enable_xformers_memory_efficient_attention: bool,
79
+ checkpointing_steps: int,
80
+ use_wandb: bool,
81
+ validation_epochs: int,
82
+ upload_to_hub: bool,
83
+ use_private_repo: bool,
84
+ delete_existing_repo: bool,
85
+ upload_to: str,
86
+ remove_gpu_after_training: bool,
87
+ ) -> str:
88
+ if not torch.cuda.is_available():
89
+ raise gr.Error('CUDA is not available.')
90
+ if instance_images is None:
91
+ raise gr.Error('You need to upload images.')
92
+ if not instance_prompt:
93
+ raise gr.Error('The instance prompt is missing.')
94
+ if not validation_prompt:
95
+ raise gr.Error('The validation prompt is missing.')
96
+
97
+ resolution = int(resolution_s)
98
+
99
+ if not output_model_name:
100
+ timestamp = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
101
+ output_model_name = f'svdiff-pytorch-{timestamp}'
102
+ output_model_name = slugify.slugify(output_model_name)
103
+
104
+ repo_dir = pathlib.Path(__file__).parent
105
+ output_dir = repo_dir / 'experiments' / output_model_name
106
+ if overwrite_existing_model or upload_to_hub:
107
+ shutil.rmtree(output_dir, ignore_errors=True)
108
+ output_dir.mkdir(parents=True)
109
+
110
+ instance_data_dir = repo_dir / 'training_data' / output_model_name
111
+ self.prepare_dataset(instance_images, resolution, instance_data_dir)
112
+
113
+ if upload_to_hub:
114
+ self.join_library_org()
115
+ # accelerate config
116
+ write_basic_config()
117
+ command = f'''
118
+ accelerate launch train_svdiff.py \
119
+ --pretrained_model_name_or_path={base_model} \
120
+ --instance_data_dir={instance_data_dir} \
121
+ --output_dir={output_dir} \
122
+ --instance_prompt="{instance_prompt}" \
123
+ --resolution={resolution} \
124
+ --train_batch_size=1 \
125
+ --gradient_accumulation_steps={gradient_accumulation} \
126
+ --learning_rate={learning_rate} \
127
+ --learning_rate_1d=1e-6 \
128
+ --train_text_encoder \
129
+ --lr_scheduler=constant \
130
+ --lr_warmup_steps=0 \
131
+ --max_train_steps={n_steps} \
132
+ --checkpointing_steps={checkpointing_steps} \
133
+ --validation_prompt="{validation_prompt}" \
134
+ --validation_epochs={validation_epochs} \
135
+ --seed={seed}
136
+ '''
137
+ if fp16:
138
+ command += ' --mixed_precision="fp16"'
139
+ if use_8bit_adam:
140
+ command += ' --use_8bit_adam'
141
+ if gradient_checkpointing:
142
+ command += ' --gradient_checkpointing'
143
+ # if enable_xformers_memory_efficient_attention:
144
+ # command += ' --enable_xformers_memory_efficient_attention'
145
+ if use_wandb:
146
+ command += ' --report_to wandb'
147
+
148
+ with open(output_dir / 'train.sh', 'w') as f:
149
+ command_s = ' '.join(command.split())
150
+ f.write(command_s)
151
+ subprocess.run(shlex.split(command))
152
+ save_model_card(save_dir=output_dir,
153
+ base_model=base_model,
154
+ instance_prompt=instance_prompt,
155
+ test_prompt=validation_prompt,
156
+ test_image_dir='test_images')
157
+
158
+ message = 'Training completed!'
159
+ print(message)
160
+
161
+ if upload_to_hub:
162
+ upload_message = self.model_uploader.upload_model(
163
+ folder_path=output_dir.as_posix(),
164
+ repo_name=output_model_name,
165
+ upload_to=upload_to,
166
+ private=use_private_repo,
167
+ delete_existing_repo=delete_existing_repo)
168
+ print(upload_message)
169
+ message = message + '\n' + upload_message
170
+
171
+ if remove_gpu_after_training:
172
+ space_id = os.getenv('SPACE_ID')
173
+ if space_id:
174
+ self.api.request_space_hardware(repo_id=space_id,
175
+ hardware='cpu-basic')
176
+
177
+ return message
uploader.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from huggingface_hub import HfApi
4
+
5
+
6
+ class Uploader:
7
+ def __init__(self, hf_token: str | None):
8
+ self.api = HfApi(token=hf_token)
9
+
10
+ def get_username(self) -> str:
11
+ return self.api.whoami()['name']
12
+
13
+ def upload(self,
14
+ folder_path: str,
15
+ repo_name: str,
16
+ organization: str = '',
17
+ repo_type: str = 'model',
18
+ private: bool = True,
19
+ delete_existing_repo: bool = False) -> str:
20
+ if not folder_path:
21
+ raise ValueError
22
+ if not repo_name:
23
+ raise ValueError
24
+ if not organization:
25
+ organization = self.get_username()
26
+ repo_id = f'{organization}/{repo_name}'
27
+ if delete_existing_repo:
28
+ try:
29
+ self.api.delete_repo(repo_id, repo_type=repo_type)
30
+ except Exception:
31
+ pass
32
+ try:
33
+ self.api.create_repo(repo_id, repo_type=repo_type, private=private)
34
+ self.api.upload_folder(repo_id=repo_id,
35
+ folder_path=folder_path,
36
+ path_in_repo='.',
37
+ repo_type=repo_type)
38
+ url = f'https://huggingface.co/{repo_id}'
39
+ message = f'Your model was successfully uploaded to <a href="{url}" target="_blank">{url}</a>.'
40
+ except Exception as e:
41
+ message = str(e)
42
+ return message
utils.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import pathlib
4
+
5
+
6
+ def find_exp_dirs(ignore_repo: bool = False) -> list[str]:
7
+ repo_dir = pathlib.Path(__file__).parent
8
+ exp_root_dir = repo_dir / 'experiments'
9
+ if not exp_root_dir.exists():
10
+ return []
11
+ exp_dirs = sorted(exp_root_dir.glob('*'))
12
+ exp_dirs = [
13
+ exp_dir for exp_dir in exp_dirs
14
+ if (exp_dir / 'spectral_shifts.safetensors').exists()
15
+ ]
16
+ if ignore_repo:
17
+ exp_dirs = [
18
+ exp_dir for exp_dir in exp_dirs if not (exp_dir / '.git').exists()
19
+ ]
20
+ return [path.relative_to(repo_dir).as_posix() for path in exp_dirs]
21
+
22
+
23
+ def save_model_card(
24
+ save_dir: pathlib.Path,
25
+ base_model: str,
26
+ instance_prompt: str,
27
+ test_prompt: str = '',
28
+ test_image_dir: str = '',
29
+ ) -> None:
30
+ image_str = ''
31
+ if test_prompt and test_image_dir:
32
+ image_paths = sorted((save_dir / test_image_dir).glob('*'))
33
+ if image_paths:
34
+ image_str = f'Test prompt: {test_prompt}\n'
35
+ for image_path in image_paths:
36
+ rel_path = image_path.relative_to(save_dir)
37
+ image_str += f'![{image_path.stem}]({rel_path})\n'
38
+
39
+ model_card = f'''---
40
+ license: creativeml-openrail-m
41
+ base_model: {base_model}
42
+ instance_prompt: {instance_prompt}
43
+ tags:
44
+ - stable-diffusion
45
+ - stable-diffusion-diffusers
46
+ - text-to-image
47
+ - diffusers
48
+ - lora
49
+ inference: true
50
+ ---
51
+ # SVDiff-pytorch - {save_dir.name}
52
+
53
+ These are SVDiff weights for {base_model}. The weights were trained on "{instance_prompt}" using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following.
54
+ {image_str}
55
+ '''
56
+
57
+ with open(save_dir / 'README.md', 'w') as f:
58
+ f.write(model_card)