Spaces:
Runtime error
Runtime error
1lint
commited on
Commit
·
6230dda
0
Parent(s):
init commit
Browse files- .gitignore +173 -0
- LICENSE +201 -0
- README.md +76 -0
- app.py +4 -0
- configs/controlnet_config.json +41 -0
- convert_state_dict.sh +8 -0
- main.py +49 -0
- quickstart_train.py +50 -0
- requirements.txt +18 -0
- src/__init__.py +2 -0
- src/app.py +260 -0
- src/controlnet_pipe.py +309 -0
- src/convert_sd.py +223 -0
- src/data.py +149 -0
- src/lab.py +474 -0
- src/ui_assets/controlnet_ids.txt +4 -0
- src/ui_assets/examples +1 -0
- src/ui_assets/footer.html +9 -0
- src/ui_assets/header.html +23 -0
- src/ui_assets/model_ids.txt +5 -0
- src/ui_functions.py +285 -0
- src/ui_shared.py +24 -0
.gitignore
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# added
|
2 |
+
archive/
|
3 |
+
wandb/
|
4 |
+
logs
|
5 |
+
models
|
6 |
+
.git_*
|
7 |
+
test*
|
8 |
+
video/
|
9 |
+
train.py
|
10 |
+
deploy.py
|
11 |
+
examples/
|
12 |
+
notebooks/
|
13 |
+
|
14 |
+
# Byte-compiled / optimized / DLL files
|
15 |
+
__pycache__/
|
16 |
+
*.py[cod]
|
17 |
+
*$py.class
|
18 |
+
|
19 |
+
# C extensions
|
20 |
+
*.so
|
21 |
+
|
22 |
+
# Distribution / packaging
|
23 |
+
.Python
|
24 |
+
build/
|
25 |
+
develop-eggs/
|
26 |
+
dist/
|
27 |
+
downloads/
|
28 |
+
eggs/
|
29 |
+
.eggs/
|
30 |
+
lib/
|
31 |
+
lib64/
|
32 |
+
parts/
|
33 |
+
sdist/
|
34 |
+
var/
|
35 |
+
wheels/
|
36 |
+
share/python-wheels/
|
37 |
+
*.egg-info/
|
38 |
+
.installed.cfg
|
39 |
+
*.egg
|
40 |
+
MANIFEST
|
41 |
+
|
42 |
+
# PyInstaller
|
43 |
+
# Usually these files are written by a python script from a template
|
44 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
45 |
+
*.manifest
|
46 |
+
*.spec
|
47 |
+
|
48 |
+
# Installer logs
|
49 |
+
pip-log.txt
|
50 |
+
pip-delete-this-directory.txt
|
51 |
+
|
52 |
+
# Unit test / coverage reports
|
53 |
+
htmlcov/
|
54 |
+
.tox/
|
55 |
+
.nox/
|
56 |
+
.coverage
|
57 |
+
.coverage.*
|
58 |
+
.cache
|
59 |
+
nosetests.xml
|
60 |
+
coverage.xml
|
61 |
+
*.cover
|
62 |
+
*.py,cover
|
63 |
+
.hypothesis/
|
64 |
+
.pytest_cache/
|
65 |
+
cover/
|
66 |
+
|
67 |
+
# Translations
|
68 |
+
*.mo
|
69 |
+
*.pot
|
70 |
+
|
71 |
+
# Django stuff:
|
72 |
+
*.log
|
73 |
+
local_settings.py
|
74 |
+
db.sqlite3
|
75 |
+
db.sqlite3-journal
|
76 |
+
|
77 |
+
# Flask stuff:
|
78 |
+
instance/
|
79 |
+
.webassets-cache
|
80 |
+
|
81 |
+
# Scrapy stuff:
|
82 |
+
.scrapy
|
83 |
+
|
84 |
+
# Sphinx documentation
|
85 |
+
docs/_build/
|
86 |
+
|
87 |
+
# PyBuilder
|
88 |
+
.pybuilder/
|
89 |
+
target/
|
90 |
+
|
91 |
+
# Jupyter Notebook
|
92 |
+
.ipynb_checkpoints
|
93 |
+
|
94 |
+
# IPython
|
95 |
+
profile_default/
|
96 |
+
ipython_config.py
|
97 |
+
|
98 |
+
# pyenv
|
99 |
+
# For a library or package, you might want to ignore these files since the code is
|
100 |
+
# intended to run in multiple environments; otherwise, check them in:
|
101 |
+
# .python-version
|
102 |
+
|
103 |
+
# pipenv
|
104 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
105 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
106 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
107 |
+
# install all needed dependencies.
|
108 |
+
#Pipfile.lock
|
109 |
+
|
110 |
+
# poetry
|
111 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
112 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
113 |
+
# commonly ignored for libraries.
|
114 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
115 |
+
#poetry.lock
|
116 |
+
|
117 |
+
# pdm
|
118 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
119 |
+
#pdm.lock
|
120 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
121 |
+
# in version control.
|
122 |
+
# https://pdm.fming.dev/#use-with-ide
|
123 |
+
.pdm.toml
|
124 |
+
|
125 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
126 |
+
__pypackages__/
|
127 |
+
|
128 |
+
# Celery stuff
|
129 |
+
celerybeat-schedule
|
130 |
+
celerybeat.pid
|
131 |
+
|
132 |
+
# SageMath parsed files
|
133 |
+
*.sage.py
|
134 |
+
|
135 |
+
# Environments
|
136 |
+
.env
|
137 |
+
.venv
|
138 |
+
env/
|
139 |
+
venv/
|
140 |
+
ENV/
|
141 |
+
env.bak/
|
142 |
+
venv.bak/
|
143 |
+
|
144 |
+
# Spyder project settings
|
145 |
+
.spyderproject
|
146 |
+
.spyproject
|
147 |
+
|
148 |
+
# Rope project settings
|
149 |
+
.ropeproject
|
150 |
+
|
151 |
+
# mkdocs documentation
|
152 |
+
/site
|
153 |
+
|
154 |
+
# mypy
|
155 |
+
.mypy_cache/
|
156 |
+
.dmypy.json
|
157 |
+
dmypy.json
|
158 |
+
|
159 |
+
# Pyre type checker
|
160 |
+
.pyre/
|
161 |
+
|
162 |
+
# pytype static type analyzer
|
163 |
+
.pytype/
|
164 |
+
|
165 |
+
# Cython debug symbols
|
166 |
+
cython_debug/
|
167 |
+
|
168 |
+
# PyCharm
|
169 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
170 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
171 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
172 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
173 |
+
#.idea/
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [2023] [1lint]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
README.md
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Style ControlNet
|
3 |
+
emoji: ❅
|
4 |
+
colorFrom: gray
|
5 |
+
colorTo: green
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 3.30.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: True
|
10 |
+
license: openrail
|
11 |
+
---
|
12 |
+
|
13 |
+
# ControlStyle
|
14 |
+
Proof of concept for controlling Stable Diffusion image style using a ControlNet.
|
15 |
+
|
16 |
+
| ![](./examples/blue_eyes.gif) | ![](./examples/blue_eyes.png) |
|
17 |
+
| ------------- | ------------- |
|
18 |
+
|
19 |
+
`prompt`: "beautiful woman with blue eyes", `controlnet_prompt`: "1girl, blue eyes"
|
20 |
+
|
21 |
+
| ![](./examples/mountains.gif) | ![](./examples/mountains.png) |
|
22 |
+
| ------------- | ------------- |
|
23 |
+
|
24 |
+
`prompt` and `controlnet_prompt`: "best quality, masterpiece, Dark hair, dark eyes, upper body, sun flare, outdoors, mountain, valley, sky. clouds, smiling"
|
25 |
+
|
26 |
+
`controlnet_conditioning_scale` increments by 0.1 from 0 to 1, left to right.
|
27 |
+
|
28 |
+
|
29 |
+
## Try Style Controlnet with A1111 WebUI
|
30 |
+
|
31 |
+
![](./examples/zerohint_grid.png)
|
32 |
+
![](./examples/hint_grid.png)
|
33 |
+
### Quick start: download the anime controlnets [here](https://huggingface.co/lint/anime_control/tree/main),
|
34 |
+
|
35 |
+
Root folder has controlnets in Diffusers format, A1111_weights has controlnets for use with [A1111 Webui Controlnet Extension](https://github.com/Mikubill/sd-webui-controlnet). More details at the [HF repo page](https://huggingface.co/lint/anime_control).
|
36 |
+
|
37 |
+
## Quick Start Training
|
38 |
+
|
39 |
+
For a basic training example with HF Accelerate, run the following
|
40 |
+
```
|
41 |
+
pip install -r requirements.txt
|
42 |
+
python quickstart_train.py
|
43 |
+
```
|
44 |
+
By default, the script will download pipeline weights and an image dataset from HF Hub.
|
45 |
+
The base stable diffusion checkpoint and controlnet weights can either be in HF diffusers format or the original stable diffusion pytorch-lightning format (inferred based on whether destination is file or not)
|
46 |
+
|
47 |
+
Use the `convert_state_dict.sh` to convert the trained controlnet state dict from `diffusers` format to one compatible with the [A1111 controlnet extension](https://github.com/Mikubill/sd-webui-controlnet)
|
48 |
+
|
49 |
+
## Style Controlnet Web UI
|
50 |
+
|
51 |
+
Launch the Web UI locally with
|
52 |
+
```
|
53 |
+
python app.py
|
54 |
+
```
|
55 |
+
|
56 |
+
(My Hf Spaces below are currently out of date, I will fix them soon once I have time)
|
57 |
+
|
58 |
+
Try the WebUI hosted on HF Spaces at https://huggingface.co/spaces/lint/anime_controlnet
|
59 |
+
![](./examples/controlstyle_ui.png)
|
60 |
+
|
61 |
+
|
62 |
+
WebUI also supports basic training
|
63 |
+
![](./examples/training_ui.png)
|
64 |
+
|
65 |
+
|
66 |
+
## ControlNet for Style
|
67 |
+
|
68 |
+
Lvmin introduced the [Controlnet](https://github.com/lllyasviel/ControlNet) to use a cloned Stable Diffusion UNet to introduce external conditioning, such as body poses/sketch lines, to guide Stable Diffusion generation with fantastic results.
|
69 |
+
|
70 |
+
I thought his approach might also work for introducing different styles (i.e. add anime style), in guiding the image generation process. Unlike the original controlnets, I initialized the controlnet weights from a distinct UNet (`andite/anything-v4.5`), and predominantly trained without any controlnet conditioning image on a synthetic anime dataset (`lint/anybooru`) distinct from the base model. Then the main controlnet weights were frozen, the input hint block weights added back in and trained on the same dataset using canny image processing to generate the controlnet conditioning image.
|
71 |
+
|
72 |
+
I originally trained the anime style controlnets without any controlnet conditioning image, so that the controlnet would focus on adding anime style rather than structure to the image. I have these weights saved at https://huggingface.co/lint/anime_styler/tree/main/A1111_webui_weights, however they need to be used with my [fork](https://github.com/1lint/sd-webui-controlnet) of the controlnet extension, which has very minor changes allow the user to load the controlnet without the input hint block weights, and pass None as a valid controlnet "conditioning".
|
73 |
+
|
74 |
+
Recently I added back in the input hint processing module, and trained only the controlnet input hint blocks on canny image generation. So the models in this repository are now just like regular controlnets, except for having a different initialization and training process. They can be used just like a regular controlnet, but the vast majority of the weights were trained on adding anime style, with just the input hint blocks trained on using the controlnet conditioning image. Though it seems to work alright from my limited testing so far, expect the canny image guidance to be weak so combine with original canny image controlnet as needed.
|
75 |
+
|
76 |
+
Since the main controlnet weights were trained without any canny image conditioning, they can (and were intended to be) used without any controlnet conditioning image. However the existing A1111 Controlnet Extension expects the user to always pass a controlnet conditioning image, otherwise it will trigger an error. However you can pass a black square as the "conditioning image", which will add some unexpected random noise to the image due to the input hint block `bias` weights, however the noise is small enough that the controlnet still appears to "work".
|
app.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src import demo
|
2 |
+
from multiprocessing import cpu_count
|
3 |
+
|
4 |
+
demo.queue(concurrency_count=cpu_count()).launch()
|
configs/controlnet_config.json
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_class_name": "ControlNetModel",
|
3 |
+
"_diffusers_version": "0.14.0.dev0",
|
4 |
+
"act_fn": "silu",
|
5 |
+
"attention_head_dim": 8,
|
6 |
+
"block_out_channels": [
|
7 |
+
320,
|
8 |
+
640,
|
9 |
+
1280,
|
10 |
+
1280
|
11 |
+
],
|
12 |
+
"class_embed_type": null,
|
13 |
+
"conditioning_embedding_out_channels": [
|
14 |
+
16,
|
15 |
+
32,
|
16 |
+
96,
|
17 |
+
256
|
18 |
+
],
|
19 |
+
"controlnet_conditioning_channel_order": "rgb",
|
20 |
+
"cross_attention_dim": 768,
|
21 |
+
"down_block_types": [
|
22 |
+
"CrossAttnDownBlock2D",
|
23 |
+
"CrossAttnDownBlock2D",
|
24 |
+
"CrossAttnDownBlock2D",
|
25 |
+
"DownBlock2D"
|
26 |
+
],
|
27 |
+
"downsample_padding": 1,
|
28 |
+
"flip_sin_to_cos": true,
|
29 |
+
"freq_shift": 0,
|
30 |
+
"in_channels": 4,
|
31 |
+
"layers_per_block": 2,
|
32 |
+
"mid_block_scale_factor": 1,
|
33 |
+
"norm_eps": 1e-05,
|
34 |
+
"norm_num_groups": 32,
|
35 |
+
"num_class_embeds": null,
|
36 |
+
"only_cross_attention": false,
|
37 |
+
"projection_class_embeddings_input_dim": null,
|
38 |
+
"resnet_time_scale_shift": "default",
|
39 |
+
"upcast_attention": false,
|
40 |
+
"use_linear_projection": false
|
41 |
+
}
|
convert_state_dict.sh
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# converts controlnet state dict saved in diffusers format to original stable diffusion controlnet format that can be used with the A1111 controlnet extension
|
3 |
+
|
4 |
+
export INPUT_PATH="/home/user/style_controlnet/models/deliberate_v2_animestyler/checkpoint-332228/diffusion_pytorch_model.safetensors"
|
5 |
+
|
6 |
+
export OUTPUT_PATH="models/A1111_weights/anime_styler-deliberate-v0.1.safetensors"
|
7 |
+
|
8 |
+
python src/convert_sd.py --model_path="$INPUT_PATH" --checkpoint_path="$OUTPUT_PATH" --is_controlnet --half --to_safetensors
|
main.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from argparse import Namespace
|
3 |
+
from multiprocessing import cpu_count
|
4 |
+
from src.lab import Lab
|
5 |
+
|
6 |
+
args = Namespace(
|
7 |
+
|
8 |
+
pretrained_model_name_or_path="lint/liquidfix",
|
9 |
+
controlnet_weights_path="lint/anime_control/anime_merge",
|
10 |
+
#controlnet_weights_path=None, #
|
11 |
+
vae_path="lint/anime_vae",
|
12 |
+
|
13 |
+
# dataset args
|
14 |
+
train_data_dir="/mnt/g/data/anybooru/train",
|
15 |
+
valid_data_dir="/mnt/g/data/anybooru/valid",
|
16 |
+
resolution=512,
|
17 |
+
from_hf_hub=False,
|
18 |
+
controlnet_hint_key="canny", # set this to "canny" to train with canny hint, or None to pass
|
19 |
+
|
20 |
+
# training args
|
21 |
+
# options are ["zero convolutions", "input hint blocks"], otherwise trains whole controlnet
|
22 |
+
training_stage = "",
|
23 |
+
learning_rate=5e-6,
|
24 |
+
num_train_epochs=1000,
|
25 |
+
max_train_steps=None,
|
26 |
+
seed=3434554,
|
27 |
+
max_grad_norm=1.0,
|
28 |
+
gradient_accumulation_steps=1,
|
29 |
+
|
30 |
+
# VRAM args
|
31 |
+
batch_size=1,
|
32 |
+
mixed_precision="fp16", # set to "fp16" for mixed-precision training.
|
33 |
+
gradient_checkpointing=True, # set this to True to lower the memory usage.
|
34 |
+
use_8bit_adam=True, # use 8bit optimizer from bitsandbytes
|
35 |
+
enable_xformers_memory_efficient_attention=True,
|
36 |
+
allow_tf32=True,
|
37 |
+
dataloader_num_workers=cpu_count(),
|
38 |
+
|
39 |
+
# logging args
|
40 |
+
output_dir="./models",
|
41 |
+
report_to='tensorboard',
|
42 |
+
image_logging_steps=600, # disabled when 0. costs additional VRAM to log images
|
43 |
+
save_whole_pipeline=True,
|
44 |
+
checkpointing_steps=6000,
|
45 |
+
)
|
46 |
+
|
47 |
+
if __name__ == '__main__':
|
48 |
+
lab = Lab(args)
|
49 |
+
lab.train(args.num_train_epochs)
|
quickstart_train.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from argparse import Namespace
|
3 |
+
from multiprocessing import cpu_count
|
4 |
+
from src.lab import Lab
|
5 |
+
|
6 |
+
# runs on 10GB VRAM GPU (RTX 3080)
|
7 |
+
args = Namespace(
|
8 |
+
|
9 |
+
pretrained_model_name_or_path="lint/liquidfix",
|
10 |
+
controlnet_weights_path="lint/anime_control/anime_merge",
|
11 |
+
#controlnet_weights_path=None, #
|
12 |
+
vae_path="lint/anime_vae",
|
13 |
+
|
14 |
+
# dataset args
|
15 |
+
train_data_dir="lint/anybooru",
|
16 |
+
valid_data_dir="",
|
17 |
+
resolution=512,
|
18 |
+
from_hf_hub=True,
|
19 |
+
controlnet_hint_key="canny", # set this to "canny" to train with canny hint, or None to pass
|
20 |
+
|
21 |
+
# training args
|
22 |
+
# options are ["zero convolutions", "input hint blocks"], otherwise trains whole controlnet
|
23 |
+
training_stage = "",
|
24 |
+
learning_rate=5e-6,
|
25 |
+
num_train_epochs=1000,
|
26 |
+
max_train_steps=None,
|
27 |
+
seed=3434554,
|
28 |
+
max_grad_norm=1.0,
|
29 |
+
gradient_accumulation_steps=1,
|
30 |
+
|
31 |
+
# VRAM args
|
32 |
+
batch_size=1,
|
33 |
+
mixed_precision="fp16", # set to "fp16" for mixed-precision training.
|
34 |
+
gradient_checkpointing=True, # set this to True to lower the memory usage.
|
35 |
+
use_8bit_adam=True, # use 8bit optimizer from bitsandbytes
|
36 |
+
enable_xformers_memory_efficient_attention=True,
|
37 |
+
allow_tf32=True,
|
38 |
+
dataloader_num_workers=cpu_count(),
|
39 |
+
|
40 |
+
# logging args
|
41 |
+
output_dir="./models",
|
42 |
+
report_to='tensorboard',
|
43 |
+
image_logging_steps=600, # disabled when 0. costs additional VRAM to log images
|
44 |
+
save_whole_pipeline=True,
|
45 |
+
checkpointing_steps=6000,
|
46 |
+
)
|
47 |
+
|
48 |
+
if __name__ == '__main__':
|
49 |
+
lab = Lab(args)
|
50 |
+
lab.train(args.num_train_epochs)
|
requirements.txt
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
accelerate==0.18.0
|
2 |
+
datasets>=2.10.0
|
3 |
+
diffusers==0.16.1
|
4 |
+
gradio>=3.28.3
|
5 |
+
huggingface_hub>=0.14.1
|
6 |
+
numpy
|
7 |
+
packaging
|
8 |
+
Pillow
|
9 |
+
torch
|
10 |
+
torchvision
|
11 |
+
tqdm
|
12 |
+
transformers>=4.25.1
|
13 |
+
omegaconf>=2.2.3
|
14 |
+
opencv_contrib_python==4.6.0.66
|
15 |
+
safetensors>=0.2.6
|
16 |
+
xformers==0.0.17.dev466
|
17 |
+
bitsandbytes
|
18 |
+
tensorboard>=2.12.0
|
src/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .app import demo
|
2 |
+
from .lab import Lab
|
src/app.py
ADDED
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from multiprocessing import cpu_count
|
3 |
+
|
4 |
+
from src.ui_shared import (
|
5 |
+
model_ids,
|
6 |
+
scheduler_names,
|
7 |
+
default_scheduler,
|
8 |
+
controlnet_ids,
|
9 |
+
assets_directory,
|
10 |
+
)
|
11 |
+
|
12 |
+
from src.ui_functions import generate, run_training
|
13 |
+
|
14 |
+
default_img_size = 512
|
15 |
+
|
16 |
+
with open(f"{assets_directory}/header.html") as fp:
|
17 |
+
header = fp.read()
|
18 |
+
|
19 |
+
with open(f"{assets_directory}/footer.html") as fp:
|
20 |
+
footer = fp.read()
|
21 |
+
|
22 |
+
|
23 |
+
theme = gr.themes.Soft(
|
24 |
+
primary_hue="blue",
|
25 |
+
neutral_hue="slate",
|
26 |
+
)
|
27 |
+
|
28 |
+
from gradio.themes.builder_app import css
|
29 |
+
|
30 |
+
with gr.Blocks(theme=theme) as demo:
|
31 |
+
|
32 |
+
gr.HTML(header)
|
33 |
+
|
34 |
+
with gr.Row():
|
35 |
+
with gr.Column(scale=70):
|
36 |
+
prompt = gr.Textbox(
|
37 |
+
label="Prompt", placeholder="Press <Shift+Enter> to generate", lines=2
|
38 |
+
)
|
39 |
+
neg_prompt = gr.Textbox(label="Negative Prompt", placeholder="", lines=2)
|
40 |
+
|
41 |
+
with gr.Row():
|
42 |
+
controlnet_prompt = gr.Textbox(
|
43 |
+
label="Controlnet Prompt",
|
44 |
+
placeholder="If empty, defaults to base `Prompt`",
|
45 |
+
lines=2,
|
46 |
+
)
|
47 |
+
|
48 |
+
controlnet_negative_prompt = gr.Textbox(
|
49 |
+
label="Controlnet Negative Prompt",
|
50 |
+
placeholder="If empty, defaults to base `Negative Prompt`",
|
51 |
+
lines=2,
|
52 |
+
)
|
53 |
+
|
54 |
+
with gr.Column(scale=30):
|
55 |
+
model_name = gr.Dropdown(
|
56 |
+
label="Model", choices=model_ids, value=model_ids[0]
|
57 |
+
)
|
58 |
+
controlnet_name = gr.Dropdown(
|
59 |
+
label="Controlnet", choices=controlnet_ids, value=controlnet_ids[0]
|
60 |
+
)
|
61 |
+
scheduler_name = gr.Dropdown(
|
62 |
+
label="Scheduler", choices=scheduler_names, value=default_scheduler
|
63 |
+
)
|
64 |
+
generate_button = gr.Button(value="Generate", variant="primary")
|
65 |
+
|
66 |
+
with gr.Row():
|
67 |
+
with gr.Column():
|
68 |
+
with gr.Tab("Inference") as tab:
|
69 |
+
|
70 |
+
guidance_image = gr.Image(
|
71 |
+
label="Guidance Image",
|
72 |
+
source="upload",
|
73 |
+
tool="editor",
|
74 |
+
type="pil",
|
75 |
+
).style(height=256)
|
76 |
+
|
77 |
+
with gr.Row():
|
78 |
+
controlnet_cond_scale = gr.Slider(
|
79 |
+
label="Controlnet Weight",
|
80 |
+
value=0.5,
|
81 |
+
minimum=0.0,
|
82 |
+
maximum=1.0,
|
83 |
+
step=0.1,
|
84 |
+
)
|
85 |
+
|
86 |
+
with gr.Row():
|
87 |
+
batch_size = gr.Slider(
|
88 |
+
label="Batch Size", value=1, minimum=1, maximum=8, step=1
|
89 |
+
)
|
90 |
+
seed = gr.Slider(-1, 2147483647, label="Seed", value=-1, step=1)
|
91 |
+
|
92 |
+
with gr.Row():
|
93 |
+
guidance = gr.Slider(
|
94 |
+
label="Guidance scale", value=7.5, minimum=0, maximum=20
|
95 |
+
)
|
96 |
+
steps = gr.Slider(
|
97 |
+
label="Steps", value=20, minimum=1, maximum=100, step=1
|
98 |
+
)
|
99 |
+
|
100 |
+
with gr.Row():
|
101 |
+
width = gr.Slider(
|
102 |
+
label="Width",
|
103 |
+
value=default_img_size,
|
104 |
+
minimum=64,
|
105 |
+
maximum=1024,
|
106 |
+
step=32,
|
107 |
+
)
|
108 |
+
height = gr.Slider(
|
109 |
+
label="Height",
|
110 |
+
value=default_img_size,
|
111 |
+
minimum=64,
|
112 |
+
maximum=1024,
|
113 |
+
step=32,
|
114 |
+
)
|
115 |
+
|
116 |
+
|
117 |
+
with gr.Tab("Train Style ControlNet") as tab:
|
118 |
+
with gr.Row():
|
119 |
+
train_batch_size = gr.Slider(
|
120 |
+
label="Training Batch Size",
|
121 |
+
minimum=1,
|
122 |
+
maximum=8,
|
123 |
+
step=1,
|
124 |
+
value=1,
|
125 |
+
)
|
126 |
+
|
127 |
+
gradient_accumulation_steps = gr.Slider(
|
128 |
+
label="Gradient Accumulation steps",
|
129 |
+
minimum=1,
|
130 |
+
maximum=6,
|
131 |
+
step=1,
|
132 |
+
value=4,
|
133 |
+
)
|
134 |
+
|
135 |
+
with gr.Row():
|
136 |
+
max_train_steps = gr.Number(
|
137 |
+
label="Total training steps", value=16000
|
138 |
+
)
|
139 |
+
train_learning_rate = gr.Number(label="Learning Rate", value=5.0e-6)
|
140 |
+
|
141 |
+
with gr.Row():
|
142 |
+
checkpointing_steps = gr.Number(
|
143 |
+
label="Steps between saving checkpoints", value=4000
|
144 |
+
)
|
145 |
+
image_logging_steps = gr.Number(
|
146 |
+
label="Steps between logging example images (pass 0 to disable)",
|
147 |
+
value=0,
|
148 |
+
)
|
149 |
+
|
150 |
+
with gr.Row():
|
151 |
+
train_data_dir = gr.Textbox(
|
152 |
+
label=f"Path to training image folder",
|
153 |
+
value="lint/anybooru",
|
154 |
+
)
|
155 |
+
valid_data_dir = gr.Textbox(
|
156 |
+
label=f"Path to validation image folder",
|
157 |
+
value="",
|
158 |
+
)
|
159 |
+
|
160 |
+
with gr.Row():
|
161 |
+
controlnet_weights_path = gr.Textbox(
|
162 |
+
label=f"Repo for initializing Controlnet Weights",
|
163 |
+
value="andite/anything-v4.0/unet",
|
164 |
+
)
|
165 |
+
output_dir = gr.Textbox(
|
166 |
+
label=f"Output directory for trained weights", value="./models"
|
167 |
+
)
|
168 |
+
|
169 |
+
with gr.Row():
|
170 |
+
train_whole_controlnet = gr.Checkbox(
|
171 |
+
label="Train whole controlnet", value=True
|
172 |
+
)
|
173 |
+
save_whole_pipeline = gr.Checkbox(
|
174 |
+
label="Save whole pipeline", value=True
|
175 |
+
)
|
176 |
+
|
177 |
+
training_button = gr.Button(
|
178 |
+
value="Train Style ControlNet", variant="primary"
|
179 |
+
)
|
180 |
+
|
181 |
+
training_status = gr.Text(label="Training Status")
|
182 |
+
|
183 |
+
|
184 |
+
with gr.Column():
|
185 |
+
gallery = gr.Gallery(
|
186 |
+
label="Generated images", show_label=False, elem_id="gallery"
|
187 |
+
).style(height=default_img_size, grid=2)
|
188 |
+
|
189 |
+
generation_details = gr.Markdown()
|
190 |
+
|
191 |
+
# pipe_kwargs = gr.Textbox(label="Pipe kwargs", value="{\n\t\n}", visible=False)
|
192 |
+
|
193 |
+
# if torch.cuda.is_available():
|
194 |
+
# giga = 2**30
|
195 |
+
# vram_guage = gr.Slider(0, torch.cuda.memory_reserved(0)/giga, label='VRAM Allocated to Reserved (GB)', value=0, step=1)
|
196 |
+
# demo.load(lambda : torch.cuda.memory_allocated(0)/giga, inputs=[], outputs=vram_guage, every=0.5, show_progress=False)
|
197 |
+
|
198 |
+
# gr.HTML(footer)
|
199 |
+
|
200 |
+
inputs = [
|
201 |
+
model_name,
|
202 |
+
guidance_image,
|
203 |
+
controlnet_name,
|
204 |
+
scheduler_name,
|
205 |
+
prompt,
|
206 |
+
guidance,
|
207 |
+
steps,
|
208 |
+
batch_size,
|
209 |
+
width,
|
210 |
+
height,
|
211 |
+
seed,
|
212 |
+
neg_prompt,
|
213 |
+
controlnet_prompt,
|
214 |
+
controlnet_negative_prompt,
|
215 |
+
controlnet_cond_scale,
|
216 |
+
# pipe_kwargs,
|
217 |
+
]
|
218 |
+
outputs = [gallery, generation_details]
|
219 |
+
|
220 |
+
prompt.submit(generate, inputs=inputs, outputs=outputs)
|
221 |
+
generate_button.click(generate, inputs=inputs, outputs=outputs)
|
222 |
+
|
223 |
+
training_inputs = [
|
224 |
+
model_name,
|
225 |
+
controlnet_weights_path,
|
226 |
+
train_data_dir,
|
227 |
+
valid_data_dir,
|
228 |
+
train_batch_size,
|
229 |
+
train_whole_controlnet,
|
230 |
+
gradient_accumulation_steps,
|
231 |
+
max_train_steps,
|
232 |
+
train_learning_rate,
|
233 |
+
output_dir,
|
234 |
+
checkpointing_steps,
|
235 |
+
image_logging_steps,
|
236 |
+
save_whole_pipeline,
|
237 |
+
]
|
238 |
+
|
239 |
+
training_button.click(
|
240 |
+
run_training,
|
241 |
+
inputs=training_inputs,
|
242 |
+
outputs=[training_status],
|
243 |
+
)
|
244 |
+
|
245 |
+
# from gradio.themes.builder_app
|
246 |
+
demo.load(
|
247 |
+
None,
|
248 |
+
None,
|
249 |
+
None,
|
250 |
+
_js="""() => {
|
251 |
+
if (document.querySelectorAll('.dark').length) {
|
252 |
+
document.querySelectorAll('.dark').forEach(el => el.classList.remove('dark'));
|
253 |
+
} else {
|
254 |
+
document.querySelector('body').classList.add('dark');
|
255 |
+
}
|
256 |
+
}""",
|
257 |
+
)
|
258 |
+
|
259 |
+
if __name__ == "__main__":
|
260 |
+
demo.queue(concurrency_count=cpu_count()).launch()
|
src/controlnet_pipe.py
ADDED
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import *
|
2 |
+
|
3 |
+
class ControlNetPipe(StableDiffusionControlNetPipeline):
|
4 |
+
|
5 |
+
# copied from superclass and modified to accept controlnet prompt independent of base prompt
|
6 |
+
@torch.no_grad()
|
7 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
8 |
+
def __call__(
|
9 |
+
self,
|
10 |
+
prompt: Union[str, List[str]] = None,
|
11 |
+
image: Union[torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]] = None,
|
12 |
+
height: Optional[int] = None,
|
13 |
+
width: Optional[int] = None,
|
14 |
+
num_inference_steps: int = 50,
|
15 |
+
guidance_scale: float = 7.5,
|
16 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
17 |
+
num_images_per_prompt: Optional[int] = 1,
|
18 |
+
eta: float = 0.0,
|
19 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
20 |
+
latents: Optional[torch.FloatTensor] = None,
|
21 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
22 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
23 |
+
output_type: Optional[str] = "pil",
|
24 |
+
return_dict: bool = True,
|
25 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
26 |
+
callback_steps: int = 1,
|
27 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
28 |
+
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
|
29 |
+
guess_mode: bool = False,
|
30 |
+
controlnet_prompt_embeds = None,
|
31 |
+
):
|
32 |
+
r"""
|
33 |
+
Function invoked when calling the pipeline for generation.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
prompt (`str` or `List[str]`, *optional*):
|
37 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
38 |
+
instead.
|
39 |
+
image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`,
|
40 |
+
`List[List[torch.FloatTensor]]`, or `List[List[PIL.Image.Image]]`):
|
41 |
+
The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
|
42 |
+
the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can
|
43 |
+
also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If
|
44 |
+
height and/or width are passed, `image` is resized according to them. If multiple ControlNets are
|
45 |
+
specified in init, images must be passed as a list such that each element of the list can be correctly
|
46 |
+
batched for input to a single controlnet.
|
47 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
48 |
+
The height in pixels of the generated image.
|
49 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
50 |
+
The width in pixels of the generated image.
|
51 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
52 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
53 |
+
expense of slower inference.
|
54 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
55 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
56 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
57 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
58 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
59 |
+
usually at the expense of lower image quality.
|
60 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
61 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
62 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
63 |
+
less than `1`).
|
64 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
65 |
+
The number of images to generate per prompt.
|
66 |
+
eta (`float`, *optional*, defaults to 0.0):
|
67 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
68 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
69 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
70 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
71 |
+
to make generation deterministic.
|
72 |
+
latents (`torch.FloatTensor`, *optional*):
|
73 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
74 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
75 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
76 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
77 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
78 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
79 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
80 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
81 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
82 |
+
argument.
|
83 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
84 |
+
The output format of the generate image. Choose between
|
85 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
86 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
87 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
88 |
+
plain tuple.
|
89 |
+
callback (`Callable`, *optional*):
|
90 |
+
A function that will be called every `callback_steps` steps during inference. The function will be
|
91 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
92 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
93 |
+
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
94 |
+
called at every step.
|
95 |
+
cross_attention_kwargs (`dict`, *optional*):
|
96 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
97 |
+
`self.processor` in
|
98 |
+
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
|
99 |
+
controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
|
100 |
+
The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
|
101 |
+
to the residual in the original unet. If multiple ControlNets are specified in init, you can set the
|
102 |
+
corresponding scale as a list.
|
103 |
+
guess_mode (`bool`, *optional*, defaults to `False`):
|
104 |
+
In this mode, the ControlNet encoder will try best to recognize the content of the input image even if
|
105 |
+
you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.
|
106 |
+
|
107 |
+
Examples:
|
108 |
+
|
109 |
+
Returns:
|
110 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
111 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
112 |
+
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
113 |
+
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
114 |
+
(nsfw) content, according to the `safety_checker`.
|
115 |
+
"""
|
116 |
+
# 0. Default height and width to unet
|
117 |
+
height, width = self._default_height_width(height, width, image)
|
118 |
+
|
119 |
+
# 1. Check inputs. Raise error if not correct
|
120 |
+
self.check_inputs(
|
121 |
+
prompt,
|
122 |
+
image,
|
123 |
+
height,
|
124 |
+
width,
|
125 |
+
callback_steps,
|
126 |
+
negative_prompt,
|
127 |
+
prompt_embeds,
|
128 |
+
negative_prompt_embeds,
|
129 |
+
controlnet_conditioning_scale,
|
130 |
+
)
|
131 |
+
|
132 |
+
# 2. Define call parameters
|
133 |
+
if prompt is not None and isinstance(prompt, str):
|
134 |
+
batch_size = 1
|
135 |
+
elif prompt is not None and isinstance(prompt, list):
|
136 |
+
batch_size = len(prompt)
|
137 |
+
else:
|
138 |
+
batch_size = prompt_embeds.shape[0]
|
139 |
+
|
140 |
+
device = self._execution_device
|
141 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
142 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
143 |
+
# corresponds to doing no classifier free guidance.
|
144 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
145 |
+
|
146 |
+
if isinstance(self.controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
|
147 |
+
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(self.controlnet.nets)
|
148 |
+
|
149 |
+
# 3. Encode input prompt
|
150 |
+
prompt_embeds = self._encode_prompt(
|
151 |
+
prompt,
|
152 |
+
device,
|
153 |
+
num_images_per_prompt,
|
154 |
+
do_classifier_free_guidance,
|
155 |
+
negative_prompt,
|
156 |
+
prompt_embeds=prompt_embeds,
|
157 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
158 |
+
)
|
159 |
+
|
160 |
+
# 4. Prepare image
|
161 |
+
if isinstance(self.controlnet, ControlNetModel):
|
162 |
+
image = self.prepare_image(
|
163 |
+
image=image,
|
164 |
+
width=width,
|
165 |
+
height=height,
|
166 |
+
batch_size=batch_size * num_images_per_prompt,
|
167 |
+
num_images_per_prompt=num_images_per_prompt,
|
168 |
+
device=device,
|
169 |
+
dtype=self.controlnet.dtype,
|
170 |
+
do_classifier_free_guidance=do_classifier_free_guidance,
|
171 |
+
guess_mode=guess_mode,
|
172 |
+
)
|
173 |
+
elif isinstance(self.controlnet, MultiControlNetModel):
|
174 |
+
images = []
|
175 |
+
|
176 |
+
for image_ in image:
|
177 |
+
image_ = self.prepare_image(
|
178 |
+
image=image_,
|
179 |
+
width=width,
|
180 |
+
height=height,
|
181 |
+
batch_size=batch_size * num_images_per_prompt,
|
182 |
+
num_images_per_prompt=num_images_per_prompt,
|
183 |
+
device=device,
|
184 |
+
dtype=self.controlnet.dtype,
|
185 |
+
do_classifier_free_guidance=do_classifier_free_guidance,
|
186 |
+
guess_mode=guess_mode,
|
187 |
+
)
|
188 |
+
|
189 |
+
images.append(image_)
|
190 |
+
|
191 |
+
image = images
|
192 |
+
else:
|
193 |
+
assert False
|
194 |
+
|
195 |
+
# 5. Prepare timesteps
|
196 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
197 |
+
timesteps = self.scheduler.timesteps
|
198 |
+
|
199 |
+
# 6. Prepare latent variables
|
200 |
+
num_channels_latents = self.unet.config.in_channels
|
201 |
+
latents = self.prepare_latents(
|
202 |
+
batch_size * num_images_per_prompt,
|
203 |
+
num_channels_latents,
|
204 |
+
height,
|
205 |
+
width,
|
206 |
+
prompt_embeds.dtype,
|
207 |
+
device,
|
208 |
+
generator,
|
209 |
+
latents,
|
210 |
+
)
|
211 |
+
|
212 |
+
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
213 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
214 |
+
|
215 |
+
if not controlnet_prompt_embeds:
|
216 |
+
controlnet_prompt_embeds = prompt_embeds
|
217 |
+
|
218 |
+
# 8. Denoising loop
|
219 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
220 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
221 |
+
for i, t in enumerate(timesteps):
|
222 |
+
# expand the latents if we are doing classifier free guidance
|
223 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
224 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
225 |
+
|
226 |
+
# controlnet(s) inference
|
227 |
+
if guess_mode and do_classifier_free_guidance:
|
228 |
+
# Infer ControlNet only for the conditional batch.
|
229 |
+
controlnet_latent_model_input = latents
|
230 |
+
controlnet_prompt_embeds = controlnet_prompt_embeds.chunk(2)[1]
|
231 |
+
else:
|
232 |
+
controlnet_latent_model_input = latent_model_input
|
233 |
+
controlnet_prompt_embeds = controlnet_prompt_embeds
|
234 |
+
|
235 |
+
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
236 |
+
controlnet_latent_model_input,
|
237 |
+
t,
|
238 |
+
encoder_hidden_states=controlnet_prompt_embeds,
|
239 |
+
controlnet_cond=image,
|
240 |
+
conditioning_scale=controlnet_conditioning_scale,
|
241 |
+
guess_mode=guess_mode,
|
242 |
+
return_dict=False,
|
243 |
+
)
|
244 |
+
|
245 |
+
if guess_mode and do_classifier_free_guidance:
|
246 |
+
# Infered ControlNet only for the conditional batch.
|
247 |
+
# To apply the output of ControlNet to both the unconditional and conditional batches,
|
248 |
+
# add 0 to the unconditional batch to keep it unchanged.
|
249 |
+
down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
|
250 |
+
mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
|
251 |
+
|
252 |
+
# predict the noise residual
|
253 |
+
noise_pred = self.unet(
|
254 |
+
latent_model_input,
|
255 |
+
t,
|
256 |
+
encoder_hidden_states=prompt_embeds,
|
257 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
258 |
+
down_block_additional_residuals=down_block_res_samples,
|
259 |
+
mid_block_additional_residual=mid_block_res_sample,
|
260 |
+
).sample
|
261 |
+
|
262 |
+
# perform guidance
|
263 |
+
if do_classifier_free_guidance:
|
264 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
265 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
266 |
+
|
267 |
+
# compute the previous noisy sample x_t -> x_t-1
|
268 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
269 |
+
|
270 |
+
# call the callback, if provided
|
271 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
272 |
+
progress_bar.update()
|
273 |
+
if callback is not None and i % callback_steps == 0:
|
274 |
+
callback(i, t, latents)
|
275 |
+
|
276 |
+
# If we do sequential model offloading, let's offload unet and controlnet
|
277 |
+
# manually for max memory savings
|
278 |
+
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
279 |
+
self.unet.to("cpu")
|
280 |
+
self.controlnet.to("cpu")
|
281 |
+
torch.cuda.empty_cache()
|
282 |
+
|
283 |
+
if output_type == "latent":
|
284 |
+
image = latents
|
285 |
+
has_nsfw_concept = None
|
286 |
+
elif output_type == "pil":
|
287 |
+
# 8. Post-processing
|
288 |
+
image = self.decode_latents(latents)
|
289 |
+
|
290 |
+
# 9. Run safety checker
|
291 |
+
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
292 |
+
|
293 |
+
# 10. Convert to PIL
|
294 |
+
image = self.numpy_to_pil(image)
|
295 |
+
else:
|
296 |
+
# 8. Post-processing
|
297 |
+
image = self.decode_latents(latents)
|
298 |
+
|
299 |
+
# 9. Run safety checker
|
300 |
+
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
301 |
+
|
302 |
+
# Offload last model to CPU
|
303 |
+
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
304 |
+
self.final_offload_hook.offload()
|
305 |
+
|
306 |
+
if not return_dict:
|
307 |
+
return (image, has_nsfw_concept)
|
308 |
+
|
309 |
+
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
src/convert_sd.py
ADDED
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint.
|
2 |
+
# *Only* converts the UNet, VAE, and Text Encoder.
|
3 |
+
# Does not convert optimizer state or any other thing.
|
4 |
+
|
5 |
+
# Originally written by jachiam at https://gist.github.com/jachiam/8a5c0b607e38fcc585168b90c686eb05
|
6 |
+
# modified by 1lint to support controlnet conversion
|
7 |
+
|
8 |
+
import argparse
|
9 |
+
import torch
|
10 |
+
from safetensors import safe_open
|
11 |
+
from safetensors.torch import save_file
|
12 |
+
from pathlib import Path
|
13 |
+
|
14 |
+
# =================#
|
15 |
+
# UNet Conversion #
|
16 |
+
# =================#
|
17 |
+
|
18 |
+
unet_conversion_map = [
|
19 |
+
# (stable-diffusion, HF Diffusers)
|
20 |
+
("time_embed.0.weight", "time_embedding.linear_1.weight"),
|
21 |
+
("time_embed.0.bias", "time_embedding.linear_1.bias"),
|
22 |
+
("time_embed.2.weight", "time_embedding.linear_2.weight"),
|
23 |
+
("time_embed.2.bias", "time_embedding.linear_2.bias"),
|
24 |
+
("input_blocks.0.0.weight", "conv_in.weight"),
|
25 |
+
("input_blocks.0.0.bias", "conv_in.bias"),
|
26 |
+
("out.0.weight", "conv_norm_out.weight"),
|
27 |
+
("out.0.bias", "conv_norm_out.bias"),
|
28 |
+
("out.2.weight", "conv_out.weight"),
|
29 |
+
("out.2.bias", "conv_out.bias"),
|
30 |
+
]
|
31 |
+
|
32 |
+
unet_conversion_map_resnet = [
|
33 |
+
# (stable-diffusion, HF Diffusers)
|
34 |
+
("in_layers.0", "norm1"),
|
35 |
+
("in_layers.2", "conv1"),
|
36 |
+
("out_layers.0", "norm2"),
|
37 |
+
("out_layers.3", "conv2"),
|
38 |
+
("emb_layers.1", "time_emb_proj"),
|
39 |
+
("skip_connection", "conv_shortcut"),
|
40 |
+
]
|
41 |
+
|
42 |
+
unet_conversion_map_layer = []
|
43 |
+
# hardcoded number of downblocks and resnets/attentions...
|
44 |
+
# would need smarter logic for other networks.
|
45 |
+
for i in range(4):
|
46 |
+
# loop over downblocks/upblocks
|
47 |
+
|
48 |
+
for j in range(2):
|
49 |
+
# loop over resnets/attentions for downblocks
|
50 |
+
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
51 |
+
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
52 |
+
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
53 |
+
|
54 |
+
if i < 3:
|
55 |
+
# no attention layers in down_blocks.3
|
56 |
+
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
57 |
+
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
58 |
+
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
59 |
+
|
60 |
+
for j in range(3):
|
61 |
+
# loop over resnets/attentions for upblocks
|
62 |
+
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
63 |
+
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
|
64 |
+
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
65 |
+
|
66 |
+
if i > 0:
|
67 |
+
# no attention layers in up_blocks.0
|
68 |
+
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
69 |
+
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
|
70 |
+
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
71 |
+
|
72 |
+
if i < 3:
|
73 |
+
# no downsample in down_blocks.3
|
74 |
+
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
75 |
+
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
76 |
+
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
77 |
+
|
78 |
+
# no upsample in up_blocks.3
|
79 |
+
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
80 |
+
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
|
81 |
+
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
82 |
+
|
83 |
+
hf_mid_atn_prefix = "mid_block.attentions.0."
|
84 |
+
sd_mid_atn_prefix = "middle_block.1."
|
85 |
+
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
86 |
+
|
87 |
+
for j in range(2):
|
88 |
+
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
89 |
+
sd_mid_res_prefix = f"middle_block.{2*j}."
|
90 |
+
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
91 |
+
|
92 |
+
|
93 |
+
def convert_unet_state_dict(unet_state_dict, is_controlnet=True):
|
94 |
+
# buyer beware: this is a *brittle* function,
|
95 |
+
# and correct output requires that all of these pieces interact in
|
96 |
+
# the exact order in which I have arranged them.
|
97 |
+
mapping = {k: k for k in unet_state_dict.keys()}
|
98 |
+
|
99 |
+
conversion_map = unet_conversion_map
|
100 |
+
if is_controlnet:
|
101 |
+
# remove output blocks from conversion mapping since controlnet doesn't have them
|
102 |
+
conversion_map = unet_conversion_map[:6]
|
103 |
+
|
104 |
+
for k, v in mapping.items():
|
105 |
+
# convert controlnet zero convolution keys
|
106 |
+
if "controlnet_down_blocks" in v:
|
107 |
+
new_key = v.replace("controlnet_down_blocks", "zero_convs")
|
108 |
+
new_key = ".0.".join(new_key.rsplit(".", 1))
|
109 |
+
mapping[k] = new_key
|
110 |
+
|
111 |
+
mapping["controlnet_mid_block.bias"] = "middle_block_out.0.bias"
|
112 |
+
mapping["controlnet_mid_block.weight"] = "middle_block_out.0.weight"
|
113 |
+
|
114 |
+
if "controlnet_cond_embedding.conv_in.weight" in mapping:
|
115 |
+
mapping[
|
116 |
+
"controlnet_cond_embedding.conv_in.weight"
|
117 |
+
] = "input_hint_block.0.weight"
|
118 |
+
mapping[
|
119 |
+
"controlnet_cond_embedding.conv_in.bias"
|
120 |
+
] = "input_hint_block.0.bias"
|
121 |
+
|
122 |
+
for i in range(6):
|
123 |
+
mapping[
|
124 |
+
f"controlnet_cond_embedding.blocks.{i}.weight"
|
125 |
+
] = f"input_hint_block.{2*(i+1)}.weight"
|
126 |
+
mapping[
|
127 |
+
f"controlnet_cond_embedding.blocks.{i}.bias"
|
128 |
+
] = f"input_hint_block.{2*(i+1)}.bias"
|
129 |
+
|
130 |
+
mapping[
|
131 |
+
"controlnet_cond_embedding.conv_out.weight"
|
132 |
+
] = "input_hint_block.14.weight"
|
133 |
+
mapping[
|
134 |
+
"controlnet_cond_embedding.conv_out.bias"
|
135 |
+
] = "input_hint_block.14.bias"
|
136 |
+
|
137 |
+
for sd_name, hf_name in conversion_map:
|
138 |
+
mapping[hf_name] = sd_name
|
139 |
+
for k, v in mapping.items():
|
140 |
+
if "resnets" in k:
|
141 |
+
for sd_part, hf_part in unet_conversion_map_resnet:
|
142 |
+
v = v.replace(hf_part, sd_part)
|
143 |
+
mapping[k] = v
|
144 |
+
for k, v in mapping.items():
|
145 |
+
for sd_part, hf_part in unet_conversion_map_layer:
|
146 |
+
v = v.replace(hf_part, sd_part)
|
147 |
+
mapping[k] = v
|
148 |
+
|
149 |
+
new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
|
150 |
+
return new_state_dict
|
151 |
+
|
152 |
+
|
153 |
+
def load_state_dict(state_dict_path):
|
154 |
+
file_ext = state_dict_path.rsplit(".", 1)[-1]
|
155 |
+
|
156 |
+
if file_ext == "safetensors":
|
157 |
+
state_dict = {}
|
158 |
+
with safe_open(state_dict_path, framework="pt", device="cpu") as f:
|
159 |
+
for key in f.keys():
|
160 |
+
state_dict[key] = f.get_tensor(key)
|
161 |
+
else:
|
162 |
+
state_dict = torch.load(state_dict_path, map_location="cpu")
|
163 |
+
|
164 |
+
return state_dict
|
165 |
+
|
166 |
+
|
167 |
+
if __name__ == "__main__":
|
168 |
+
parser = argparse.ArgumentParser()
|
169 |
+
|
170 |
+
parser.add_argument(
|
171 |
+
"--model_path",
|
172 |
+
default=None,
|
173 |
+
type=str,
|
174 |
+
required=True,
|
175 |
+
help="Path to the model to convert.",
|
176 |
+
)
|
177 |
+
parser.add_argument(
|
178 |
+
"--checkpoint_path",
|
179 |
+
default=None,
|
180 |
+
type=str,
|
181 |
+
required=True,
|
182 |
+
help="Path to the output model.",
|
183 |
+
)
|
184 |
+
parser.add_argument(
|
185 |
+
"--half", action="store_true", help="Save weights in half precision."
|
186 |
+
)
|
187 |
+
parser.add_argument(
|
188 |
+
"--is_controlnet",
|
189 |
+
action="store_true",
|
190 |
+
help="Whether conversion is for controlnet or standard sd unet",
|
191 |
+
)
|
192 |
+
parser.add_argument(
|
193 |
+
"--to_safetensors",
|
194 |
+
action="store_true",
|
195 |
+
help="Whether to save state dict in safetensors format",
|
196 |
+
)
|
197 |
+
|
198 |
+
args = parser.parse_args()
|
199 |
+
|
200 |
+
assert args.model_path is not None, "Must provide a model path!"
|
201 |
+
|
202 |
+
assert args.checkpoint_path is not None, "Must provide a checkpoint path!"
|
203 |
+
|
204 |
+
unet_state_dict = load_state_dict(args.model_path)
|
205 |
+
|
206 |
+
# Convert the UNet model
|
207 |
+
unet_state_dict = convert_unet_state_dict(
|
208 |
+
unet_state_dict, is_controlnet=args.is_controlnet
|
209 |
+
)
|
210 |
+
|
211 |
+
if args.half:
|
212 |
+
unet_state_dict = {k: v.half() for k, v in unet_state_dict.items()}
|
213 |
+
|
214 |
+
Path(args.checkpoint_path).parent.mkdir(parents=True, exist_ok=True)
|
215 |
+
|
216 |
+
if args.to_safetensors:
|
217 |
+
save_file(unet_state_dict, args.checkpoint_path)
|
218 |
+
else:
|
219 |
+
torch.save(unet_state_dict, args.checkpoint_path)
|
220 |
+
|
221 |
+
print(
|
222 |
+
f"Converted {Path(args.model_path)} to original SD format at {Path(args.checkpoint_path)}"
|
223 |
+
)
|
src/data.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from PIL import Image
|
3 |
+
import torchvision
|
4 |
+
import random
|
5 |
+
|
6 |
+
from torch.utils.data import Dataset, DataLoader
|
7 |
+
from functools import partial
|
8 |
+
from multiprocessing import cpu_count
|
9 |
+
from datasets import load_dataset
|
10 |
+
|
11 |
+
import cv2
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
|
15 |
+
|
16 |
+
class PNGDataset(Dataset):
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
data_dir,
|
20 |
+
tokenizer,
|
21 |
+
from_hf_hub=False,
|
22 |
+
ucg=0.10,
|
23 |
+
resolution=(512, 512),
|
24 |
+
prompt_key="tags",
|
25 |
+
cond_key="cond",
|
26 |
+
target_key="image",
|
27 |
+
controlnet_hint_key=None,
|
28 |
+
file_extension="png",
|
29 |
+
):
|
30 |
+
super().__init__()
|
31 |
+
vars(self).update(locals())
|
32 |
+
|
33 |
+
if from_hf_hub:
|
34 |
+
self.img_paths = load_dataset(data_dir)["train"]
|
35 |
+
else:
|
36 |
+
self.img_paths = list(Path(data_dir).glob(f"*.{file_extension}"))
|
37 |
+
|
38 |
+
self.ucg = ucg
|
39 |
+
|
40 |
+
self.flip_transform = torchvision.transforms.RandomHorizontalFlip(p=0.5)
|
41 |
+
self.transforms = torchvision.transforms.Compose(
|
42 |
+
[
|
43 |
+
torchvision.transforms.Resize(resolution),
|
44 |
+
torchvision.transforms.ToTensor(),
|
45 |
+
]
|
46 |
+
)
|
47 |
+
self.normalize = torchvision.transforms.Normalize([0.5], [0.5])
|
48 |
+
|
49 |
+
def process_canny(self, image):
|
50 |
+
# code from https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/controlnet
|
51 |
+
image = np.array(image)
|
52 |
+
low_threshold, high_threshold = (100, 200)
|
53 |
+
image = cv2.Canny(image, low_threshold, high_threshold)
|
54 |
+
image = image[:, :, None]
|
55 |
+
image = np.concatenate([image, image, image], axis=2)
|
56 |
+
canny_image = Image.fromarray(image)
|
57 |
+
|
58 |
+
return canny_image
|
59 |
+
|
60 |
+
def __len__(self):
|
61 |
+
return len(self.img_paths)
|
62 |
+
|
63 |
+
def __getitem__(self, idx):
|
64 |
+
if self.from_hf_hub:
|
65 |
+
image = self.img_paths[idx]["image"]
|
66 |
+
else:
|
67 |
+
image = Image.open(self.img_paths[idx])
|
68 |
+
|
69 |
+
if self.prompt_key not in image.info:
|
70 |
+
print(f"Image {idx} lacks {self.prompt_key}, skipping to next image")
|
71 |
+
return self.__getitem__(idx + 1 % len(self))
|
72 |
+
|
73 |
+
if random.random() < self.ucg:
|
74 |
+
tags = ""
|
75 |
+
else:
|
76 |
+
tags = image.info[self.prompt_key]
|
77 |
+
|
78 |
+
# randomly flip image here so input image to canny has matching flip
|
79 |
+
image = self.flip_transform(image)
|
80 |
+
|
81 |
+
target = self.normalize(self.transforms(image))
|
82 |
+
|
83 |
+
output_dict = {self.target_key: target, self.cond_key: tags}
|
84 |
+
|
85 |
+
if self.controlnet_hint_key == "canny":
|
86 |
+
canny_image = self.transforms(self.process_canny(image))
|
87 |
+
output_dict[self.controlnet_hint_key] = canny_image
|
88 |
+
|
89 |
+
return output_dict
|
90 |
+
|
91 |
+
def collate_fn(self, samples):
|
92 |
+
prompts = torch.tensor(
|
93 |
+
[
|
94 |
+
self.tokenizer(
|
95 |
+
sample[self.cond_key],
|
96 |
+
padding="max_length",
|
97 |
+
truncation=True,
|
98 |
+
).input_ids
|
99 |
+
for sample in samples
|
100 |
+
]
|
101 |
+
)
|
102 |
+
|
103 |
+
images = torch.stack(
|
104 |
+
[sample[self.target_key] for sample in samples]
|
105 |
+
).contiguous()
|
106 |
+
|
107 |
+
batch = {
|
108 |
+
self.cond_key: prompts,
|
109 |
+
self.target_key: images,
|
110 |
+
}
|
111 |
+
|
112 |
+
if self.controlnet_hint_key is not None:
|
113 |
+
hint = torch.stack(
|
114 |
+
[sample[self.controlnet_hint_key] for sample in samples]
|
115 |
+
).contiguous()
|
116 |
+
batch[self.controlnet_hint_key] = hint
|
117 |
+
|
118 |
+
return batch
|
119 |
+
|
120 |
+
|
121 |
+
class PNGDataModule:
|
122 |
+
def __init__(
|
123 |
+
self,
|
124 |
+
batch_size=1,
|
125 |
+
num_workers=None,
|
126 |
+
persistent_workers=True,
|
127 |
+
**kwargs, # passed to dataset class
|
128 |
+
):
|
129 |
+
super().__init__()
|
130 |
+
vars(self).update(locals())
|
131 |
+
|
132 |
+
if num_workers is None:
|
133 |
+
num_workers = cpu_count() // 2
|
134 |
+
|
135 |
+
self.ds_wrapper = partial(PNGDataset, **kwargs)
|
136 |
+
|
137 |
+
self.dl_wrapper = partial(
|
138 |
+
DataLoader,
|
139 |
+
batch_size=batch_size,
|
140 |
+
num_workers=num_workers,
|
141 |
+
persistent_workers=persistent_workers,
|
142 |
+
)
|
143 |
+
|
144 |
+
def get_dataloader(self, data_dir, shuffle=False):
|
145 |
+
dataset = self.ds_wrapper(data_dir=data_dir)
|
146 |
+
dataloader = self.dl_wrapper(
|
147 |
+
dataset, shuffle=shuffle, collate_fn=dataset.collate_fn
|
148 |
+
)
|
149 |
+
return dataloader
|
src/lab.py
ADDED
@@ -0,0 +1,474 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified starting from HuggingFace diffusers train_dreambooth.py example
|
2 |
+
# https://github.com/huggingface/diffusers/blob/024c4376fb19caa85275c038f071b6e1446a5cad/examples/dreambooth/train_dreambooth.py
|
3 |
+
|
4 |
+
import os
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import torch.utils.checkpoint
|
10 |
+
from accelerate import Accelerator
|
11 |
+
from accelerate.logging import get_logger
|
12 |
+
from accelerate.utils import ProjectConfiguration, set_seed
|
13 |
+
from PIL import Image
|
14 |
+
from tqdm.auto import tqdm
|
15 |
+
|
16 |
+
from diffusers import AutoencoderKL, StableDiffusionPipeline
|
17 |
+
|
18 |
+
from torchvision.utils import make_grid
|
19 |
+
import numpy as np
|
20 |
+
|
21 |
+
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
|
22 |
+
download_from_original_stable_diffusion_ckpt,
|
23 |
+
)
|
24 |
+
|
25 |
+
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
|
26 |
+
|
27 |
+
|
28 |
+
from diffusers.schedulers import UniPCMultistepScheduler
|
29 |
+
|
30 |
+
from .data import PNGDataModule
|
31 |
+
|
32 |
+
logger = get_logger(__name__)
|
33 |
+
|
34 |
+
|
35 |
+
class Lab(Accelerator):
|
36 |
+
def __init__(self, args, control_pipe=None):
|
37 |
+
self.cond_key = "prompts"
|
38 |
+
self.target_key = "images"
|
39 |
+
self.args = args
|
40 |
+
|
41 |
+
self.output_dir = Path(args.output_dir)
|
42 |
+
logging_dir = str(self.output_dir / "logs")
|
43 |
+
|
44 |
+
accelerator_project_config = ProjectConfiguration(
|
45 |
+
logging_dir=logging_dir,
|
46 |
+
)
|
47 |
+
|
48 |
+
super().__init__(
|
49 |
+
mixed_precision=args.mixed_precision,
|
50 |
+
log_with=args.report_to,
|
51 |
+
project_config=accelerator_project_config,
|
52 |
+
)
|
53 |
+
|
54 |
+
if self.mixed_precision == "fp16":
|
55 |
+
self.weight_dtype = torch.float16
|
56 |
+
elif self.mixed_precision == "bf16":
|
57 |
+
self.weight_dtype = torch.bfloat16
|
58 |
+
else:
|
59 |
+
self.weight_dtype = torch.float32
|
60 |
+
|
61 |
+
if args.seed is not None:
|
62 |
+
set_seed(args.seed)
|
63 |
+
|
64 |
+
if control_pipe is None:
|
65 |
+
control_pipe = self.load_pipe(
|
66 |
+
args.pretrained_model_name_or_path, args.controlnet_weights_path
|
67 |
+
)
|
68 |
+
self.control_pipe = control_pipe
|
69 |
+
|
70 |
+
vae = control_pipe.vae
|
71 |
+
unet = control_pipe.unet
|
72 |
+
text_encoder = control_pipe.text_encoder
|
73 |
+
tokenizer = control_pipe.tokenizer
|
74 |
+
controlnet = (
|
75 |
+
control_pipe.controlnet if hasattr(control_pipe, "controlnet") else None
|
76 |
+
)
|
77 |
+
self.noise_scheduler = UniPCMultistepScheduler.from_config(control_pipe.scheduler.config)
|
78 |
+
|
79 |
+
vae.requires_grad_(False)
|
80 |
+
text_encoder.requires_grad_(False)
|
81 |
+
|
82 |
+
if controlnet:
|
83 |
+
unet.requires_grad_(False)
|
84 |
+
|
85 |
+
if args.training_stage == "zero convolutions":
|
86 |
+
controlnet.requires_grad_(False)
|
87 |
+
controlnet.controlnet_down_blocks.requires_grad_(True)
|
88 |
+
controlnet.controlnet_mid_block.requires_grad_(True)
|
89 |
+
# optimize only the zero convolution weights
|
90 |
+
params_to_optimize = list(
|
91 |
+
controlnet.controlnet_down_blocks.parameters()
|
92 |
+
) + list(controlnet.controlnet_mid_block.parameters())
|
93 |
+
|
94 |
+
elif args.training_stage == "input hint blocks":
|
95 |
+
controlnet.requires_grad_(False)
|
96 |
+
controlnet.controlnet_cond_embedding.requires_grad_(True)
|
97 |
+
params_to_optimize = list(
|
98 |
+
controlnet.controlnet_cond_embedding.parameters()
|
99 |
+
)
|
100 |
+
else:
|
101 |
+
controlnet.requires_grad_(True)
|
102 |
+
params_to_optimize = list(controlnet.parameters())
|
103 |
+
else:
|
104 |
+
unet.requires_grad_(True)
|
105 |
+
params_to_optimize = list(unet.parameters())
|
106 |
+
|
107 |
+
self.params_to_optimize = params_to_optimize
|
108 |
+
|
109 |
+
args.learning_rate = (
|
110 |
+
args.learning_rate
|
111 |
+
* args.gradient_accumulation_steps
|
112 |
+
* args.batch_size
|
113 |
+
* self.num_processes
|
114 |
+
)
|
115 |
+
|
116 |
+
if args.use_8bit_adam:
|
117 |
+
import bitsandbytes as bnb
|
118 |
+
|
119 |
+
optimizer_class = bnb.optim.AdamW8bit
|
120 |
+
else:
|
121 |
+
optimizer_class = torch.optim.AdamW
|
122 |
+
|
123 |
+
self.optimizer = self.prepare(
|
124 |
+
optimizer_class(
|
125 |
+
params_to_optimize,
|
126 |
+
lr=args.learning_rate,
|
127 |
+
)
|
128 |
+
)
|
129 |
+
|
130 |
+
if args.enable_xformers_memory_efficient_attention:
|
131 |
+
unet.enable_xformers_memory_efficient_attention()
|
132 |
+
if controlnet:
|
133 |
+
controlnet.enable_xformers_memory_efficient_attention()
|
134 |
+
|
135 |
+
if args.gradient_checkpointing:
|
136 |
+
unet.enable_gradient_checkpointing()
|
137 |
+
if controlnet:
|
138 |
+
controlnet.enable_gradient_checkpointing()
|
139 |
+
|
140 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
141 |
+
|
142 |
+
datamodule = PNGDataModule(
|
143 |
+
tokenizer=tokenizer,
|
144 |
+
from_hf_hub=args.from_hf_hub,
|
145 |
+
resolution=[args.resolution, args.resolution],
|
146 |
+
target_key=self.target_key,
|
147 |
+
cond_key=self.cond_key,
|
148 |
+
persistent_workers=True,
|
149 |
+
num_workers=args.dataloader_num_workers,
|
150 |
+
batch_size=args.batch_size,
|
151 |
+
controlnet_hint_key=None if controlnet is None else args.controlnet_hint_key,
|
152 |
+
)
|
153 |
+
|
154 |
+
self.train_dataloader = self.prepare(
|
155 |
+
datamodule.get_dataloader(args.train_data_dir, shuffle=True)
|
156 |
+
)
|
157 |
+
|
158 |
+
if args.valid_data_dir:
|
159 |
+
self.valid_dataloader = self.prepare(
|
160 |
+
datamodule.get_dataloader(args.valid_data_dir)
|
161 |
+
)
|
162 |
+
|
163 |
+
self.vae = vae.to(self.device, dtype=self.weight_dtype)
|
164 |
+
self.text_encoder = text_encoder.to(self.device, dtype=self.weight_dtype)
|
165 |
+
|
166 |
+
if controlnet:
|
167 |
+
controlnet = self.prepare(controlnet)
|
168 |
+
self.controlnet = controlnet.to(self.device, dtype=torch.float32)
|
169 |
+
self.unet = unet.to(self.device, dtype=self.weight_dtype)
|
170 |
+
else:
|
171 |
+
unet = self.prepare(unet)
|
172 |
+
self.unet = unet.to(self.device, dtype=torch.float32)
|
173 |
+
self.controlnet = None
|
174 |
+
|
175 |
+
def load_pipe(self, sd_model_path, controlnet_path=None):
|
176 |
+
|
177 |
+
if self.args.vae_path:
|
178 |
+
vae = AutoencoderKL.from_pretrained(
|
179 |
+
self.args.vae_path, torch_dtype=self.weight_dtype
|
180 |
+
)
|
181 |
+
|
182 |
+
if os.path.isfile(sd_model_path):
|
183 |
+
file_ext = sd_model_path.rsplit(".", 1)[-1]
|
184 |
+
from_safetensors = file_ext == "safetensors"
|
185 |
+
pipe = download_from_original_stable_diffusion_ckpt(
|
186 |
+
sd_model_path,
|
187 |
+
from_safetensors=from_safetensors,
|
188 |
+
device="cpu",
|
189 |
+
load_safety_checker=False,
|
190 |
+
)
|
191 |
+
pipe.safety_checker = None
|
192 |
+
pipe.feature_extractor = None
|
193 |
+
if self.args.vae_path:
|
194 |
+
pipe.vae = vae
|
195 |
+
else:
|
196 |
+
if self.args.vae_path:
|
197 |
+
kw_args = dict(vae=vae)
|
198 |
+
else:
|
199 |
+
kw_args = dict()
|
200 |
+
pipe = StableDiffusionPipeline.from_pretrained(
|
201 |
+
sd_model_path,
|
202 |
+
safety_checker=None,
|
203 |
+
feature_extractor=None,
|
204 |
+
requires_safety_checker=False,
|
205 |
+
torch_dtype=self.weight_dtype,
|
206 |
+
**kw_args
|
207 |
+
)
|
208 |
+
|
209 |
+
if not controlnet_path:
|
210 |
+
return pipe
|
211 |
+
|
212 |
+
pathobj = Path(controlnet_path)
|
213 |
+
if pathobj.is_file():
|
214 |
+
controlnet = ControlNetModel.from_config(
|
215 |
+
ControlNetModel.load_config("configs/controlnet_config.json")
|
216 |
+
)
|
217 |
+
controlnet.load_weights_from_sd_ckpt(controlnet_path)
|
218 |
+
else:
|
219 |
+
controlnet_path = str(Path().joinpath(*pathobj.parts[:-1]))
|
220 |
+
subfolder = str(pathobj.parts[-1])
|
221 |
+
controlnet = ControlNetModel.from_pretrained(
|
222 |
+
controlnet_path,
|
223 |
+
subfolder=subfolder,
|
224 |
+
low_cpu_mem_usage=False,
|
225 |
+
device_map=None,
|
226 |
+
)
|
227 |
+
|
228 |
+
return StableDiffusionControlNetPipeline(
|
229 |
+
**pipe.components,
|
230 |
+
controlnet=controlnet,
|
231 |
+
requires_safety_checker=False,
|
232 |
+
)
|
233 |
+
|
234 |
+
@torch.autocast("cuda")
|
235 |
+
def compute_loss(self, batch):
|
236 |
+
images = batch[self.target_key].to(dtype=self.weight_dtype)
|
237 |
+
latents = self.vae.encode(images).latent_dist.sample()
|
238 |
+
latents = latents * self.vae.config.scaling_factor
|
239 |
+
|
240 |
+
# Sample noise that we'll add to the latents
|
241 |
+
noise = torch.randn_like(latents)
|
242 |
+
# Sample a random timestep for each image
|
243 |
+
timesteps = torch.randint(
|
244 |
+
0,
|
245 |
+
self.noise_scheduler.config.num_train_timesteps,
|
246 |
+
(latents.shape[0],),
|
247 |
+
device=latents.device,
|
248 |
+
)
|
249 |
+
timesteps = timesteps.long()
|
250 |
+
|
251 |
+
# Add noise to the latents according to the noise magnitude at each timestep
|
252 |
+
# (this is the forward diffusion process)
|
253 |
+
noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps)
|
254 |
+
|
255 |
+
# Get the text embedding for conditioning
|
256 |
+
encoder_hidden_states = self.text_encoder(batch[self.cond_key])[0]
|
257 |
+
|
258 |
+
if self.controlnet:
|
259 |
+
|
260 |
+
if self.args.controlnet_hint_key in batch:
|
261 |
+
controlnet_hint = batch[self.args.controlnet_hint_key].to(
|
262 |
+
dtype=self.weight_dtype
|
263 |
+
)
|
264 |
+
else:
|
265 |
+
controlnet_hint = torch.zeros(images.shape).to(images)
|
266 |
+
|
267 |
+
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
268 |
+
noisy_latents,
|
269 |
+
timesteps,
|
270 |
+
encoder_hidden_states=encoder_hidden_states,
|
271 |
+
controlnet_cond=controlnet_hint,
|
272 |
+
return_dict=False,
|
273 |
+
)
|
274 |
+
else:
|
275 |
+
down_block_res_samples, mid_block_res_sample = None, None
|
276 |
+
|
277 |
+
noise_pred = self.unet(
|
278 |
+
noisy_latents,
|
279 |
+
timesteps,
|
280 |
+
encoder_hidden_states=encoder_hidden_states,
|
281 |
+
down_block_additional_residuals=down_block_res_samples,
|
282 |
+
mid_block_additional_residual=mid_block_res_sample,
|
283 |
+
).sample
|
284 |
+
|
285 |
+
# Get the target for loss depending on the prediction type
|
286 |
+
if self.noise_scheduler.config.prediction_type == "epsilon":
|
287 |
+
target = noise
|
288 |
+
elif self.noise_scheduler.config.prediction_type == "v_prediction":
|
289 |
+
target = self.noise_scheduler.get_velocity(latents, noise, timesteps)
|
290 |
+
else:
|
291 |
+
raise ValueError(
|
292 |
+
f"Unknown prediction type {self.noise_scheduler.config.prediction_type}"
|
293 |
+
)
|
294 |
+
|
295 |
+
loss = F.mse_loss(noise_pred, target, reduction="mean")
|
296 |
+
|
297 |
+
return loss, encoder_hidden_states
|
298 |
+
|
299 |
+
def decode_latents(self, latents):
|
300 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
301 |
+
output_latents = self.vae.decode(latents).sample
|
302 |
+
output_latents = (output_latents / 2 + 0.5).clamp(0, 1)
|
303 |
+
return output_latents
|
304 |
+
|
305 |
+
@torch.no_grad()
|
306 |
+
@torch.autocast("cuda")
|
307 |
+
def log_images(self, batch, encoder_hidden_states, cond_scales=[0.0, 0.5, 1.0]):
|
308 |
+
input_tensors = batch[self.target_key].to(self.weight_dtype)
|
309 |
+
input_tensors = (input_tensors / 2 + 0.5).clamp(0, 1)
|
310 |
+
|
311 |
+
tensors_to_log = [input_tensors.cpu()]
|
312 |
+
|
313 |
+
[height, width] = input_tensors.shape[-2:]
|
314 |
+
|
315 |
+
if self.controlnet:
|
316 |
+
if self.args.controlnet_hint_key in batch:
|
317 |
+
controlnet_hint = batch[self.args.controlnet_hint_key].to(
|
318 |
+
self.weight_dtype
|
319 |
+
)
|
320 |
+
else:
|
321 |
+
controlnet_hint = None
|
322 |
+
|
323 |
+
for cond_scale in cond_scales:
|
324 |
+
latents = self.control_pipe(
|
325 |
+
image=controlnet_hint,
|
326 |
+
prompt_embeds=encoder_hidden_states,
|
327 |
+
controlnet_conditioning_scale=cond_scale,
|
328 |
+
height=height,
|
329 |
+
width=width,
|
330 |
+
output_type="latent",
|
331 |
+
num_inference_steps=25,
|
332 |
+
)[0]
|
333 |
+
|
334 |
+
tensors_to_log.append(self.decode_latents(latents).detach().cpu())
|
335 |
+
|
336 |
+
if controlnet_hint is not None:
|
337 |
+
tensors_to_log.append(controlnet_hint.detach().cpu())
|
338 |
+
else:
|
339 |
+
latents = self.control_pipe(
|
340 |
+
prompt_embeds=encoder_hidden_states,
|
341 |
+
height=height,
|
342 |
+
width=width,
|
343 |
+
output_type="latent",
|
344 |
+
num_inference_steps=25,
|
345 |
+
)[0]
|
346 |
+
|
347 |
+
tensors_to_log.append(self.decode_latents(latents).detach().cpu())
|
348 |
+
|
349 |
+
image_tensors = torch.cat(tensors_to_log)
|
350 |
+
|
351 |
+
grid = make_grid(image_tensors, normalize=False, nrow=input_tensors.shape[0])
|
352 |
+
grid = grid.permute(1, 2, 0).squeeze(-1) * 255
|
353 |
+
grid = grid.numpy().astype(np.uint8)
|
354 |
+
|
355 |
+
image_grid = Image.fromarray(grid)
|
356 |
+
image_grid.save(Path(self.trackers[0].logging_dir) / f"{self.global_step}.png")
|
357 |
+
|
358 |
+
def save_weights(self, to_safetensors=True):
|
359 |
+
save_dir = self.output_dir / f"checkpoint-{self.global_step}"
|
360 |
+
os.makedirs(save_dir, exist_ok=True)
|
361 |
+
|
362 |
+
if self.args.save_whole_pipeline:
|
363 |
+
self.control_pipe.save_pretrained(
|
364 |
+
str(save_dir), safe_serialization=to_safetensors
|
365 |
+
)
|
366 |
+
elif self.controlnet:
|
367 |
+
self.controlnet.save_pretrained(
|
368 |
+
str(save_dir / "controlnet"), safe_serialization=to_safetensors
|
369 |
+
)
|
370 |
+
else:
|
371 |
+
self.unet.save_pretrained(
|
372 |
+
str(save_dir / "unet"), safe_serialization=to_safetensors
|
373 |
+
)
|
374 |
+
|
375 |
+
def train(self, num_train_epochs=1000):
|
376 |
+
args = self.args
|
377 |
+
|
378 |
+
max_train_steps = (
|
379 |
+
num_train_epochs
|
380 |
+
* len(self.train_dataloader)
|
381 |
+
// args.gradient_accumulation_steps
|
382 |
+
)
|
383 |
+
|
384 |
+
if self.is_main_process:
|
385 |
+
self.init_trackers("tb_logs", config=vars(args))
|
386 |
+
|
387 |
+
self.global_step = 0
|
388 |
+
|
389 |
+
# Only show the progress bar once on each machine.
|
390 |
+
progress_bar = tqdm(
|
391 |
+
range(max_train_steps),
|
392 |
+
disable=not self.is_local_main_process,
|
393 |
+
)
|
394 |
+
progress_bar.set_description("Steps")
|
395 |
+
|
396 |
+
try:
|
397 |
+
for epoch in range(num_train_epochs):
|
398 |
+
# run training loop
|
399 |
+
if self.controlnet:
|
400 |
+
self.controlnet.train()
|
401 |
+
else:
|
402 |
+
self.unet.train()
|
403 |
+
for batch in self.train_dataloader:
|
404 |
+
loss, encoder_hidden_states = self.compute_loss(batch)
|
405 |
+
|
406 |
+
loss /= args.gradient_accumulation_steps
|
407 |
+
self.backward(loss)
|
408 |
+
if self.global_step % args.gradient_accumulation_steps == 0:
|
409 |
+
if self.sync_gradients:
|
410 |
+
self.clip_grad_norm_(
|
411 |
+
self.params_to_optimize, args.max_grad_norm
|
412 |
+
)
|
413 |
+
self.optimizer.step()
|
414 |
+
self.optimizer.zero_grad()
|
415 |
+
|
416 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
417 |
+
if self.sync_gradients:
|
418 |
+
progress_bar.update(1)
|
419 |
+
self.global_step += 1
|
420 |
+
|
421 |
+
if self.is_main_process:
|
422 |
+
if self.global_step % args.checkpointing_steps == 0:
|
423 |
+
self.save_weights()
|
424 |
+
|
425 |
+
if args.image_logging_steps and (
|
426 |
+
self.global_step % args.image_logging_steps == 0
|
427 |
+
or self.global_step == 1
|
428 |
+
):
|
429 |
+
self.log_images(batch, encoder_hidden_states)
|
430 |
+
|
431 |
+
logs = {"training_loss": loss.detach().item()}
|
432 |
+
self.log(logs, step=self.global_step)
|
433 |
+
progress_bar.set_postfix(**logs)
|
434 |
+
|
435 |
+
if self.global_step >= max_train_steps:
|
436 |
+
break
|
437 |
+
|
438 |
+
self.wait_for_everyone()
|
439 |
+
|
440 |
+
# run validation loop
|
441 |
+
if args.valid_data_dir:
|
442 |
+
total_valid_loss = 0
|
443 |
+
if self.controlnet:
|
444 |
+
self.controlnet.eval()
|
445 |
+
else:
|
446 |
+
self.unet.eval()
|
447 |
+
|
448 |
+
for batch in self.valid_dataloader:
|
449 |
+
with torch.no_grad():
|
450 |
+
loss, encoder_hidden_states = self.compute_loss(batch)
|
451 |
+
|
452 |
+
loss = loss.detach().item()
|
453 |
+
total_valid_loss += loss
|
454 |
+
logs = {"validation_loss": loss}
|
455 |
+
progress_bar.set_postfix(**logs)
|
456 |
+
|
457 |
+
self.log(
|
458 |
+
{
|
459 |
+
"validation_loss": total_valid_loss
|
460 |
+
/ len(self.valid_dataloader)
|
461 |
+
},
|
462 |
+
step=self.global_step,
|
463 |
+
)
|
464 |
+
self.wait_for_everyone()
|
465 |
+
|
466 |
+
except KeyboardInterrupt:
|
467 |
+
print("Keyboard interrupt detected, attempting to save trained weights")
|
468 |
+
|
469 |
+
# except Exception as e:
|
470 |
+
# print(f"Encountered error {e}, attempting to save trained weights")
|
471 |
+
|
472 |
+
self.save_weights()
|
473 |
+
|
474 |
+
self.end_training()
|
src/ui_assets/controlnet_ids.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
anime_merge
|
2 |
+
anime_dream
|
3 |
+
anime_protogen
|
4 |
+
anime_neverending
|
src/ui_assets/examples
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
../../examples
|
src/ui_assets/footer.html
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
<!-- based on https://huggingface.co/spaces/stabilityai/stable-diffusion/blob/main/app.py -->
|
3 |
+
|
4 |
+
|
5 |
+
<div class="footer">
|
6 |
+
<p><h4>LICENSE</h4>
|
7 |
+
The default model is licensed with a <a href="https://huggingface.co/stabilityai/stable-diffusion-2/blob/main/LICENSE-MODEL" style="text-decoration: underline;" target="_blank">CreativeML OpenRAIL++</a> license. The authors claim no rights on the outputs you generate, you are free to use them and are accountable for their use which must not go against the provisions set in this license. The license forbids you from sharing any content that violates any laws, produce any harm to a person, disseminate any personal information that would be meant for harm, spread misinformation and target vulnerable groups. For the full list of restrictions please <a href="https://huggingface.co/spaces/CompVis/stable-diffusion-license" target="_blank" style="text-decoration: underline;" target="_blank">read the license</a></p>
|
8 |
+
</div>
|
9 |
+
|
src/ui_assets/header.html
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
<!-- based on https://huggingface.co/spaces/stabilityai/stable-diffusion/blob/main/app.py -->
|
3 |
+
|
4 |
+
<div style="text-align: center; margin: 0 auto;">
|
5 |
+
<div
|
6 |
+
style="
|
7 |
+
display: inline-flex;
|
8 |
+
align-items: center;
|
9 |
+
gap: 0.8rem;
|
10 |
+
font-size: 1.75rem;
|
11 |
+
"
|
12 |
+
>
|
13 |
+
<svg version="1.1" id="Layer_1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" x="0px" y="0px" viewBox="0 0 32 32" style="enable-background:new 0 0 512 512;" xml:space="preserve" width="32" height="32"><path style="fill:#FCD577;" d="M29.545 29.791V2.21c-1.22 0 -2.21 -0.99 -2.21 -2.21H4.665c0 1.22 -0.99 2.21 -2.21 2.21v27.581c1.22 0 2.21 0.99 2.21 2.21H27.335C27.335 30.779 28.325 29.791 29.545 29.791z"/><path x="98.205" y="58.928" style="fill:#99B6C6;" width="315.577" height="394.144" d="M6.138 3.683H25.861V28.317H6.138V3.683z"/><path x="98.205" y="58.928" style="fill:#7BD4EF;" width="315.577" height="131.317" d="M6.138 3.683H25.861V11.89H6.138V3.683z"/><g><path style="fill:#7190A5;" d="M14.498 10.274c0 1.446 0.983 1.155 1.953 1.502l0.504 5.317c0 0 -5.599 0.989 -6.026 2.007l0.27 -2.526c0.924 -1.462 1.286 -4.864 1.419 -6.809l0.086 0.006C12.697 9.876 14.498 10.166 14.498 10.274z"/><path style="fill:#7190A5;" d="M21.96 17.647c0 0 -0.707 1.458 -1.716 1.903c0 0 -1.502 -0.827 -1.502 -0.827c-2.276 -1.557 -2.366 -8.3 -2.366 -8.3c0 -1.718 -0.185 -1.615 -1.429 -1.615c-1.167 0 -2.127 -0.606 -2.242 0.963l-0.086 -0.006c0.059 -0.859 0.074 -1.433 0.074 -1.433c0 -1.718 1.449 -3.11 3.237 -3.11s3.237 1.392 3.237 3.11C19.168 8.332 19.334 15.617 21.96 17.647z"/></g><path style="fill:#6C8793;" d="M12.248 24.739c1.538 0.711 3.256 1.591 3.922 2.258c-1.374 0.354 -2.704 0.798 -3.513 1.32h-2.156c-1.096 -0.606 -2.011 -1.472 -2.501 -2.702c-1.953 -4.907 2.905 -8.664 2.905 -8.664c0.001 -0.001 0.002 -0.002 0.003 -0.003c0.213 -0.214 0.523 -0.301 0.811 -0.21l0.02 0.006c-0.142 0.337 -0.03 0.71 0.517 1.108c1.264 0.919 3.091 1.131 4.416 1.143c-1.755 1.338 -3.42 3.333 -4.367 5.618L12.248 24.739z"/><path style="fill:#577484;" d="M16.17 26.997c-0.666 -0.666 -2.385 -1.548 -3.922 -2.258l0.059 -0.126c0.947 -2.284 2.612 -4.28 4.367 -5.618c0.001 0 0.001 0 0.001 0c0.688 -0.525 1.391 -0.948 2.068 -1.247c0.001 0 0.001 0 0.001 0c1.009 -0.446 1.964 -0.617 2.742 -0.44c0.61 0.138 1.109 0.492 1.439 1.095c1.752 3.205 0.601 9.913 0.601 9.913H12.657C13.466 27.796 14.796 27.352 16.17 26.997z"/><path style="fill:#F7DEB0;" d="M14.38 13.1c-0.971 -0.347 -1.687 -1.564 -1.687 -3.01c0 -0.107 0.004 -0.213 0.011 -0.318c0.116 -1.569 1.075 -2.792 2.242 -2.792c1.244 0 2.253 1.392 2.253 3.11c0 0 -0.735 6.103 1.542 7.66c-0.677 0.299 -1.38 0.722 -2.068 1.247c0 0 0 0 -0.001 0c-1.326 -0.012 -3.152 -0.223 -4.416 -1.143c-0.547 -0.398 -0.659 -0.771 -0.517 -1.108c0.426 -1.018 3.171 -1.697 3.171 -1.697L14.38 13.1z"/><path style="fill:#E5CA9E;" d="M14.38 13.1c0 0 1.019 0.216 1.544 -0.309c0 0 -0.401 1.04 -1.346 1.04"/><g><path style="fill:#EAC36E;" points="437.361,0 413.79,58.926 472.717,35.356 " d="M27.335 0L25.862 3.683L29.545 2.21"/><path style="fill:#EAC36E;" points="437.361,512 413.79,453.074 472.717,476.644 " d="M27.335 32L25.862 28.317L29.545 29.791"/><path style="fill:#EAC36E;" points="74.639,512 98.21,453.074 39.283,476.644 " d="M4.665 32L6.138 28.317L2.455 29.791"/><path style="fill:#EAC36E;" points="39.283,35.356 98.21,58.926 74.639,0 " d="M2.455 2.21L6.138 3.683L4.665 0"/><path style="fill:#EAC36E;" d="M26.425 28.881H5.574V3.119h20.851v25.761H26.425zM6.702 27.754h18.597V4.246H6.702V27.754z"/></g><g><path style="fill:#486572;" d="M12.758 21.613c-0.659 0.767 -1.245 1.613 -1.722 2.531l0.486 0.202C11.82 23.401 12.241 22.483 12.758 21.613z"/><path style="fill:#486572;" d="M21.541 25.576l-0.37 0.068c-0.553 0.101 -1.097 0.212 -1.641 0.331l-0.071 -0.201l-0.059 -0.167c-0.019 -0.056 -0.035 -0.112 -0.052 -0.169l-0.104 -0.338l-0.088 -0.342c-0.112 -0.457 -0.197 -0.922 -0.235 -1.393c-0.035 -0.47 -0.032 -0.947 0.042 -1.417c0.072 -0.47 0.205 -0.935 0.422 -1.369c-0.272 0.402 -0.469 0.856 -0.606 1.329c-0.138 0.473 -0.207 0.967 -0.234 1.462c-0.024 0.496 0.002 0.993 0.057 1.487l0.046 0.37l0.063 0.367c0.011 0.061 0.02 0.123 0.033 0.184l0.039 0.182l0.037 0.174c-0.677 0.157 -1.351 0.327 -2.019 0.514c-0.131 0.037 -0.262 0.075 -0.392 0.114l0.004 -0.004c-0.117 -0.095 -0.232 -0.197 -0.35 -0.275c-0.059 -0.041 -0.117 -0.084 -0.177 -0.122l-0.179 -0.112c-0.239 -0.147 -0.482 -0.279 -0.727 -0.406c-0.489 -0.252 -0.985 -0.479 -1.484 -0.697c-0.998 -0.433 -2.01 -0.825 -3.026 -1.196c0.973 0.475 1.937 0.969 2.876 1.499c0.469 0.266 0.932 0.539 1.379 0.832c0.223 0.146 0.442 0.297 0.648 0.456l0.154 0.119c0.05 0.041 0.097 0.083 0.145 0.124c0.002 0.002 0.004 0.003 0.005 0.005c-0.339 0.109 -0.675 0.224 -1.009 0.349c-0.349 0.132 -0.696 0.273 -1.034 0.431c-0.338 0.159 -0.668 0.337 -0.973 0.549c0.322 -0.186 0.662 -0.334 1.01 -0.463c0.347 -0.129 0.701 -0.239 1.056 -0.34c0.394 -0.111 0.79 -0.208 1.19 -0.297c0.006 0.006 0.013 0.013 0.019 0.019l0.03 -0.03c0.306 -0.068 0.614 -0.132 0.922 -0.192c0.727 -0.14 1.457 -0.258 2.189 -0.362c0.731 -0.103 1.469 -0.195 2.197 -0.265l0.374 -0.036L21.541 25.576z"/></g></svg>
|
14 |
+
|
15 |
+
<h1 style="font-weight: 1000; margin-bottom: 8px;margin-top:8px">
|
16 |
+
<a href="https://github.com/1lint/style_controlnet">
|
17 |
+
Style ControlNet Web UI
|
18 |
+
</a>
|
19 |
+
</h1>
|
20 |
+
</div>
|
21 |
+
<p> Use the ControlNet architecture to control Stable Diffusion image generation style</p>
|
22 |
+
</div>
|
23 |
+
|
src/ui_assets/model_ids.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
lint/liquidfix
|
2 |
+
prompthero/openjourney-v2
|
3 |
+
Lykon/DreamShaper
|
4 |
+
darkstorm2150/Protogen_x5.8_Official_Release
|
5 |
+
runwayml/stable-diffusion-v1-5
|
src/ui_functions.py
ADDED
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import random
|
4 |
+
from PIL import Image
|
5 |
+
import os
|
6 |
+
import argparse
|
7 |
+
import shutil
|
8 |
+
import gc
|
9 |
+
import importlib
|
10 |
+
import json
|
11 |
+
from multiprocessing import cpu_count
|
12 |
+
import cv2
|
13 |
+
import numpy as np
|
14 |
+
from pathlib import Path
|
15 |
+
|
16 |
+
from diffusers import (
|
17 |
+
StableDiffusionControlNetPipeline,
|
18 |
+
StableDiffusionPipeline,
|
19 |
+
ControlNetModel,
|
20 |
+
AutoencoderKL,
|
21 |
+
)
|
22 |
+
|
23 |
+
from src.controlnet_pipe import ControlNetPipe as StableDiffusionControlNetPipeline
|
24 |
+
|
25 |
+
|
26 |
+
from src.lab import Lab
|
27 |
+
|
28 |
+
|
29 |
+
from src.ui_shared import (
|
30 |
+
default_scheduler,
|
31 |
+
scheduler_dict,
|
32 |
+
model_ids,
|
33 |
+
controlnet_ids,
|
34 |
+
is_hfspace,
|
35 |
+
)
|
36 |
+
|
37 |
+
CONTROLNET_REPO = "lint/anime_control"
|
38 |
+
_xformers_available = importlib.util.find_spec("xformers") is not None
|
39 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
40 |
+
# device = 'cpu'
|
41 |
+
dtype = torch.float16 if device == "cuda" else torch.float32
|
42 |
+
|
43 |
+
pipe = None
|
44 |
+
loaded_model_id = ""
|
45 |
+
loaded_controlnet_id = ""
|
46 |
+
|
47 |
+
def load_pipe(model_id, controlnet_id, scheduler_name):
|
48 |
+
global pipe, loaded_model_id, loaded_controlnet_id
|
49 |
+
|
50 |
+
scheduler = scheduler_dict[scheduler_name]
|
51 |
+
|
52 |
+
reload_pipe = False
|
53 |
+
|
54 |
+
if pipe:
|
55 |
+
new_weights = pipe.components
|
56 |
+
else:
|
57 |
+
new_weights = {}
|
58 |
+
|
59 |
+
if model_id != loaded_model_id:
|
60 |
+
|
61 |
+
new_pipe = StableDiffusionPipeline.from_pretrained(
|
62 |
+
model_id,
|
63 |
+
vae=AutoencoderKL.from_pretrained("lint/anime_vae", torch_dtype=dtype),
|
64 |
+
safety_checker=None,
|
65 |
+
feature_extractor=None,
|
66 |
+
requires_safety_checker=False,
|
67 |
+
use_safetensors=False,
|
68 |
+
torch_dtype=dtype,
|
69 |
+
)
|
70 |
+
loaded_model_id = model_id
|
71 |
+
new_weights.update(new_pipe.components)
|
72 |
+
new_weights["scheduler"] = scheduler.from_pretrained(model_id, subfolder="scheduler")
|
73 |
+
reload_pipe = True
|
74 |
+
|
75 |
+
if controlnet_id != loaded_controlnet_id:
|
76 |
+
|
77 |
+
controlnet = ControlNetModel.from_pretrained(
|
78 |
+
CONTROLNET_REPO,
|
79 |
+
subfolder=controlnet_id,
|
80 |
+
torch_dtype=dtype,
|
81 |
+
)
|
82 |
+
loaded_controlnet_id = controlnet_id
|
83 |
+
new_weights["controlnet"] = controlnet
|
84 |
+
reload_pipe = True
|
85 |
+
|
86 |
+
|
87 |
+
if reload_pipe:
|
88 |
+
pipe = StableDiffusionControlNetPipeline(
|
89 |
+
**new_weights,
|
90 |
+
requires_safety_checker=False,
|
91 |
+
)
|
92 |
+
|
93 |
+
|
94 |
+
if device == "cuda":
|
95 |
+
for component in pipe.components.values():
|
96 |
+
if isinstance(component, torch.nn.Module):
|
97 |
+
component.to("cuda", torch.float16)
|
98 |
+
if _xformers_available:
|
99 |
+
pipe.enable_xformers_memory_efficient_attention()
|
100 |
+
pipe.enable_attention_slicing()
|
101 |
+
pipe.enable_vae_tiling()
|
102 |
+
|
103 |
+
return pipe
|
104 |
+
|
105 |
+
|
106 |
+
# initialize with preloaded pipe
|
107 |
+
if is_hfspace:
|
108 |
+
pipe = load_pipe(model_ids[0], controlnet_ids[0], default_scheduler)
|
109 |
+
|
110 |
+
|
111 |
+
def extract_canny(image):
|
112 |
+
CANNY_THRESHOLD = (100, 200)
|
113 |
+
|
114 |
+
image_array = np.asarray(image)
|
115 |
+
canny_image = cv2.Canny(image_array, *CANNY_THRESHOLD)
|
116 |
+
canny_image = canny_image[:, :, None]
|
117 |
+
canny_image = np.concatenate([canny_image]*3, axis=2)
|
118 |
+
|
119 |
+
return Image.fromarray(canny_image)
|
120 |
+
|
121 |
+
@torch.no_grad()
|
122 |
+
def generate(
|
123 |
+
model_name,
|
124 |
+
guidance_image,
|
125 |
+
controlnet_name,
|
126 |
+
scheduler_name,
|
127 |
+
prompt,
|
128 |
+
guidance,
|
129 |
+
steps,
|
130 |
+
n_images=1,
|
131 |
+
width=512,
|
132 |
+
height=512,
|
133 |
+
seed=0,
|
134 |
+
neg_prompt="",
|
135 |
+
controlnet_prompt=None,
|
136 |
+
controlnet_negative_prompt=None,
|
137 |
+
controlnet_cond_scale=1.0,
|
138 |
+
progress=gr.Progress(track_tqdm=True),
|
139 |
+
):
|
140 |
+
|
141 |
+
if seed == -1:
|
142 |
+
seed = random.randint(0, 2147483647)
|
143 |
+
|
144 |
+
if guidance_image:
|
145 |
+
guiadnce_image = extract_canny(guidance_image)
|
146 |
+
else:
|
147 |
+
guidance_image = torch.zeros(n_images, 3, height, width)
|
148 |
+
|
149 |
+
generator = torch.Generator(device).manual_seed(seed)
|
150 |
+
|
151 |
+
pipe = load_pipe(
|
152 |
+
model_id=model_name,
|
153 |
+
controlnet_id=controlnet_name,
|
154 |
+
scheduler_name=scheduler_name,
|
155 |
+
)
|
156 |
+
|
157 |
+
status_message = f"Prompt: '{prompt}' | Seed: {seed} | Guidance: {guidance} | Scheduler: {scheduler_name} | Steps: {steps}"
|
158 |
+
|
159 |
+
# pass None so pipeline uses base prompt as controlnet_prompt
|
160 |
+
if controlnet_prompt == "":
|
161 |
+
controlnet_prompt = None #
|
162 |
+
if controlnet_negative_prompt == "":
|
163 |
+
controlnet_negative_prompt = None
|
164 |
+
|
165 |
+
if controlnet_prompt:
|
166 |
+
controlnet_prompt_embeds = pipe._encode_prompt(
|
167 |
+
controlnet_prompt,
|
168 |
+
device,
|
169 |
+
n_images,
|
170 |
+
do_classifier_free_guidance = guidance > 1.0,
|
171 |
+
negative_prompt = controlnet_negative_prompt,
|
172 |
+
prompt_embeds=None,
|
173 |
+
negative_prompt_embeds=None,
|
174 |
+
)
|
175 |
+
else:
|
176 |
+
controlnet_prompt_embeds = None
|
177 |
+
|
178 |
+
result = pipe(
|
179 |
+
prompt,
|
180 |
+
image=guidance_image,
|
181 |
+
height=height,
|
182 |
+
width=width,
|
183 |
+
num_inference_steps=int(steps),
|
184 |
+
guidance_scale=guidance,
|
185 |
+
negative_prompt=neg_prompt,
|
186 |
+
num_images_per_prompt=n_images,
|
187 |
+
generator=generator,
|
188 |
+
controlnet_conditioning_scale = float(controlnet_cond_scale),
|
189 |
+
controlnet_prompt_embeds = controlnet_prompt_embeds,
|
190 |
+
)
|
191 |
+
|
192 |
+
return result.images, status_message
|
193 |
+
|
194 |
+
def run_training(
|
195 |
+
model_name,
|
196 |
+
controlnet_weights_path,
|
197 |
+
train_data_dir,
|
198 |
+
valid_data_dir,
|
199 |
+
train_batch_size,
|
200 |
+
train_whole_controlnet,
|
201 |
+
gradient_accumulation_steps,
|
202 |
+
max_train_steps,
|
203 |
+
train_learning_rate,
|
204 |
+
output_dir,
|
205 |
+
checkpointing_steps,
|
206 |
+
image_logging_steps,
|
207 |
+
save_whole_pipeline,
|
208 |
+
progress=gr.Progress(track_tqdm=True),
|
209 |
+
):
|
210 |
+
global pipe
|
211 |
+
|
212 |
+
if device == "cpu":
|
213 |
+
raise gr.Error("Training not supported on CPU")
|
214 |
+
|
215 |
+
pathobj = Path(controlnet_weights_path)
|
216 |
+
|
217 |
+
controlnet_path = str(Path().joinpath(*pathobj.parts[:-1]))
|
218 |
+
subfolder = str(pathobj.parts[-1])
|
219 |
+
controlnet = ControlNetModel.from_pretrained(
|
220 |
+
controlnet_path,
|
221 |
+
subfolder=subfolder,
|
222 |
+
low_cpu_mem_usage=False,
|
223 |
+
device_map=None,
|
224 |
+
)
|
225 |
+
|
226 |
+
pipe.components["controlnet"] = controlnet
|
227 |
+
|
228 |
+
pipe = StableDiffusionControlNetPipeline(
|
229 |
+
**pipe.components,
|
230 |
+
requires_safety_checker=False,
|
231 |
+
)
|
232 |
+
|
233 |
+
training_args = argparse.Namespace(
|
234 |
+
# start training from preexisting models
|
235 |
+
pretrained_model_name_or_path=None,
|
236 |
+
controlnet_weights_path=None,
|
237 |
+
|
238 |
+
# dataset args
|
239 |
+
train_data_dir=train_data_dir,
|
240 |
+
valid_data_dir=valid_data_dir,
|
241 |
+
resolution=512,
|
242 |
+
from_hf_hub = train_data_dir == "lint/anybooru",
|
243 |
+
controlnet_hint_key=None,
|
244 |
+
|
245 |
+
# training args
|
246 |
+
# options are ["zero convolutions", "input hint blocks"], trains whole controlnet by default
|
247 |
+
training_stage="" if train_whole_controlnet else "zero convolutions",
|
248 |
+
learning_rate=float(train_learning_rate),
|
249 |
+
num_train_epochs=1000,
|
250 |
+
max_train_steps=int(max_train_steps),
|
251 |
+
seed=3434554,
|
252 |
+
max_grad_norm=1.0,
|
253 |
+
gradient_accumulation_steps=int(gradient_accumulation_steps),
|
254 |
+
|
255 |
+
# VRAM args
|
256 |
+
batch_size=train_batch_size,
|
257 |
+
mixed_precision="fp16", # set to "fp16" for mixed-precision training.
|
258 |
+
gradient_checkpointing=True, # set this to True to lower the memory usage.
|
259 |
+
use_8bit_adam=False, # use 8bit optimizer from bitsandbytes
|
260 |
+
enable_xformers_memory_efficient_attention=True,
|
261 |
+
allow_tf32=True,
|
262 |
+
dataloader_num_workers=cpu_count(),
|
263 |
+
|
264 |
+
# logging args
|
265 |
+
output_dir=output_dir,
|
266 |
+
report_to="tensorboard",
|
267 |
+
image_logging_steps=image_logging_steps, # disabled when 0. costs additional VRAM to log images
|
268 |
+
save_whole_pipeline=save_whole_pipeline,
|
269 |
+
checkpointing_steps=checkpointing_steps,
|
270 |
+
)
|
271 |
+
|
272 |
+
try:
|
273 |
+
lab = Lab(training_args, pipe)
|
274 |
+
lab.train(training_args.num_train_epochs)
|
275 |
+
except Exception as e:
|
276 |
+
raise gr.Error(e)
|
277 |
+
|
278 |
+
for component in pipe.components.values():
|
279 |
+
if isinstance(component, torch.nn.Module):
|
280 |
+
component.to(device, dtype=dtype)
|
281 |
+
|
282 |
+
gc.collect()
|
283 |
+
torch.cuda.empty_cache()
|
284 |
+
|
285 |
+
return f"Finished training! Check the {training_args.output_dir} directory for saved model weights"
|
src/ui_shared.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import diffusers.schedulers
|
2 |
+
import os
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
assets_directory = Path(__file__).parent / "ui_assets"
|
6 |
+
|
7 |
+
is_hfspace = "SPACE_REPO_NAME" in os.environ
|
8 |
+
|
9 |
+
scheduler_dict = {
|
10 |
+
k: v
|
11 |
+
for k, v in diffusers.schedulers.__dict__.items()
|
12 |
+
if "Scheduler" in k and "Flax" not in k
|
13 |
+
}
|
14 |
+
scheduler_dict.pop(
|
15 |
+
"VQDiffusionScheduler", None
|
16 |
+
) # requires unique parameter, unlike other schedulers
|
17 |
+
scheduler_names = list(scheduler_dict.keys())
|
18 |
+
default_scheduler = "UniPCMultistepScheduler"
|
19 |
+
|
20 |
+
with open(assets_directory / "model_ids.txt", "r") as fp:
|
21 |
+
model_ids = fp.read().splitlines()
|
22 |
+
|
23 |
+
with open(assets_directory / "controlnet_ids.txt", "r") as fp:
|
24 |
+
controlnet_ids = fp.read().splitlines()
|