hysts HF staff commited on
Commit
a3a5dce
1 Parent(s): 775da60
Files changed (11) hide show
  1. .gitignore +164 -0
  2. .gitmodules +3 -0
  3. .pre-commit-config.yaml +35 -0
  4. .style.yapf +5 -0
  5. LICENSE +21 -0
  6. app.py +225 -0
  7. inference.py +73 -0
  8. lora +1 -0
  9. requirements.txt +10 -0
  10. style.css +3 -0
  11. trainer.py +109 -0
.gitignore ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ training_data/
2
+ results/
3
+
4
+
5
+ # Byte-compiled / optimized / DLL files
6
+ __pycache__/
7
+ *.py[cod]
8
+ *$py.class
9
+
10
+ # C extensions
11
+ *.so
12
+
13
+ # Distribution / packaging
14
+ .Python
15
+ build/
16
+ develop-eggs/
17
+ dist/
18
+ downloads/
19
+ eggs/
20
+ .eggs/
21
+ lib/
22
+ lib64/
23
+ parts/
24
+ sdist/
25
+ var/
26
+ wheels/
27
+ share/python-wheels/
28
+ *.egg-info/
29
+ .installed.cfg
30
+ *.egg
31
+ MANIFEST
32
+
33
+ # PyInstaller
34
+ # Usually these files are written by a python script from a template
35
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
36
+ *.manifest
37
+ *.spec
38
+
39
+ # Installer logs
40
+ pip-log.txt
41
+ pip-delete-this-directory.txt
42
+
43
+ # Unit test / coverage reports
44
+ htmlcov/
45
+ .tox/
46
+ .nox/
47
+ .coverage
48
+ .coverage.*
49
+ .cache
50
+ nosetests.xml
51
+ coverage.xml
52
+ *.cover
53
+ *.py,cover
54
+ .hypothesis/
55
+ .pytest_cache/
56
+ cover/
57
+
58
+ # Translations
59
+ *.mo
60
+ *.pot
61
+
62
+ # Django stuff:
63
+ *.log
64
+ local_settings.py
65
+ db.sqlite3
66
+ db.sqlite3-journal
67
+
68
+ # Flask stuff:
69
+ instance/
70
+ .webassets-cache
71
+
72
+ # Scrapy stuff:
73
+ .scrapy
74
+
75
+ # Sphinx documentation
76
+ docs/_build/
77
+
78
+ # PyBuilder
79
+ .pybuilder/
80
+ target/
81
+
82
+ # Jupyter Notebook
83
+ .ipynb_checkpoints
84
+
85
+ # IPython
86
+ profile_default/
87
+ ipython_config.py
88
+
89
+ # pyenv
90
+ # For a library or package, you might want to ignore these files since the code is
91
+ # intended to run in multiple environments; otherwise, check them in:
92
+ # .python-version
93
+
94
+ # pipenv
95
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
96
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
97
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
98
+ # install all needed dependencies.
99
+ #Pipfile.lock
100
+
101
+ # poetry
102
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
103
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
104
+ # commonly ignored for libraries.
105
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
106
+ #poetry.lock
107
+
108
+ # pdm
109
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
110
+ #pdm.lock
111
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
112
+ # in version control.
113
+ # https://pdm.fming.dev/#use-with-ide
114
+ .pdm.toml
115
+
116
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
117
+ __pypackages__/
118
+
119
+ # Celery stuff
120
+ celerybeat-schedule
121
+ celerybeat.pid
122
+
123
+ # SageMath parsed files
124
+ *.sage.py
125
+
126
+ # Environments
127
+ .env
128
+ .venv
129
+ env/
130
+ venv/
131
+ ENV/
132
+ env.bak/
133
+ venv.bak/
134
+
135
+ # Spyder project settings
136
+ .spyderproject
137
+ .spyproject
138
+
139
+ # Rope project settings
140
+ .ropeproject
141
+
142
+ # mkdocs documentation
143
+ /site
144
+
145
+ # mypy
146
+ .mypy_cache/
147
+ .dmypy.json
148
+ dmypy.json
149
+
150
+ # Pyre type checker
151
+ .pyre/
152
+
153
+ # pytype static type analyzer
154
+ .pytype/
155
+
156
+ # Cython debug symbols
157
+ cython_debug/
158
+
159
+ # PyCharm
160
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
161
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
162
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
163
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
164
+ #.idea/
.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "lora"]
2
+ path = lora
3
+ url = https://github.com/cloneofsimo/lora
.pre-commit-config.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/pre-commit/pre-commit-hooks
3
+ rev: v4.2.0
4
+ hooks:
5
+ - id: check-executables-have-shebangs
6
+ - id: check-json
7
+ - id: check-merge-conflict
8
+ - id: check-shebang-scripts-are-executable
9
+ - id: check-toml
10
+ - id: check-yaml
11
+ - id: double-quote-string-fixer
12
+ - id: end-of-file-fixer
13
+ - id: mixed-line-ending
14
+ args: ['--fix=lf']
15
+ - id: requirements-txt-fixer
16
+ - id: trailing-whitespace
17
+ - repo: https://github.com/myint/docformatter
18
+ rev: v1.4
19
+ hooks:
20
+ - id: docformatter
21
+ args: ['--in-place']
22
+ - repo: https://github.com/pycqa/isort
23
+ rev: 5.10.1
24
+ hooks:
25
+ - id: isort
26
+ - repo: https://github.com/pre-commit/mirrors-mypy
27
+ rev: v0.991
28
+ hooks:
29
+ - id: mypy
30
+ args: ['--ignore-missing-imports']
31
+ - repo: https://github.com/google/yapf
32
+ rev: v0.32.0
33
+ hooks:
34
+ - id: yapf
35
+ args: ['--parallel', '--in-place']
.style.yapf ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ [style]
2
+ based_on_style = pep8
3
+ blank_line_before_nested_class_or_def = false
4
+ spaces_before_comment = 2
5
+ split_before_logical_operator = true
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2022 hysts
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
app.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """Unofficial demo app for https://github.com/cloneofsimo/lora.
3
+
4
+ The code in this repo is partly adapted from the following repository:
5
+ https://huggingface.co/spaces/multimodalart/dreambooth-training/tree/a00184917aa273c6d8adab08d5deb9b39b997938
6
+ The license of the original code is MIT, which is specified in the README.md.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import os
12
+ import pathlib
13
+
14
+ import gradio as gr
15
+ import torch
16
+
17
+ from inference import InferencePipeline
18
+ from trainer import Trainer
19
+
20
+ TITLE = '# LoRA + StableDiffusion Training UI'
21
+ DESCRIPTION = 'This is an unofficial demo for [https://github.com/cloneofsimo/lora](https://github.com/cloneofsimo/lora).'
22
+
23
+ USAGE_INFO = '''You can train and download models in the "Training" tab, and test them in the "Test" tab.
24
+
25
+ You can also test the pretrained models in the [original repo](https://github.com/cloneofsimo/lora).
26
+ Models with names starting with "lora/" are the pretrained models and the ones with names starting with "results/" are your trained models.
27
+ After training, you can press "Reload Weight List" button to load your trained model names.
28
+
29
+ Note that your trained models will be deleted when the second training is started.
30
+ '''
31
+
32
+ SPACE_ID = os.getenv('SPACE_ID', 'hysts/LoRA-SD-training')
33
+ SHARED_UI_WARNING = f'''# Attention - This Space doesn't work in this shared UI. You can duplicate and use it with a paid private T4 GPU.
34
+
35
+ <center><a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a></center>
36
+ '''
37
+ CUDA_NOT_AVAILABLE_WARNING = '# Attention - CUDA is not available in this environment.'
38
+
39
+
40
+ def show_warning(warning_text: str) -> gr.Blocks:
41
+ with gr.Blocks() as demo:
42
+ with gr.Box():
43
+ gr.Markdown(warning_text)
44
+ return demo
45
+
46
+
47
+ def update_output_files() -> dict:
48
+ paths = sorted(pathlib.Path('results').glob('*.pt'))
49
+ paths = [path.as_posix() for path in paths] # type: ignore
50
+ return gr.update(value=paths or None)
51
+
52
+
53
+ def create_training_demo(trainer: Trainer,
54
+ pipe: InferencePipeline) -> gr.Blocks:
55
+ with gr.Blocks() as demo:
56
+ base_model = gr.Dropdown(
57
+ choices=['stabilityai/stable-diffusion-2-1-base'],
58
+ value='stabilityai/stable-diffusion-2-1-base',
59
+ label='Base Model',
60
+ visible=False)
61
+ resolution = gr.Dropdown(choices=['512'],
62
+ value='512',
63
+ label='Resolution',
64
+ visible=False)
65
+
66
+ with gr.Row():
67
+ with gr.Box():
68
+ gr.Markdown('Training Data')
69
+ concept_images = gr.Files(label='Images for your concept')
70
+ concept_prompt = gr.Textbox(label='Concept Prompt',
71
+ value='sks',
72
+ max_lines=1)
73
+ gr.Markdown('''
74
+ - Upload images of the style you are planning on training on.
75
+ - For a concept prompt, use a unique, made up word to avoid collisions.
76
+ ''')
77
+ with gr.Box():
78
+ gr.Markdown('Training Parameters')
79
+ num_training_steps = gr.Number(
80
+ label='Number of Training Steps', value=1000, precision=0)
81
+ learning_rate = gr.Number(label='Learning Rate', value=0.0001)
82
+ gr.Markdown('''
83
+ - It will take about 15-20 minutes to train for 1000 steps with a T4 GPU.
84
+ - You may want to try a small number of steps first, like 1, to see if everything works fine in your environment.
85
+ ''')
86
+
87
+ run_button = gr.Button('Start Training')
88
+ with gr.Box():
89
+ with gr.Row():
90
+ check_status_button = gr.Button('Check Training Status')
91
+ with gr.Column():
92
+ training_status = gr.Markdown()
93
+ output_files = gr.Files(label='Trained Weight Files')
94
+
95
+ run_button.click(fn=pipe.clear)
96
+ run_button.click(fn=trainer.run,
97
+ inputs=[
98
+ base_model,
99
+ resolution,
100
+ concept_images,
101
+ concept_prompt,
102
+ num_training_steps,
103
+ learning_rate,
104
+ ],
105
+ outputs=[
106
+ training_status,
107
+ output_files,
108
+ ],
109
+ queue=False)
110
+ check_status_button.click(fn=trainer.check_if_running,
111
+ inputs=None,
112
+ outputs=training_status,
113
+ queue=False)
114
+ check_status_button.click(fn=update_output_files,
115
+ inputs=None,
116
+ outputs=output_files,
117
+ queue=False)
118
+ return demo
119
+
120
+
121
+ def find_weight_files() -> list[str]:
122
+ curr_dir = pathlib.Path(__file__).parent
123
+ paths = sorted(curr_dir.rglob('*.pt'))
124
+ return [path.relative_to(curr_dir).as_posix() for path in paths]
125
+
126
+
127
+ def reload_lora_weight_list() -> dict:
128
+ return gr.update(choices=find_weight_files())
129
+
130
+
131
+ def create_inference_demo(pipe: InferencePipeline) -> gr.Blocks:
132
+ with gr.Blocks() as demo:
133
+ with gr.Row():
134
+ with gr.Column():
135
+ base_model = gr.Dropdown(
136
+ choices=['stabilityai/stable-diffusion-2-1-base'],
137
+ value='stabilityai/stable-diffusion-2-1-base',
138
+ label='Base Model',
139
+ visible=False)
140
+ reload_button = gr.Button('Reload Weight List')
141
+ lora_weight_name = gr.Dropdown(choices=find_weight_files(),
142
+ value='lora/lora_disney.pt',
143
+ label='LoRA Weight File')
144
+ prompt = gr.Textbox(
145
+ label='Prompt',
146
+ max_lines=1,
147
+ placeholder='Example: "style of sks, baby lion"')
148
+ alpha = gr.Slider(label='Alpha',
149
+ minimum=0,
150
+ maximum=2,
151
+ step=0.05,
152
+ value=1)
153
+ seed = gr.Slider(label='Seed',
154
+ minimum=0,
155
+ maximum=100000,
156
+ step=1,
157
+ value=1)
158
+ with gr.Accordion('Other Parameters', open=False):
159
+ num_steps = gr.Slider(label='Number of Steps',
160
+ minimum=0,
161
+ maximum=100,
162
+ step=1,
163
+ value=50)
164
+ guidance_scale = gr.Slider(label='CFG Scale',
165
+ minimum=0,
166
+ maximum=50,
167
+ step=0.1,
168
+ value=7)
169
+
170
+ run_button = gr.Button('Generate')
171
+ with gr.Column():
172
+ result = gr.Image(label='Result')
173
+
174
+ reload_button.click(fn=reload_lora_weight_list,
175
+ inputs=None,
176
+ outputs=lora_weight_name)
177
+ prompt.submit(fn=pipe.run,
178
+ inputs=[
179
+ base_model,
180
+ lora_weight_name,
181
+ prompt,
182
+ alpha,
183
+ seed,
184
+ num_steps,
185
+ guidance_scale,
186
+ ],
187
+ outputs=result,
188
+ queue=False)
189
+ run_button.click(fn=pipe.run,
190
+ inputs=[
191
+ base_model,
192
+ lora_weight_name,
193
+ prompt,
194
+ alpha,
195
+ seed,
196
+ num_steps,
197
+ guidance_scale,
198
+ ],
199
+ outputs=result,
200
+ queue=False)
201
+ return demo
202
+
203
+
204
+ pipe = InferencePipeline()
205
+ trainer = Trainer()
206
+
207
+ with gr.Blocks(css='style.css') as demo:
208
+ if os.getenv('IS_SHARED_UI'):
209
+ show_warning(SHARED_UI_WARNING)
210
+ if not torch.cuda.is_available():
211
+ show_warning(CUDA_NOT_AVAILABLE_WARNING)
212
+
213
+ gr.Markdown(TITLE)
214
+ gr.Markdown(DESCRIPTION)
215
+
216
+ with gr.Tabs():
217
+ with gr.TabItem('Training'):
218
+ create_training_demo(trainer, pipe)
219
+ with gr.TabItem('Test'):
220
+ create_inference_demo(pipe)
221
+
222
+ with gr.Accordion('Usage', open=False):
223
+ gr.Markdown(USAGE_INFO)
224
+
225
+ demo.queue(default_enabled=False).launch(share=False)
inference.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import gc
4
+ import pathlib
5
+ import sys
6
+
7
+ import gradio as gr
8
+ import PIL.Image
9
+ import torch
10
+ from diffusers import StableDiffusionPipeline
11
+
12
+ sys.path.insert(0, 'lora')
13
+ from lora_diffusion import monkeypatch_lora, tune_lora_scale
14
+
15
+
16
+ class InferencePipeline:
17
+ def __init__(self):
18
+ self.pipe = None
19
+ self.device = torch.device(
20
+ 'cuda:0' if torch.cuda.is_available() else 'cpu')
21
+ self.weight_path = None
22
+
23
+ def clear(self) -> None:
24
+ self.weight_path = None
25
+ del self.pipe
26
+ self.pipe = None
27
+ torch.cuda.empty_cache()
28
+ gc.collect()
29
+
30
+ @staticmethod
31
+ def get_lora_weight_path(name: str) -> pathlib.Path:
32
+ curr_dir = pathlib.Path(__file__).parent
33
+ return curr_dir / name
34
+
35
+ def load_pipe(self, model_id: str, lora_filename: str) -> None:
36
+ weight_path = self.get_lora_weight_path(lora_filename)
37
+ if weight_path == self.weight_path:
38
+ return
39
+ self.weight_path = weight_path
40
+ lora_weight = torch.load(self.weight_path, map_location=self.device)
41
+
42
+ if self.device.type == 'cpu':
43
+ pipe = StableDiffusionPipeline.from_pretrained(model_id)
44
+ else:
45
+ pipe = StableDiffusionPipeline.from_pretrained(
46
+ model_id, torch_dtype=torch.float16)
47
+ pipe = pipe.to(self.device)
48
+
49
+ monkeypatch_lora(pipe.unet, lora_weight)
50
+ self.pipe = pipe
51
+
52
+ def run(
53
+ self,
54
+ base_model: str,
55
+ lora_weight_name: str,
56
+ prompt: str,
57
+ alpha: float,
58
+ seed: int,
59
+ n_steps: int,
60
+ guidance_scale: float,
61
+ ) -> PIL.Image.Image:
62
+ if not torch.cuda.is_available():
63
+ raise gr.Error('CUDA is not available.')
64
+
65
+ self.load_pipe(base_model, lora_weight_name)
66
+
67
+ generator = torch.Generator(device=self.device).manual_seed(seed)
68
+ tune_lora_scale(self.pipe.unet, alpha) # type: ignore
69
+ out = self.pipe(prompt,
70
+ num_inference_steps=n_steps,
71
+ guidance_scale=guidance_scale,
72
+ generator=generator) # type: ignore
73
+ return out.images[0]
lora ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit ba349e56e23e92e3b128c7c67ae58d3067540daa
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.15.0
2
+ bitsandbytes==0.35.4
3
+ diffusers==0.10.2
4
+ ftfy==6.1.1
5
+ Pillow==9.3.0
6
+ torch==1.13.0
7
+ torchvision==0.14.0
8
+ transformers==4.25.1
9
+ triton==2.0.0.dev20220701
10
+ xformers==0.0.13
style.css ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }
trainer.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import pathlib
5
+ import shlex
6
+ import shutil
7
+ import subprocess
8
+
9
+ import gradio as gr
10
+ import PIL.Image
11
+ import torch
12
+
13
+ os.environ['PYTHONPATH'] = f'lora:{os.getenv("PYTHONPATH", "")}'
14
+
15
+
16
+ def pad_image(image: PIL.Image.Image) -> PIL.Image.Image:
17
+ w, h = image.size
18
+ if w == h:
19
+ return image
20
+ elif w > h:
21
+ new_image = PIL.Image.new(image.mode, (w, w), (0, 0, 0))
22
+ new_image.paste(image, (0, (w - h) // 2))
23
+ return new_image
24
+ else:
25
+ new_image = PIL.Image.new(image.mode, (h, h), (0, 0, 0))
26
+ new_image.paste(image, ((h - w) // 2, 0))
27
+ return new_image
28
+
29
+
30
+ class Trainer:
31
+ def __init__(self):
32
+ self.is_running = False
33
+ self.is_running_message = 'Another training is in progress.'
34
+
35
+ self.instance_data_dir = pathlib.Path('training_data')
36
+ self.output_dir = pathlib.Path('results')
37
+
38
+ def check_if_running(self) -> dict:
39
+ if self.is_running:
40
+ return gr.update(value=self.is_running_message)
41
+ else:
42
+ return gr.update(value='No training is running.')
43
+
44
+ def cleanup_dirs(self) -> None:
45
+ shutil.rmtree(self.instance_data_dir, ignore_errors=True)
46
+ shutil.rmtree(self.output_dir, ignore_errors=True)
47
+
48
+ def prepare_dataset(self, concept_images: list, resolution: int) -> None:
49
+ self.instance_data_dir.mkdir()
50
+ for i, temp_path in enumerate(concept_images):
51
+ image = PIL.Image.open(temp_path.name)
52
+ image = pad_image(image)
53
+ image = image.resize((resolution, resolution))
54
+ image = image.convert('RGB')
55
+ out_path = self.instance_data_dir / f'{i:03d}.jpg'
56
+ image.save(out_path, format='JPEG', quality=100)
57
+
58
+ def run(
59
+ self,
60
+ base_model: str,
61
+ resolution_s: str,
62
+ concept_images: list | None,
63
+ concept_prompt: str,
64
+ n_steps: int,
65
+ learning_rate: float,
66
+ ) -> tuple[dict, str]:
67
+ if not torch.cuda.is_available():
68
+ raise gr.Error('CUDA is not available.')
69
+
70
+ out_path = ''
71
+ if self.is_running:
72
+ return gr.update(value=self.is_running_message), out_path
73
+
74
+ if concept_images is None:
75
+ raise gr.Error('You need to upload images.')
76
+ if not concept_prompt:
77
+ raise gr.Error('The concept prompt is missing.')
78
+
79
+ resolution = int(resolution_s)
80
+
81
+ self.cleanup_dirs()
82
+ self.prepare_dataset(concept_images, resolution)
83
+
84
+ self.is_running = True
85
+ command = f'''
86
+ accelerate launch lora/train_lora_dreambooth.py \
87
+ --pretrained_model_name_or_path={base_model} \
88
+ --instance_data_dir={self.instance_data_dir} \
89
+ --output_dir={self.output_dir} \
90
+ --instance_prompt="style of {concept_prompt}" \
91
+ --resolution={resolution} \
92
+ --train_batch_size=1 \
93
+ --gradient_accumulation_steps=1 \
94
+ --learning_rate={learning_rate} \
95
+ --lr_scheduler=constant \
96
+ --lr_warmup_steps=0 \
97
+ --max_train_steps={n_steps}
98
+ '''
99
+ res = subprocess.run(shlex.split(command))
100
+ self.is_running = False
101
+
102
+ if res.returncode == 0:
103
+ result_message = 'Training Completed!'
104
+ weight_path = self.output_dir / 'lora_weight.pt'
105
+ if weight_path.exists():
106
+ out_path = weight_path.as_posix()
107
+ else:
108
+ result_message = 'Training Failed!'
109
+ return gr.update(value=result_message), out_path