Spaces:
Running
on
Zero
Running
on
Zero
Upload folder using huggingface_hub
Browse files- .gitattributes +2 -0
- .gitignore +162 -0
- LICENSE +201 -0
- README.md +35 -12
- assets/demos.pdf +3 -0
- assets/demos.png +3 -0
- inference.py +59 -0
- src/pipeline.py +373 -0
- src/pipeline_img2img.py +353 -0
- src/pipeline_inpaint.py +378 -0
- src/scheduler.py +175 -0
- src/transformer.py +1215 -0
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
assets/demos.pdf filter=lfs diff=lfs merge=lfs -text
|
37 |
+
assets/demos.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# poetry
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
102 |
+
#poetry.lock
|
103 |
+
|
104 |
+
# pdm
|
105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
106 |
+
#pdm.lock
|
107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
108 |
+
# in version control.
|
109 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
110 |
+
.pdm.toml
|
111 |
+
.pdm-python
|
112 |
+
.pdm-build/
|
113 |
+
|
114 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
115 |
+
__pypackages__/
|
116 |
+
|
117 |
+
# Celery stuff
|
118 |
+
celerybeat-schedule
|
119 |
+
celerybeat.pid
|
120 |
+
|
121 |
+
# SageMath parsed files
|
122 |
+
*.sage.py
|
123 |
+
|
124 |
+
# Environments
|
125 |
+
.env
|
126 |
+
.venv
|
127 |
+
env/
|
128 |
+
venv/
|
129 |
+
ENV/
|
130 |
+
env.bak/
|
131 |
+
venv.bak/
|
132 |
+
|
133 |
+
# Spyder project settings
|
134 |
+
.spyderproject
|
135 |
+
.spyproject
|
136 |
+
|
137 |
+
# Rope project settings
|
138 |
+
.ropeproject
|
139 |
+
|
140 |
+
# mkdocs documentation
|
141 |
+
/site
|
142 |
+
|
143 |
+
# mypy
|
144 |
+
.mypy_cache/
|
145 |
+
.dmypy.json
|
146 |
+
dmypy.json
|
147 |
+
|
148 |
+
# Pyre type checker
|
149 |
+
.pyre/
|
150 |
+
|
151 |
+
# pytype static type analyzer
|
152 |
+
.pytype/
|
153 |
+
|
154 |
+
# Cython debug symbols
|
155 |
+
cython_debug/
|
156 |
+
|
157 |
+
# PyCharm
|
158 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
159 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
160 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
161 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
162 |
+
#.idea/
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
README.md
CHANGED
@@ -1,12 +1,35 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Meissonic: Revitalizing Masked Generative Transformers for Efficient High-Resolution Text-to-Image Synthesis
|
2 |
+
|
3 |
+
[Paper](https://arxiv.org/abs/2410.08261) | [Model](https://huggingface.co/MeissonFlow/Meissonic) | [Code](https://github.com/viiika/Meissonic)
|
4 |
+
|
5 |
+
|
6 |
+
![demo](./assets/demos.png)
|
7 |
+
|
8 |
+
## Introduction
|
9 |
+
Meissonic is a non-autoregressive mask image modeling text-to-image synthesis model that can generate high-resolution images. It is designed to run on consumer graphics cards.
|
10 |
+
|
11 |
+
## Prerequisites
|
12 |
+
|
13 |
+
```bash
|
14 |
+
git clone https://github.com/huggingface/diffusers.git
|
15 |
+
cd diffusers
|
16 |
+
pip install -e .
|
17 |
+
```
|
18 |
+
|
19 |
+
## Usage
|
20 |
+
|
21 |
+
```bash
|
22 |
+
python inference.py
|
23 |
+
```
|
24 |
+
|
25 |
+
|
26 |
+
## Citation
|
27 |
+
If you find this work helpful, please consider citing:
|
28 |
+
```bibtex
|
29 |
+
@article{bai2024meissonic,
|
30 |
+
title={Meissonic: Revitalizing Masked Generative Transformers for Efficient High-Resolution Text-to-Image Synthesis},
|
31 |
+
author={Bai, Jinbin and Ye, Tian and Chow, Wei and Song, Enxin and Chen, Qing-Guo and Li, Xiangtai and Dong, Zhen and Zhu, Lei and Yan, Shuicheng},
|
32 |
+
journal={arXiv preprint arXiv:2410.08261},
|
33 |
+
year={2024}
|
34 |
+
}
|
35 |
+
```
|
assets/demos.pdf
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1d14191e0b8e9fdf4cb3a7199cf36554e60e456cdeba11509d305a8201e6b131
|
3 |
+
size 2476203
|
assets/demos.png
ADDED
Git LFS Details
|
inference.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
sys.path.append("./")
|
4 |
+
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from torchvision import transforms
|
8 |
+
from src.transformer import Transformer2DModel
|
9 |
+
from src.pipeline import Pipeline
|
10 |
+
from src.scheduler import Scheduler
|
11 |
+
from transformers import (
|
12 |
+
CLIPTextModelWithProjection,
|
13 |
+
CLIPTokenizer,
|
14 |
+
)
|
15 |
+
from diffusers import VQModel
|
16 |
+
|
17 |
+
device = 'cuda'
|
18 |
+
|
19 |
+
model_path = "MeissonFlow/Meissonic"
|
20 |
+
model = Transformer2DModel.from_pretrained(model_path,subfolder="transformer",)
|
21 |
+
vq_model = VQModel.from_pretrained(model_path, subfolder="vqvae", )
|
22 |
+
text_encoder = CLIPTextModelWithProjection.from_pretrained(model_path,subfolder="text_encoder",)
|
23 |
+
tokenizer = CLIPTokenizer.from_pretrained(model_path,subfolder="tokenizer",)
|
24 |
+
scheduler = Scheduler.from_pretrained(model_path,subfolder="scheduler",)
|
25 |
+
pipe=Pipeline(vq_model, tokenizer=tokenizer,text_encoder=text_encoder,transformer=model,scheduler=scheduler)
|
26 |
+
|
27 |
+
pipe = pipe.to(device)
|
28 |
+
|
29 |
+
steps = 48
|
30 |
+
CFG = 9
|
31 |
+
resolution = 1024
|
32 |
+
negative_prompts = "worst quality, normal quality, low quality, low res, blurry, distortion, text, watermark, logo, banner, extra digits, cropped, jpeg artifacts, signature, username, error, sketch, duplicate, ugly, monochrome, horror, geometry, mutation, disgusting, bad anatomy, bad proportions, bad quality, deformed, disconnected limbs, out of frame, out of focus, dehydrated, disfigured, extra arms, extra limbs, extra hands, fused fingers, gross proportions, long neck, jpeg, malformed limbs, mutated, mutated hands, mutated limbs, missing arms, missing fingers, picture frame, poorly drawn hands, poorly drawn face, collage, pixel, pixelated, grainy, color aberration, amputee, autograph, bad illustration, beyond the borders, blank background, body out of frame, boring background, branding, cut off, dismembered, disproportioned, distorted, draft, duplicated features, extra fingers, extra legs, fault, flaw, grains, hazy, identifying mark, improper scale, incorrect physiology, incorrect ratio, indistinct, kitsch, low resolution"
|
33 |
+
|
34 |
+
|
35 |
+
# A racoon wearing a suit smoking a cigar in the style of James Gurney.
|
36 |
+
# Medieval painting of a rat king.
|
37 |
+
# Oil portrait of Super Mario as a shaman tripping on mushrooms in a dark and detailed scene.
|
38 |
+
# A painting of a Persian cat dressed as a Renaissance king, standing on a skyscraper overlooking a city.
|
39 |
+
# A fluffy owl sits atop a stack of antique books in a detailed and moody illustration.
|
40 |
+
# A cosmonaut otter poses for a portrait painted in intricate detail by Rembrandt.
|
41 |
+
# A painting featuring a woman wearing virtual reality glasses and a bird, created by Dave McKean and Ivan Shishkin.
|
42 |
+
# A hyperrealist portrait of a fairy girl emperor wearing a crown and long starry robes.
|
43 |
+
# A psychedelic painting of a fantasy space whale.
|
44 |
+
# A monkey in a blue top hat painted in oil by Vincent van Gogh in the 1800s.
|
45 |
+
# A queen with red hair and a green and black dress stands veiled in a highly detailed and elegant digital painting.
|
46 |
+
# An oil painting of an anthropomorphic fox overlooking a village in the moor.
|
47 |
+
# A digital painting of an evil geisha in a bar.
|
48 |
+
# Digital painting of a furry deer character on FurAffinity.
|
49 |
+
# A highly detailed goddess portrait with a focus on the eyes.
|
50 |
+
# A cute young demon princess in a forest, depicted in digital painting.
|
51 |
+
# A red-haired queen wearing a green and black dress and veil is depicted in an intricate and elegant digital painting.
|
52 |
+
prompt = "A racoon wearing a suit smoking a cigar in the style of James Gurney."
|
53 |
+
|
54 |
+
image = pipe(prompt=prompt,negative_prompt=negative_prompts,height=resolution,width=resolution,guidance_scale=CFG,num_inference_steps=steps).images[0]
|
55 |
+
|
56 |
+
output_dir = "./output"
|
57 |
+
os.makedirs(output_dir, exist_ok=True)
|
58 |
+
image.save(output_dir, f"{prompt[:10]}_{resolution}_{steps}_{CFG}.png")
|
59 |
+
|
src/pipeline.py
ADDED
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Team and The MeissonFlow Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import sys
|
15 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
16 |
+
|
17 |
+
import torch
|
18 |
+
from transformers import CLIPTextModelWithProjection, CLIPTokenizer
|
19 |
+
|
20 |
+
from diffusers.image_processor import VaeImageProcessor
|
21 |
+
from diffusers.models import VQModel
|
22 |
+
|
23 |
+
from src.scheduler import Scheduler
|
24 |
+
from diffusers.utils import replace_example_docstring
|
25 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
26 |
+
|
27 |
+
from src.transformer import Transformer2DModel
|
28 |
+
|
29 |
+
|
30 |
+
EXAMPLE_DOC_STRING = """
|
31 |
+
Examples:
|
32 |
+
```py
|
33 |
+
>>> image = pipe(prompt).images[0]
|
34 |
+
```
|
35 |
+
"""
|
36 |
+
|
37 |
+
|
38 |
+
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
|
39 |
+
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
|
40 |
+
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
|
41 |
+
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
|
42 |
+
|
43 |
+
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
44 |
+
|
45 |
+
latent_image_ids = latent_image_ids.reshape(
|
46 |
+
latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
47 |
+
)
|
48 |
+
|
49 |
+
return latent_image_ids.to(device=device, dtype=dtype)
|
50 |
+
|
51 |
+
|
52 |
+
class Pipeline(DiffusionPipeline):
|
53 |
+
image_processor: VaeImageProcessor
|
54 |
+
vqvae: VQModel
|
55 |
+
tokenizer: CLIPTokenizer
|
56 |
+
text_encoder: CLIPTextModelWithProjection
|
57 |
+
transformer: Transformer2DModel
|
58 |
+
scheduler: Scheduler
|
59 |
+
# tokenizer_t5: T5Tokenizer
|
60 |
+
# text_encoder_t5: T5ForConditionalGeneration
|
61 |
+
|
62 |
+
model_cpu_offload_seq = "text_encoder->transformer->vqvae"
|
63 |
+
|
64 |
+
def __init__(
|
65 |
+
self,
|
66 |
+
vqvae: VQModel,
|
67 |
+
tokenizer: CLIPTokenizer,
|
68 |
+
text_encoder: CLIPTextModelWithProjection,
|
69 |
+
transformer: Transformer2DModel,
|
70 |
+
scheduler: Scheduler,
|
71 |
+
# tokenizer_t5: T5Tokenizer,
|
72 |
+
# text_encoder_t5: T5ForConditionalGeneration,
|
73 |
+
):
|
74 |
+
super().__init__()
|
75 |
+
|
76 |
+
self.register_modules(
|
77 |
+
vqvae=vqvae,
|
78 |
+
tokenizer=tokenizer,
|
79 |
+
text_encoder=text_encoder,
|
80 |
+
transformer=transformer,
|
81 |
+
scheduler=scheduler,
|
82 |
+
# tokenizer_t5=tokenizer_t5,
|
83 |
+
# text_encoder_t5=text_encoder_t5,
|
84 |
+
)
|
85 |
+
self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1)
|
86 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False)
|
87 |
+
|
88 |
+
@torch.no_grad()
|
89 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
90 |
+
def __call__(
|
91 |
+
self,
|
92 |
+
prompt: Optional[Union[List[str], str]] = None,
|
93 |
+
height: Optional[int] = 1024,
|
94 |
+
width: Optional[int] = 1024,
|
95 |
+
num_inference_steps: int = 48,
|
96 |
+
guidance_scale: float = 9.0,
|
97 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
98 |
+
num_images_per_prompt: Optional[int] = 1,
|
99 |
+
generator: Optional[torch.Generator] = None,
|
100 |
+
latents: Optional[torch.IntTensor] = None,
|
101 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
102 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
103 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
104 |
+
negative_encoder_hidden_states: Optional[torch.Tensor] = None,
|
105 |
+
output_type="pil",
|
106 |
+
return_dict: bool = True,
|
107 |
+
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
|
108 |
+
callback_steps: int = 1,
|
109 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
110 |
+
micro_conditioning_aesthetic_score: int = 6,
|
111 |
+
micro_conditioning_crop_coord: Tuple[int, int] = (0, 0),
|
112 |
+
temperature: Union[int, Tuple[int, int], List[int]] = (2, 0),
|
113 |
+
):
|
114 |
+
"""
|
115 |
+
The call function to the pipeline for generation.
|
116 |
+
|
117 |
+
Args:
|
118 |
+
prompt (`str` or `List[str]`, *optional*):
|
119 |
+
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
120 |
+
height (`int`, *optional*, defaults to `self.transformer.config.sample_size * self.vae_scale_factor`):
|
121 |
+
The height in pixels of the generated image.
|
122 |
+
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
123 |
+
The width in pixels of the generated image.
|
124 |
+
num_inference_steps (`int`, *optional*, defaults to 16):
|
125 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
126 |
+
expense of slower inference.
|
127 |
+
guidance_scale (`float`, *optional*, defaults to 10.0):
|
128 |
+
A higher guidance scale value encourages the model to generate images closely linked to the text
|
129 |
+
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
130 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
131 |
+
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
|
132 |
+
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
133 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
134 |
+
The number of images to generate per prompt.
|
135 |
+
generator (`torch.Generator`, *optional*):
|
136 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
137 |
+
generation deterministic.
|
138 |
+
latents (`torch.IntTensor`, *optional*):
|
139 |
+
Pre-generated tokens representing latent vectors in `self.vqvae`, to be used as inputs for image
|
140 |
+
gneration. If not provided, the starting latents will be completely masked.
|
141 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
142 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
143 |
+
provided, text embeddings are generated from the `prompt` input argument. A single vector from the
|
144 |
+
pooled and projected final hidden states.
|
145 |
+
encoder_hidden_states (`torch.Tensor`, *optional*):
|
146 |
+
Pre-generated penultimate hidden states from the text encoder providing additional text conditioning.
|
147 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
148 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
149 |
+
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
150 |
+
negative_encoder_hidden_states (`torch.Tensor`, *optional*):
|
151 |
+
Analogous to `encoder_hidden_states` for the positive prompt.
|
152 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
153 |
+
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
154 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
155 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
156 |
+
plain tuple.
|
157 |
+
callback (`Callable`, *optional*):
|
158 |
+
A function that calls every `callback_steps` steps during inference. The function is called with the
|
159 |
+
following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
|
160 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
161 |
+
The frequency at which the `callback` function is called. If not specified, the callback is called at
|
162 |
+
every step.
|
163 |
+
cross_attention_kwargs (`dict`, *optional*):
|
164 |
+
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
165 |
+
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
166 |
+
micro_conditioning_aesthetic_score (`int`, *optional*, defaults to 6):
|
167 |
+
The targeted aesthetic score according to the laion aesthetic classifier. See
|
168 |
+
https://laion.ai/blog/laion-aesthetics/ and the micro-conditioning section of
|
169 |
+
https://arxiv.org/abs/2307.01952.
|
170 |
+
micro_conditioning_crop_coord (`Tuple[int]`, *optional*, defaults to (0, 0)):
|
171 |
+
The targeted height, width crop coordinates. See the micro-conditioning section of
|
172 |
+
https://arxiv.org/abs/2307.01952.
|
173 |
+
temperature (`Union[int, Tuple[int, int], List[int]]`, *optional*, defaults to (2, 0)):
|
174 |
+
Configures the temperature scheduler on `self.scheduler` see `Scheduler#set_timesteps`.
|
175 |
+
|
176 |
+
Examples:
|
177 |
+
|
178 |
+
Returns:
|
179 |
+
[`~pipelines.pipeline_utils.ImagePipelineOutput`] or `tuple`:
|
180 |
+
If `return_dict` is `True`, [`~pipelines.pipeline_utils.ImagePipelineOutput`] is returned, otherwise a
|
181 |
+
`tuple` is returned where the first element is a list with the generated images.
|
182 |
+
"""
|
183 |
+
if (prompt_embeds is not None and encoder_hidden_states is None) or (
|
184 |
+
prompt_embeds is None and encoder_hidden_states is not None
|
185 |
+
):
|
186 |
+
raise ValueError("pass either both `prompt_embeds` and `encoder_hidden_states` or neither")
|
187 |
+
|
188 |
+
if (negative_prompt_embeds is not None and negative_encoder_hidden_states is None) or (
|
189 |
+
negative_prompt_embeds is None and negative_encoder_hidden_states is not None
|
190 |
+
):
|
191 |
+
raise ValueError(
|
192 |
+
"pass either both `negatve_prompt_embeds` and `negative_encoder_hidden_states` or neither"
|
193 |
+
)
|
194 |
+
|
195 |
+
if (prompt is None and prompt_embeds is None) or (prompt is not None and prompt_embeds is not None):
|
196 |
+
raise ValueError("pass only one of `prompt` or `prompt_embeds`")
|
197 |
+
|
198 |
+
if isinstance(prompt, str):
|
199 |
+
prompt = [prompt]
|
200 |
+
|
201 |
+
if prompt is not None:
|
202 |
+
batch_size = len(prompt)
|
203 |
+
else:
|
204 |
+
batch_size = prompt_embeds.shape[0]
|
205 |
+
|
206 |
+
batch_size = batch_size * num_images_per_prompt
|
207 |
+
|
208 |
+
if height is None:
|
209 |
+
height = self.transformer.config.sample_size * self.vae_scale_factor
|
210 |
+
|
211 |
+
if width is None:
|
212 |
+
width = self.transformer.config.sample_size * self.vae_scale_factor
|
213 |
+
|
214 |
+
if prompt_embeds is None:
|
215 |
+
input_ids = self.tokenizer(
|
216 |
+
prompt,
|
217 |
+
return_tensors="pt",
|
218 |
+
padding="max_length",
|
219 |
+
truncation=True,
|
220 |
+
max_length=77, #self.tokenizer.model_max_length,
|
221 |
+
).input_ids.to(self._execution_device)
|
222 |
+
# input_ids_t5 = self.tokenizer_t5(
|
223 |
+
# prompt,
|
224 |
+
# return_tensors="pt",
|
225 |
+
# padding="max_length",
|
226 |
+
# truncation=True,
|
227 |
+
# max_length=512,
|
228 |
+
# ).input_ids.to(self._execution_device)
|
229 |
+
|
230 |
+
|
231 |
+
outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True)
|
232 |
+
# outputs_t5 = self.text_encoder_t5(input_ids_t5, decoder_input_ids = input_ids_t5 ,return_dict=True, output_hidden_states=True)
|
233 |
+
prompt_embeds = outputs.text_embeds
|
234 |
+
encoder_hidden_states = outputs.hidden_states[-2]
|
235 |
+
# encoder_hidden_states = outputs_t5.encoder_hidden_states[-2]
|
236 |
+
|
237 |
+
prompt_embeds = prompt_embeds.repeat(num_images_per_prompt, 1)
|
238 |
+
encoder_hidden_states = encoder_hidden_states.repeat(num_images_per_prompt, 1, 1)
|
239 |
+
|
240 |
+
if guidance_scale > 1.0:
|
241 |
+
if negative_prompt_embeds is None:
|
242 |
+
if negative_prompt is None:
|
243 |
+
negative_prompt = [""] * len(prompt)
|
244 |
+
|
245 |
+
if isinstance(negative_prompt, str):
|
246 |
+
negative_prompt = [negative_prompt]
|
247 |
+
|
248 |
+
input_ids = self.tokenizer(
|
249 |
+
negative_prompt,
|
250 |
+
return_tensors="pt",
|
251 |
+
padding="max_length",
|
252 |
+
truncation=True,
|
253 |
+
max_length=77, #self.tokenizer.model_max_length,
|
254 |
+
).input_ids.to(self._execution_device)
|
255 |
+
# input_ids_t5 = self.tokenizer_t5(
|
256 |
+
# prompt,
|
257 |
+
# return_tensors="pt",
|
258 |
+
# padding="max_length",
|
259 |
+
# truncation=True,
|
260 |
+
# max_length=512,
|
261 |
+
# ).input_ids.to(self._execution_device)
|
262 |
+
|
263 |
+
outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True)
|
264 |
+
# outputs_t5 = self.text_encoder_t5(input_ids_t5, decoder_input_ids = input_ids_t5 ,return_dict=True, output_hidden_states=True)
|
265 |
+
negative_prompt_embeds = outputs.text_embeds
|
266 |
+
negative_encoder_hidden_states = outputs.hidden_states[-2]
|
267 |
+
# negative_encoder_hidden_states = outputs_t5.encoder_hidden_states[-2]
|
268 |
+
|
269 |
+
|
270 |
+
|
271 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(num_images_per_prompt, 1)
|
272 |
+
negative_encoder_hidden_states = negative_encoder_hidden_states.repeat(num_images_per_prompt, 1, 1)
|
273 |
+
|
274 |
+
prompt_embeds = torch.concat([negative_prompt_embeds, prompt_embeds])
|
275 |
+
encoder_hidden_states = torch.concat([negative_encoder_hidden_states, encoder_hidden_states])
|
276 |
+
|
277 |
+
# Note that the micro conditionings _do_ flip the order of width, height for the original size
|
278 |
+
# and the crop coordinates. This is how it was done in the original code base
|
279 |
+
micro_conds = torch.tensor(
|
280 |
+
[
|
281 |
+
width,
|
282 |
+
height,
|
283 |
+
micro_conditioning_crop_coord[0],
|
284 |
+
micro_conditioning_crop_coord[1],
|
285 |
+
micro_conditioning_aesthetic_score,
|
286 |
+
],
|
287 |
+
device=self._execution_device,
|
288 |
+
dtype=encoder_hidden_states.dtype,
|
289 |
+
)
|
290 |
+
micro_conds = micro_conds.unsqueeze(0)
|
291 |
+
micro_conds = micro_conds.expand(2 * batch_size if guidance_scale > 1.0 else batch_size, -1)
|
292 |
+
|
293 |
+
shape = (batch_size, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
294 |
+
|
295 |
+
if latents is None:
|
296 |
+
latents = torch.full(
|
297 |
+
shape, self.scheduler.config.mask_token_id, dtype=torch.long, device=self._execution_device
|
298 |
+
)
|
299 |
+
|
300 |
+
self.scheduler.set_timesteps(num_inference_steps, temperature, self._execution_device)
|
301 |
+
|
302 |
+
num_warmup_steps = len(self.scheduler.timesteps) - num_inference_steps * self.scheduler.order
|
303 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
304 |
+
for i, timestep in enumerate(self.scheduler.timesteps):
|
305 |
+
if guidance_scale > 1.0:
|
306 |
+
model_input = torch.cat([latents] * 2)
|
307 |
+
else:
|
308 |
+
model_input = latents
|
309 |
+
if height == 1024: #args.resolution == 1024:
|
310 |
+
img_ids = _prepare_latent_image_ids(model_input.shape[0], model_input.shape[-2],model_input.shape[-1],model_input.device,model_input.dtype)
|
311 |
+
else:
|
312 |
+
img_ids = _prepare_latent_image_ids(model_input.shape[0],2*model_input.shape[-2],2*model_input.shape[-1],model_input.device,model_input.dtype)
|
313 |
+
txt_ids = torch.zeros(encoder_hidden_states.shape[1],3).to(device = encoder_hidden_states.device, dtype = encoder_hidden_states.dtype)
|
314 |
+
model_output = self.transformer(
|
315 |
+
hidden_states = model_input,
|
316 |
+
micro_conds=micro_conds,
|
317 |
+
pooled_projections=prompt_embeds,
|
318 |
+
encoder_hidden_states=encoder_hidden_states,
|
319 |
+
img_ids = img_ids,
|
320 |
+
txt_ids = txt_ids,
|
321 |
+
timestep = torch.tensor([timestep], device=model_input.device, dtype=torch.long),
|
322 |
+
# guidance = 7,
|
323 |
+
# cross_attention_kwargs=cross_attention_kwargs,
|
324 |
+
)
|
325 |
+
|
326 |
+
if guidance_scale > 1.0:
|
327 |
+
uncond_logits, cond_logits = model_output.chunk(2)
|
328 |
+
model_output = uncond_logits + guidance_scale * (cond_logits - uncond_logits)
|
329 |
+
|
330 |
+
latents = self.scheduler.step(
|
331 |
+
model_output=model_output,
|
332 |
+
timestep=timestep,
|
333 |
+
sample=latents,
|
334 |
+
generator=generator,
|
335 |
+
).prev_sample
|
336 |
+
|
337 |
+
if i == len(self.scheduler.timesteps) - 1 or (
|
338 |
+
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
|
339 |
+
):
|
340 |
+
progress_bar.update()
|
341 |
+
if callback is not None and i % callback_steps == 0:
|
342 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
343 |
+
callback(step_idx, timestep, latents)
|
344 |
+
|
345 |
+
if output_type == "latent":
|
346 |
+
output = latents
|
347 |
+
else:
|
348 |
+
needs_upcasting = self.vqvae.dtype == torch.float16 and self.vqvae.config.force_upcast
|
349 |
+
|
350 |
+
if needs_upcasting:
|
351 |
+
self.vqvae.float()
|
352 |
+
|
353 |
+
output = self.vqvae.decode(
|
354 |
+
latents,
|
355 |
+
force_not_quantize=True,
|
356 |
+
shape=(
|
357 |
+
batch_size,
|
358 |
+
height // self.vae_scale_factor,
|
359 |
+
width // self.vae_scale_factor,
|
360 |
+
self.vqvae.config.latent_channels,
|
361 |
+
),
|
362 |
+
).sample.clip(0, 1)
|
363 |
+
output = self.image_processor.postprocess(output, output_type)
|
364 |
+
|
365 |
+
if needs_upcasting:
|
366 |
+
self.vqvae.half()
|
367 |
+
|
368 |
+
self.maybe_free_model_hooks()
|
369 |
+
|
370 |
+
if not return_dict:
|
371 |
+
return (output,)
|
372 |
+
|
373 |
+
return ImagePipelineOutput(output)
|
src/pipeline_img2img.py
ADDED
@@ -0,0 +1,353 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Team and The MeissonFlow Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
15 |
+
|
16 |
+
import torch
|
17 |
+
from transformers import CLIPTextModelWithProjection, CLIPTokenizer
|
18 |
+
|
19 |
+
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
20 |
+
from diffusers.models import UVit2DModel, VQModel
|
21 |
+
# from diffusers.schedulers import AmusedScheduler
|
22 |
+
from training.scheduling import Scheduler
|
23 |
+
from diffusers.utils import replace_example_docstring
|
24 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
25 |
+
|
26 |
+
from training.transformer import Transformer2DModel
|
27 |
+
|
28 |
+
EXAMPLE_DOC_STRING = """
|
29 |
+
Examples:
|
30 |
+
```py
|
31 |
+
>>> image = pipe(prompt, input_image).images[0]
|
32 |
+
```
|
33 |
+
"""
|
34 |
+
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
|
35 |
+
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
|
36 |
+
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
|
37 |
+
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
|
38 |
+
|
39 |
+
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
40 |
+
|
41 |
+
latent_image_ids = latent_image_ids.reshape(
|
42 |
+
latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
43 |
+
)
|
44 |
+
# latent_image_ids = latent_image_ids.unsqueeze(0).repeat(batch_size, 1, 1)
|
45 |
+
|
46 |
+
return latent_image_ids.to(device=device, dtype=dtype)
|
47 |
+
|
48 |
+
|
49 |
+
class Img2ImgPipeline(DiffusionPipeline):
|
50 |
+
image_processor: VaeImageProcessor
|
51 |
+
vqvae: VQModel
|
52 |
+
tokenizer: CLIPTokenizer
|
53 |
+
text_encoder: CLIPTextModelWithProjection
|
54 |
+
transformer: Transformer2DModel #UVit2DModel
|
55 |
+
scheduler: Scheduler
|
56 |
+
|
57 |
+
model_cpu_offload_seq = "text_encoder->transformer->vqvae"
|
58 |
+
|
59 |
+
# TODO - when calling self.vqvae.quantize, it uses self.vqvae.quantize.embedding.weight before
|
60 |
+
# the forward method of self.vqvae.quantize, so the hook doesn't get called to move the parameter
|
61 |
+
# off the meta device. There should be a way to fix this instead of just not offloading it
|
62 |
+
_exclude_from_cpu_offload = ["vqvae"]
|
63 |
+
|
64 |
+
def __init__(
|
65 |
+
self,
|
66 |
+
vqvae: VQModel,
|
67 |
+
tokenizer: CLIPTokenizer,
|
68 |
+
text_encoder: CLIPTextModelWithProjection,
|
69 |
+
transformer: Transformer2DModel, #UVit2DModel,
|
70 |
+
scheduler: Scheduler,
|
71 |
+
):
|
72 |
+
super().__init__()
|
73 |
+
|
74 |
+
self.register_modules(
|
75 |
+
vqvae=vqvae,
|
76 |
+
tokenizer=tokenizer,
|
77 |
+
text_encoder=text_encoder,
|
78 |
+
transformer=transformer,
|
79 |
+
scheduler=scheduler,
|
80 |
+
)
|
81 |
+
self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1)
|
82 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False)
|
83 |
+
|
84 |
+
@torch.no_grad()
|
85 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
86 |
+
def __call__(
|
87 |
+
self,
|
88 |
+
prompt: Optional[Union[List[str], str]] = None,
|
89 |
+
image: PipelineImageInput = None,
|
90 |
+
strength: float = 0.5,
|
91 |
+
num_inference_steps: int = 12,
|
92 |
+
guidance_scale: float = 10.0,
|
93 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
94 |
+
num_images_per_prompt: Optional[int] = 1,
|
95 |
+
generator: Optional[torch.Generator] = None,
|
96 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
97 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
98 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
99 |
+
negative_encoder_hidden_states: Optional[torch.Tensor] = None,
|
100 |
+
output_type="pil",
|
101 |
+
return_dict: bool = True,
|
102 |
+
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
|
103 |
+
callback_steps: int = 1,
|
104 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
105 |
+
micro_conditioning_aesthetic_score: int = 6,
|
106 |
+
micro_conditioning_crop_coord: Tuple[int, int] = (0, 0),
|
107 |
+
temperature: Union[int, Tuple[int, int], List[int]] = (2, 0),
|
108 |
+
):
|
109 |
+
"""
|
110 |
+
The call function to the pipeline for generation.
|
111 |
+
|
112 |
+
Args:
|
113 |
+
prompt (`str` or `List[str]`, *optional*):
|
114 |
+
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
115 |
+
image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
116 |
+
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
|
117 |
+
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
|
118 |
+
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
|
119 |
+
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
|
120 |
+
latents as `image`, but if passing latents directly it is not encoded again.
|
121 |
+
strength (`float`, *optional*, defaults to 0.5):
|
122 |
+
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
|
123 |
+
starting point and more noise is added the higher the `strength`. The number of denoising steps depends
|
124 |
+
on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
|
125 |
+
process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
|
126 |
+
essentially ignores `image`.
|
127 |
+
num_inference_steps (`int`, *optional*, defaults to 12):
|
128 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
129 |
+
expense of slower inference.
|
130 |
+
guidance_scale (`float`, *optional*, defaults to 10.0):
|
131 |
+
A higher guidance scale value encourages the model to generate images closely linked to the text
|
132 |
+
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
133 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
134 |
+
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
|
135 |
+
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
136 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
137 |
+
The number of images to generate per prompt.
|
138 |
+
generator (`torch.Generator`, *optional*):
|
139 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
140 |
+
generation deterministic.
|
141 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
142 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
143 |
+
provided, text embeddings are generated from the `prompt` input argument. A single vector from the
|
144 |
+
pooled and projected final hidden states.
|
145 |
+
encoder_hidden_states (`torch.Tensor`, *optional*):
|
146 |
+
Pre-generated penultimate hidden states from the text encoder providing additional text conditioning.
|
147 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
148 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
149 |
+
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
150 |
+
negative_encoder_hidden_states (`torch.Tensor`, *optional*):
|
151 |
+
Analogous to `encoder_hidden_states` for the positive prompt.
|
152 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
153 |
+
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
154 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
155 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
156 |
+
plain tuple.
|
157 |
+
callback (`Callable`, *optional*):
|
158 |
+
A function that calls every `callback_steps` steps during inference. The function is called with the
|
159 |
+
following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
|
160 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
161 |
+
The frequency at which the `callback` function is called. If not specified, the callback is called at
|
162 |
+
every step.
|
163 |
+
cross_attention_kwargs (`dict`, *optional*):
|
164 |
+
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
165 |
+
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
166 |
+
micro_conditioning_aesthetic_score (`int`, *optional*, defaults to 6):
|
167 |
+
The targeted aesthetic score according to the laion aesthetic classifier. See
|
168 |
+
https://laion.ai/blog/laion-aesthetics/ and the micro-conditioning section of
|
169 |
+
https://arxiv.org/abs/2307.01952.
|
170 |
+
micro_conditioning_crop_coord (`Tuple[int]`, *optional*, defaults to (0, 0)):
|
171 |
+
The targeted height, width crop coordinates. See the micro-conditioning section of
|
172 |
+
https://arxiv.org/abs/2307.01952.
|
173 |
+
temperature (`Union[int, Tuple[int, int], List[int]]`, *optional*, defaults to (2, 0)):
|
174 |
+
Configures the temperature scheduler on `self.scheduler` see `AmusedScheduler#set_timesteps`.
|
175 |
+
|
176 |
+
Examples:
|
177 |
+
|
178 |
+
Returns:
|
179 |
+
[`~pipelines.pipeline_utils.ImagePipelineOutput`] or `tuple`:
|
180 |
+
If `return_dict` is `True`, [`~pipelines.pipeline_utils.ImagePipelineOutput`] is returned, otherwise a
|
181 |
+
`tuple` is returned where the first element is a list with the generated images.
|
182 |
+
"""
|
183 |
+
|
184 |
+
if (prompt_embeds is not None and encoder_hidden_states is None) or (
|
185 |
+
prompt_embeds is None and encoder_hidden_states is not None
|
186 |
+
):
|
187 |
+
raise ValueError("pass either both `prompt_embeds` and `encoder_hidden_states` or neither")
|
188 |
+
|
189 |
+
if (negative_prompt_embeds is not None and negative_encoder_hidden_states is None) or (
|
190 |
+
negative_prompt_embeds is None and negative_encoder_hidden_states is not None
|
191 |
+
):
|
192 |
+
raise ValueError(
|
193 |
+
"pass either both `negative_prompt_embeds` and `negative_encoder_hidden_states` or neither"
|
194 |
+
)
|
195 |
+
|
196 |
+
if (prompt is None and prompt_embeds is None) or (prompt is not None and prompt_embeds is not None):
|
197 |
+
raise ValueError("pass only one of `prompt` or `prompt_embeds`")
|
198 |
+
|
199 |
+
if isinstance(prompt, str):
|
200 |
+
prompt = [prompt]
|
201 |
+
|
202 |
+
if prompt is not None:
|
203 |
+
batch_size = len(prompt)
|
204 |
+
else:
|
205 |
+
batch_size = prompt_embeds.shape[0]
|
206 |
+
|
207 |
+
batch_size = batch_size * num_images_per_prompt
|
208 |
+
|
209 |
+
if prompt_embeds is None:
|
210 |
+
input_ids = self.tokenizer(
|
211 |
+
prompt,
|
212 |
+
return_tensors="pt",
|
213 |
+
padding="max_length",
|
214 |
+
truncation=True,
|
215 |
+
max_length=77, #self.tokenizer.model_max_length,
|
216 |
+
).input_ids.to(self._execution_device)
|
217 |
+
|
218 |
+
outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True)
|
219 |
+
prompt_embeds = outputs.text_embeds
|
220 |
+
encoder_hidden_states = outputs.hidden_states[-2]
|
221 |
+
|
222 |
+
prompt_embeds = prompt_embeds.repeat(num_images_per_prompt, 1)
|
223 |
+
encoder_hidden_states = encoder_hidden_states.repeat(num_images_per_prompt, 1, 1)
|
224 |
+
|
225 |
+
if guidance_scale > 1.0:
|
226 |
+
if negative_prompt_embeds is None:
|
227 |
+
if negative_prompt is None:
|
228 |
+
negative_prompt = [""] * len(prompt)
|
229 |
+
|
230 |
+
if isinstance(negative_prompt, str):
|
231 |
+
negative_prompt = [negative_prompt]
|
232 |
+
|
233 |
+
input_ids = self.tokenizer(
|
234 |
+
negative_prompt,
|
235 |
+
return_tensors="pt",
|
236 |
+
padding="max_length",
|
237 |
+
truncation=True,
|
238 |
+
max_length=77, #self.tokenizer.model_max_length,
|
239 |
+
).input_ids.to(self._execution_device)
|
240 |
+
|
241 |
+
outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True)
|
242 |
+
negative_prompt_embeds = outputs.text_embeds
|
243 |
+
negative_encoder_hidden_states = outputs.hidden_states[-2]
|
244 |
+
|
245 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(num_images_per_prompt, 1)
|
246 |
+
negative_encoder_hidden_states = negative_encoder_hidden_states.repeat(num_images_per_prompt, 1, 1)
|
247 |
+
|
248 |
+
prompt_embeds = torch.concat([negative_prompt_embeds, prompt_embeds])
|
249 |
+
encoder_hidden_states = torch.concat([negative_encoder_hidden_states, encoder_hidden_states])
|
250 |
+
|
251 |
+
image = self.image_processor.preprocess(image)
|
252 |
+
|
253 |
+
height, width = image.shape[-2:]
|
254 |
+
|
255 |
+
# Note that the micro conditionings _do_ flip the order of width, height for the original size
|
256 |
+
# and the crop coordinates. This is how it was done in the original code base
|
257 |
+
micro_conds = torch.tensor(
|
258 |
+
[
|
259 |
+
width,
|
260 |
+
height,
|
261 |
+
micro_conditioning_crop_coord[0],
|
262 |
+
micro_conditioning_crop_coord[1],
|
263 |
+
micro_conditioning_aesthetic_score,
|
264 |
+
],
|
265 |
+
device=self._execution_device,
|
266 |
+
dtype=encoder_hidden_states.dtype,
|
267 |
+
)
|
268 |
+
|
269 |
+
micro_conds = micro_conds.unsqueeze(0)
|
270 |
+
micro_conds = micro_conds.expand(2 * batch_size if guidance_scale > 1.0 else batch_size, -1)
|
271 |
+
|
272 |
+
self.scheduler.set_timesteps(num_inference_steps, temperature, self._execution_device)
|
273 |
+
num_inference_steps = int(len(self.scheduler.timesteps) * strength)
|
274 |
+
start_timestep_idx = len(self.scheduler.timesteps) - num_inference_steps
|
275 |
+
|
276 |
+
needs_upcasting = False # = self.vqvae.dtype == torch.float16 and self.vqvae.config.force_upcast
|
277 |
+
|
278 |
+
if needs_upcasting:
|
279 |
+
self.vqvae.float()
|
280 |
+
|
281 |
+
latents = self.vqvae.encode(image.to(dtype=self.vqvae.dtype, device=self._execution_device)).latents
|
282 |
+
latents_bsz, channels, latents_height, latents_width = latents.shape
|
283 |
+
latents = self.vqvae.quantize(latents)[2][2].reshape(latents_bsz, latents_height, latents_width)
|
284 |
+
latents = self.scheduler.add_noise(
|
285 |
+
latents, self.scheduler.timesteps[start_timestep_idx - 1], generator=generator
|
286 |
+
)
|
287 |
+
latents = latents.repeat(num_images_per_prompt, 1, 1)
|
288 |
+
|
289 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
290 |
+
for i in range(start_timestep_idx, len(self.scheduler.timesteps)):
|
291 |
+
timestep = self.scheduler.timesteps[i]
|
292 |
+
|
293 |
+
if guidance_scale > 1.0:
|
294 |
+
model_input = torch.cat([latents] * 2)
|
295 |
+
else:
|
296 |
+
model_input = latents
|
297 |
+
if height == 1024: #args.resolution == 1024:
|
298 |
+
img_ids = _prepare_latent_image_ids(model_input.shape[0], model_input.shape[-2],model_input.shape[-1],model_input.device,model_input.dtype)
|
299 |
+
else:
|
300 |
+
img_ids = _prepare_latent_image_ids(model_input.shape[0],2*model_input.shape[-2],2*model_input.shape[-1],model_input.device,model_input.dtype)
|
301 |
+
txt_ids = torch.zeros(encoder_hidden_states.shape[1],3).to(device = encoder_hidden_states.device, dtype = encoder_hidden_states.dtype)
|
302 |
+
model_output = self.transformer(
|
303 |
+
model_input,
|
304 |
+
micro_conds=micro_conds,
|
305 |
+
pooled_projections=prompt_embeds,
|
306 |
+
encoder_hidden_states=encoder_hidden_states,
|
307 |
+
# cross_attention_kwargs=cross_attention_kwargs,
|
308 |
+
img_ids = img_ids,
|
309 |
+
txt_ids = txt_ids,
|
310 |
+
timestep = torch.tensor([timestep], device=model_input.device, dtype=torch.long),
|
311 |
+
)
|
312 |
+
|
313 |
+
if guidance_scale > 1.0:
|
314 |
+
uncond_logits, cond_logits = model_output.chunk(2)
|
315 |
+
model_output = uncond_logits + guidance_scale * (cond_logits - uncond_logits)
|
316 |
+
|
317 |
+
latents = self.scheduler.step(
|
318 |
+
model_output=model_output,
|
319 |
+
timestep=timestep,
|
320 |
+
sample=latents,
|
321 |
+
generator=generator,
|
322 |
+
).prev_sample
|
323 |
+
|
324 |
+
if i == len(self.scheduler.timesteps) - 1 or ((i + 1) % self.scheduler.order == 0):
|
325 |
+
progress_bar.update()
|
326 |
+
if callback is not None and i % callback_steps == 0:
|
327 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
328 |
+
callback(step_idx, timestep, latents)
|
329 |
+
|
330 |
+
if output_type == "latent":
|
331 |
+
output = latents
|
332 |
+
else:
|
333 |
+
output = self.vqvae.decode(
|
334 |
+
latents,
|
335 |
+
force_not_quantize=True,
|
336 |
+
shape=(
|
337 |
+
batch_size,
|
338 |
+
height // self.vae_scale_factor,
|
339 |
+
width // self.vae_scale_factor,
|
340 |
+
self.vqvae.config.latent_channels,
|
341 |
+
),
|
342 |
+
).sample.clip(0, 1)
|
343 |
+
output = self.image_processor.postprocess(output, output_type)
|
344 |
+
|
345 |
+
if needs_upcasting:
|
346 |
+
self.vqvae.half()
|
347 |
+
|
348 |
+
self.maybe_free_model_hooks()
|
349 |
+
|
350 |
+
if not return_dict:
|
351 |
+
return (output,)
|
352 |
+
|
353 |
+
return ImagePipelineOutput(output)
|
src/pipeline_inpaint.py
ADDED
@@ -0,0 +1,378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Team and The MeissonFlow Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
15 |
+
|
16 |
+
import torch
|
17 |
+
from transformers import CLIPTextModelWithProjection, CLIPTokenizer
|
18 |
+
|
19 |
+
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
20 |
+
from diffusers.models import UVit2DModel, VQModel
|
21 |
+
# from diffusers.schedulers import AmusedScheduler
|
22 |
+
from training.scheduling import Scheduler
|
23 |
+
from diffusers.utils import replace_example_docstring
|
24 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
25 |
+
|
26 |
+
from training.transformer import Transformer2DModel
|
27 |
+
|
28 |
+
EXAMPLE_DOC_STRING = """
|
29 |
+
Examples:
|
30 |
+
```py
|
31 |
+
>>> pipe(prompt, input_image, mask).images[0].save("out.png")
|
32 |
+
```
|
33 |
+
"""
|
34 |
+
|
35 |
+
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
|
36 |
+
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
|
37 |
+
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
|
38 |
+
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
|
39 |
+
|
40 |
+
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
41 |
+
|
42 |
+
latent_image_ids = latent_image_ids.reshape(
|
43 |
+
latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
44 |
+
)
|
45 |
+
# latent_image_ids = latent_image_ids.unsqueeze(0).repeat(batch_size, 1, 1)
|
46 |
+
|
47 |
+
return latent_image_ids.to(device=device, dtype=dtype)
|
48 |
+
|
49 |
+
|
50 |
+
class InpaintPipeline(DiffusionPipeline):
|
51 |
+
image_processor: VaeImageProcessor
|
52 |
+
vqvae: VQModel
|
53 |
+
tokenizer: CLIPTokenizer
|
54 |
+
text_encoder: CLIPTextModelWithProjection
|
55 |
+
transformer: Transformer2DModel #UVit2DModel
|
56 |
+
scheduler: Scheduler
|
57 |
+
|
58 |
+
model_cpu_offload_seq = "text_encoder->transformer->vqvae"
|
59 |
+
|
60 |
+
# TODO - when calling self.vqvae.quantize, it uses self.vqvae.quantize.embedding.weight before
|
61 |
+
# the forward method of self.vqvae.quantize, so the hook doesn't get called to move the parameter
|
62 |
+
# off the meta device. There should be a way to fix this instead of just not offloading it
|
63 |
+
_exclude_from_cpu_offload = ["vqvae"]
|
64 |
+
|
65 |
+
def __init__(
|
66 |
+
self,
|
67 |
+
vqvae: VQModel,
|
68 |
+
tokenizer: CLIPTokenizer,
|
69 |
+
text_encoder: CLIPTextModelWithProjection,
|
70 |
+
transformer: Transformer2DModel, #UVit2DModel,
|
71 |
+
scheduler: Scheduler,
|
72 |
+
):
|
73 |
+
super().__init__()
|
74 |
+
|
75 |
+
self.register_modules(
|
76 |
+
vqvae=vqvae,
|
77 |
+
tokenizer=tokenizer,
|
78 |
+
text_encoder=text_encoder,
|
79 |
+
transformer=transformer,
|
80 |
+
scheduler=scheduler,
|
81 |
+
)
|
82 |
+
self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1)
|
83 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False)
|
84 |
+
self.mask_processor = VaeImageProcessor(
|
85 |
+
vae_scale_factor=self.vae_scale_factor,
|
86 |
+
do_normalize=False,
|
87 |
+
do_binarize=True,
|
88 |
+
do_convert_grayscale=True,
|
89 |
+
do_resize=True,
|
90 |
+
)
|
91 |
+
self.scheduler.register_to_config(masking_schedule="linear")
|
92 |
+
|
93 |
+
@torch.no_grad()
|
94 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
95 |
+
def __call__(
|
96 |
+
self,
|
97 |
+
prompt: Optional[Union[List[str], str]] = None,
|
98 |
+
image: PipelineImageInput = None,
|
99 |
+
mask_image: PipelineImageInput = None,
|
100 |
+
strength: float = 1.0,
|
101 |
+
num_inference_steps: int = 12,
|
102 |
+
guidance_scale: float = 10.0,
|
103 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
104 |
+
num_images_per_prompt: Optional[int] = 1,
|
105 |
+
generator: Optional[torch.Generator] = None,
|
106 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
107 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
108 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
109 |
+
negative_encoder_hidden_states: Optional[torch.Tensor] = None,
|
110 |
+
output_type="pil",
|
111 |
+
return_dict: bool = True,
|
112 |
+
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
|
113 |
+
callback_steps: int = 1,
|
114 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
115 |
+
micro_conditioning_aesthetic_score: int = 6,
|
116 |
+
micro_conditioning_crop_coord: Tuple[int, int] = (0, 0),
|
117 |
+
temperature: Union[int, Tuple[int, int], List[int]] = (2, 0),
|
118 |
+
):
|
119 |
+
"""
|
120 |
+
The call function to the pipeline for generation.
|
121 |
+
|
122 |
+
Args:
|
123 |
+
prompt (`str` or `List[str]`, *optional*):
|
124 |
+
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
125 |
+
image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
126 |
+
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
|
127 |
+
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
|
128 |
+
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
|
129 |
+
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
|
130 |
+
latents as `image`, but if passing latents directly it is not encoded again.
|
131 |
+
mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
132 |
+
`Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask
|
133 |
+
are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a
|
134 |
+
single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one
|
135 |
+
color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B,
|
136 |
+
H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W,
|
137 |
+
1)`, or `(H, W)`.
|
138 |
+
strength (`float`, *optional*, defaults to 1.0):
|
139 |
+
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
|
140 |
+
starting point and more noise is added the higher the `strength`. The number of denoising steps depends
|
141 |
+
on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
|
142 |
+
process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
|
143 |
+
essentially ignores `image`.
|
144 |
+
num_inference_steps (`int`, *optional*, defaults to 16):
|
145 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
146 |
+
expense of slower inference.
|
147 |
+
guidance_scale (`float`, *optional*, defaults to 10.0):
|
148 |
+
A higher guidance scale value encourages the model to generate images closely linked to the text
|
149 |
+
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
150 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
151 |
+
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
|
152 |
+
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
153 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
154 |
+
The number of images to generate per prompt.
|
155 |
+
generator (`torch.Generator`, *optional*):
|
156 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
157 |
+
generation deterministic.
|
158 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
159 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
160 |
+
provided, text embeddings are generated from the `prompt` input argument. A single vector from the
|
161 |
+
pooled and projected final hidden states.
|
162 |
+
encoder_hidden_states (`torch.Tensor`, *optional*):
|
163 |
+
Pre-generated penultimate hidden states from the text encoder providing additional text conditioning.
|
164 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
165 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
166 |
+
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
167 |
+
negative_encoder_hidden_states (`torch.Tensor`, *optional*):
|
168 |
+
Analogous to `encoder_hidden_states` for the positive prompt.
|
169 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
170 |
+
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
171 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
172 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
173 |
+
plain tuple.
|
174 |
+
callback (`Callable`, *optional*):
|
175 |
+
A function that calls every `callback_steps` steps during inference. The function is called with the
|
176 |
+
following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
|
177 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
178 |
+
The frequency at which the `callback` function is called. If not specified, the callback is called at
|
179 |
+
every step.
|
180 |
+
cross_attention_kwargs (`dict`, *optional*):
|
181 |
+
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
182 |
+
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
183 |
+
micro_conditioning_aesthetic_score (`int`, *optional*, defaults to 6):
|
184 |
+
The targeted aesthetic score according to the laion aesthetic classifier. See
|
185 |
+
https://laion.ai/blog/laion-aesthetics/ and the micro-conditioning section of
|
186 |
+
https://arxiv.org/abs/2307.01952.
|
187 |
+
micro_conditioning_crop_coord (`Tuple[int]`, *optional*, defaults to (0, 0)):
|
188 |
+
The targeted height, width crop coordinates. See the micro-conditioning section of
|
189 |
+
https://arxiv.org/abs/2307.01952.
|
190 |
+
temperature (`Union[int, Tuple[int, int], List[int]]`, *optional*, defaults to (2, 0)):
|
191 |
+
Configures the temperature scheduler on `self.scheduler` see `AmusedScheduler#set_timesteps`.
|
192 |
+
|
193 |
+
Examples:
|
194 |
+
|
195 |
+
Returns:
|
196 |
+
[`~pipelines.pipeline_utils.ImagePipelineOutput`] or `tuple`:
|
197 |
+
If `return_dict` is `True`, [`~pipelines.pipeline_utils.ImagePipelineOutput`] is returned, otherwise a
|
198 |
+
`tuple` is returned where the first element is a list with the generated images.
|
199 |
+
"""
|
200 |
+
|
201 |
+
if (prompt_embeds is not None and encoder_hidden_states is None) or (
|
202 |
+
prompt_embeds is None and encoder_hidden_states is not None
|
203 |
+
):
|
204 |
+
raise ValueError("pass either both `prompt_embeds` and `encoder_hidden_states` or neither")
|
205 |
+
|
206 |
+
if (negative_prompt_embeds is not None and negative_encoder_hidden_states is None) or (
|
207 |
+
negative_prompt_embeds is None and negative_encoder_hidden_states is not None
|
208 |
+
):
|
209 |
+
raise ValueError(
|
210 |
+
"pass either both `negatve_prompt_embeds` and `negative_encoder_hidden_states` or neither"
|
211 |
+
)
|
212 |
+
|
213 |
+
if (prompt is None and prompt_embeds is None) or (prompt is not None and prompt_embeds is not None):
|
214 |
+
raise ValueError("pass only one of `prompt` or `prompt_embeds`")
|
215 |
+
|
216 |
+
if isinstance(prompt, str):
|
217 |
+
prompt = [prompt]
|
218 |
+
|
219 |
+
if prompt is not None:
|
220 |
+
batch_size = len(prompt)
|
221 |
+
else:
|
222 |
+
batch_size = prompt_embeds.shape[0]
|
223 |
+
|
224 |
+
batch_size = batch_size * num_images_per_prompt
|
225 |
+
|
226 |
+
if prompt_embeds is None:
|
227 |
+
input_ids = self.tokenizer(
|
228 |
+
prompt,
|
229 |
+
return_tensors="pt",
|
230 |
+
padding="max_length",
|
231 |
+
truncation=True,
|
232 |
+
max_length=77, #self.tokenizer.model_max_length,
|
233 |
+
).input_ids.to(self._execution_device)
|
234 |
+
|
235 |
+
outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True)
|
236 |
+
prompt_embeds = outputs.text_embeds
|
237 |
+
encoder_hidden_states = outputs.hidden_states[-2]
|
238 |
+
|
239 |
+
prompt_embeds = prompt_embeds.repeat(num_images_per_prompt, 1)
|
240 |
+
encoder_hidden_states = encoder_hidden_states.repeat(num_images_per_prompt, 1, 1)
|
241 |
+
|
242 |
+
if guidance_scale > 1.0:
|
243 |
+
if negative_prompt_embeds is None:
|
244 |
+
if negative_prompt is None:
|
245 |
+
negative_prompt = [""] * len(prompt)
|
246 |
+
|
247 |
+
if isinstance(negative_prompt, str):
|
248 |
+
negative_prompt = [negative_prompt]
|
249 |
+
|
250 |
+
input_ids = self.tokenizer(
|
251 |
+
negative_prompt,
|
252 |
+
return_tensors="pt",
|
253 |
+
padding="max_length",
|
254 |
+
truncation=True,
|
255 |
+
max_length=77, #self.tokenizer.model_max_length,
|
256 |
+
).input_ids.to(self._execution_device)
|
257 |
+
|
258 |
+
outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True)
|
259 |
+
negative_prompt_embeds = outputs.text_embeds
|
260 |
+
negative_encoder_hidden_states = outputs.hidden_states[-2]
|
261 |
+
|
262 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(num_images_per_prompt, 1)
|
263 |
+
negative_encoder_hidden_states = negative_encoder_hidden_states.repeat(num_images_per_prompt, 1, 1)
|
264 |
+
|
265 |
+
prompt_embeds = torch.concat([negative_prompt_embeds, prompt_embeds])
|
266 |
+
encoder_hidden_states = torch.concat([negative_encoder_hidden_states, encoder_hidden_states])
|
267 |
+
|
268 |
+
image = self.image_processor.preprocess(image)
|
269 |
+
|
270 |
+
height, width = image.shape[-2:]
|
271 |
+
|
272 |
+
# Note that the micro conditionings _do_ flip the order of width, height for the original size
|
273 |
+
# and the crop coordinates. This is how it was done in the original code base
|
274 |
+
micro_conds = torch.tensor(
|
275 |
+
[
|
276 |
+
width,
|
277 |
+
height,
|
278 |
+
micro_conditioning_crop_coord[0],
|
279 |
+
micro_conditioning_crop_coord[1],
|
280 |
+
micro_conditioning_aesthetic_score,
|
281 |
+
],
|
282 |
+
device=self._execution_device,
|
283 |
+
dtype=encoder_hidden_states.dtype,
|
284 |
+
)
|
285 |
+
|
286 |
+
micro_conds = micro_conds.unsqueeze(0)
|
287 |
+
micro_conds = micro_conds.expand(2 * batch_size if guidance_scale > 1.0 else batch_size, -1)
|
288 |
+
|
289 |
+
self.scheduler.set_timesteps(num_inference_steps, temperature, self._execution_device)
|
290 |
+
num_inference_steps = int(len(self.scheduler.timesteps) * strength)
|
291 |
+
start_timestep_idx = len(self.scheduler.timesteps) - num_inference_steps
|
292 |
+
|
293 |
+
needs_upcasting = False #self.vqvae.dtype == torch.float16 and self.vqvae.config.force_upcast
|
294 |
+
|
295 |
+
if needs_upcasting:
|
296 |
+
self.vqvae.float()
|
297 |
+
|
298 |
+
latents = self.vqvae.encode(image.to(dtype=self.vqvae.dtype, device=self._execution_device)).latents
|
299 |
+
latents_bsz, channels, latents_height, latents_width = latents.shape
|
300 |
+
latents = self.vqvae.quantize(latents)[2][2].reshape(latents_bsz, latents_height, latents_width)
|
301 |
+
|
302 |
+
mask = self.mask_processor.preprocess(
|
303 |
+
mask_image, height // self.vae_scale_factor, width // self.vae_scale_factor
|
304 |
+
)
|
305 |
+
mask = mask.reshape(mask.shape[0], latents_height, latents_width).bool().to(latents.device)
|
306 |
+
latents[mask] = self.scheduler.config.mask_token_id
|
307 |
+
|
308 |
+
starting_mask_ratio = mask.sum() / latents.numel()
|
309 |
+
|
310 |
+
latents = latents.repeat(num_images_per_prompt, 1, 1)
|
311 |
+
|
312 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
313 |
+
for i in range(start_timestep_idx, len(self.scheduler.timesteps)):
|
314 |
+
timestep = self.scheduler.timesteps[i]
|
315 |
+
|
316 |
+
if guidance_scale > 1.0:
|
317 |
+
model_input = torch.cat([latents] * 2)
|
318 |
+
else:
|
319 |
+
model_input = latents
|
320 |
+
|
321 |
+
if height == 1024: #args.resolution == 1024:
|
322 |
+
img_ids = _prepare_latent_image_ids(model_input.shape[0], model_input.shape[-2],model_input.shape[-1],model_input.device,model_input.dtype)
|
323 |
+
else:
|
324 |
+
img_ids = _prepare_latent_image_ids(model_input.shape[0],2*model_input.shape[-2],2*model_input.shape[-1],model_input.device,model_input.dtype)
|
325 |
+
txt_ids = torch.zeros(encoder_hidden_states.shape[1],3).to(device = encoder_hidden_states.device, dtype = encoder_hidden_states.dtype)
|
326 |
+
model_output = self.transformer(
|
327 |
+
model_input,
|
328 |
+
micro_conds=micro_conds,
|
329 |
+
pooled_projections=prompt_embeds,
|
330 |
+
encoder_hidden_states=encoder_hidden_states,
|
331 |
+
# cross_attention_kwargs=cross_attention_kwargs,
|
332 |
+
img_ids = img_ids,
|
333 |
+
txt_ids = txt_ids,
|
334 |
+
timestep = torch.tensor([timestep], device=model_input.device, dtype=torch.long),
|
335 |
+
)
|
336 |
+
|
337 |
+
if guidance_scale > 1.0:
|
338 |
+
uncond_logits, cond_logits = model_output.chunk(2)
|
339 |
+
model_output = uncond_logits + guidance_scale * (cond_logits - uncond_logits)
|
340 |
+
|
341 |
+
latents = self.scheduler.step(
|
342 |
+
model_output=model_output,
|
343 |
+
timestep=timestep,
|
344 |
+
sample=latents,
|
345 |
+
generator=generator,
|
346 |
+
starting_mask_ratio=starting_mask_ratio,
|
347 |
+
).prev_sample
|
348 |
+
|
349 |
+
if i == len(self.scheduler.timesteps) - 1 or ((i + 1) % self.scheduler.order == 0):
|
350 |
+
progress_bar.update()
|
351 |
+
if callback is not None and i % callback_steps == 0:
|
352 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
353 |
+
callback(step_idx, timestep, latents)
|
354 |
+
|
355 |
+
if output_type == "latent":
|
356 |
+
output = latents
|
357 |
+
else:
|
358 |
+
output = self.vqvae.decode(
|
359 |
+
latents,
|
360 |
+
force_not_quantize=True,
|
361 |
+
shape=(
|
362 |
+
batch_size,
|
363 |
+
height // self.vae_scale_factor,
|
364 |
+
width // self.vae_scale_factor,
|
365 |
+
self.vqvae.config.latent_channels,
|
366 |
+
),
|
367 |
+
).sample.clip(0, 1)
|
368 |
+
output = self.image_processor.postprocess(output, output_type)
|
369 |
+
|
370 |
+
if needs_upcasting:
|
371 |
+
self.vqvae.half()
|
372 |
+
|
373 |
+
self.maybe_free_model_hooks()
|
374 |
+
|
375 |
+
if not return_dict:
|
376 |
+
return (output,)
|
377 |
+
|
378 |
+
return ImagePipelineOutput(output)
|
src/scheduler.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Team and The MeissonFlow Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import math
|
15 |
+
from dataclasses import dataclass
|
16 |
+
from typing import List, Optional, Tuple, Union
|
17 |
+
|
18 |
+
import torch
|
19 |
+
|
20 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
21 |
+
from diffusers.utils import BaseOutput
|
22 |
+
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
23 |
+
|
24 |
+
|
25 |
+
def gumbel_noise(t, generator=None):
|
26 |
+
device = generator.device if generator is not None else t.device
|
27 |
+
noise = torch.zeros_like(t, device=device).uniform_(0, 1, generator=generator).to(t.device)
|
28 |
+
return -torch.log((-torch.log(noise.clamp(1e-20))).clamp(1e-20))
|
29 |
+
|
30 |
+
|
31 |
+
def mask_by_random_topk(mask_len, probs, temperature=1.0, generator=None):
|
32 |
+
confidence = torch.log(probs.clamp(1e-20)) + temperature * gumbel_noise(probs, generator=generator)
|
33 |
+
sorted_confidence = torch.sort(confidence, dim=-1).values
|
34 |
+
cut_off = torch.gather(sorted_confidence, 1, mask_len.long())
|
35 |
+
masking = confidence < cut_off
|
36 |
+
return masking
|
37 |
+
|
38 |
+
|
39 |
+
@dataclass
|
40 |
+
class SchedulerOutput(BaseOutput):
|
41 |
+
"""
|
42 |
+
Output class for the scheduler's `step` function output.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
|
46 |
+
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
47 |
+
denoising loop.
|
48 |
+
pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
|
49 |
+
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
|
50 |
+
`pred_original_sample` can be used to preview progress or for guidance.
|
51 |
+
"""
|
52 |
+
|
53 |
+
prev_sample: torch.Tensor
|
54 |
+
pred_original_sample: torch.Tensor = None
|
55 |
+
|
56 |
+
|
57 |
+
class Scheduler(SchedulerMixin, ConfigMixin):
|
58 |
+
order = 1
|
59 |
+
|
60 |
+
temperatures: torch.Tensor
|
61 |
+
|
62 |
+
@register_to_config
|
63 |
+
def __init__(
|
64 |
+
self,
|
65 |
+
mask_token_id: int,
|
66 |
+
masking_schedule: str = "cosine",
|
67 |
+
):
|
68 |
+
self.temperatures = None
|
69 |
+
self.timesteps = None
|
70 |
+
|
71 |
+
def set_timesteps(
|
72 |
+
self,
|
73 |
+
num_inference_steps: int,
|
74 |
+
temperature: Union[int, Tuple[int, int], List[int]] = (2, 0),
|
75 |
+
device: Union[str, torch.device] = None,
|
76 |
+
):
|
77 |
+
self.timesteps = torch.arange(num_inference_steps, device=device).flip(0)
|
78 |
+
|
79 |
+
if isinstance(temperature, (tuple, list)):
|
80 |
+
self.temperatures = torch.linspace(temperature[0], temperature[1], num_inference_steps, device=device)
|
81 |
+
else:
|
82 |
+
self.temperatures = torch.linspace(temperature, 0.01, num_inference_steps, device=device)
|
83 |
+
|
84 |
+
def step(
|
85 |
+
self,
|
86 |
+
model_output: torch.Tensor,
|
87 |
+
timestep: torch.long,
|
88 |
+
sample: torch.LongTensor,
|
89 |
+
starting_mask_ratio: int = 1,
|
90 |
+
generator: Optional[torch.Generator] = None,
|
91 |
+
return_dict: bool = True,
|
92 |
+
) -> Union[SchedulerOutput, Tuple]:
|
93 |
+
two_dim_input = sample.ndim == 3 and model_output.ndim == 4
|
94 |
+
|
95 |
+
if two_dim_input:
|
96 |
+
batch_size, codebook_size, height, width = model_output.shape
|
97 |
+
sample = sample.reshape(batch_size, height * width)
|
98 |
+
model_output = model_output.reshape(batch_size, codebook_size, height * width).permute(0, 2, 1)
|
99 |
+
|
100 |
+
unknown_map = sample == self.config.mask_token_id
|
101 |
+
|
102 |
+
probs = model_output.softmax(dim=-1)
|
103 |
+
|
104 |
+
device = probs.device
|
105 |
+
probs_ = probs.to(generator.device) if generator is not None else probs # handles when generator is on CPU
|
106 |
+
if probs_.device.type == "cpu" and probs_.dtype != torch.float32:
|
107 |
+
probs_ = probs_.float() # multinomial is not implemented for cpu half precision
|
108 |
+
probs_ = probs_.reshape(-1, probs.size(-1))
|
109 |
+
pred_original_sample = torch.multinomial(probs_, 1, generator=generator).to(device=device)
|
110 |
+
pred_original_sample = pred_original_sample[:, 0].view(*probs.shape[:-1])
|
111 |
+
pred_original_sample = torch.where(unknown_map, pred_original_sample, sample)
|
112 |
+
|
113 |
+
if timestep == 0:
|
114 |
+
prev_sample = pred_original_sample
|
115 |
+
else:
|
116 |
+
seq_len = sample.shape[1]
|
117 |
+
step_idx = (self.timesteps == timestep).nonzero()
|
118 |
+
ratio = (step_idx + 1) / len(self.timesteps)
|
119 |
+
|
120 |
+
if self.config.masking_schedule == "cosine":
|
121 |
+
mask_ratio = torch.cos(ratio * math.pi / 2)
|
122 |
+
elif self.config.masking_schedule == "linear":
|
123 |
+
mask_ratio = 1 - ratio
|
124 |
+
else:
|
125 |
+
raise ValueError(f"unknown masking schedule {self.config.masking_schedule}")
|
126 |
+
|
127 |
+
mask_ratio = starting_mask_ratio * mask_ratio
|
128 |
+
|
129 |
+
mask_len = (seq_len * mask_ratio).floor()
|
130 |
+
# do not mask more than amount previously masked
|
131 |
+
mask_len = torch.min(unknown_map.sum(dim=-1, keepdim=True) - 1, mask_len)
|
132 |
+
# mask at least one
|
133 |
+
mask_len = torch.max(torch.tensor([1], device=model_output.device), mask_len)
|
134 |
+
|
135 |
+
selected_probs = torch.gather(probs, -1, pred_original_sample[:, :, None])[:, :, 0]
|
136 |
+
# Ignores the tokens given in the input by overwriting their confidence.
|
137 |
+
selected_probs = torch.where(unknown_map, selected_probs, torch.finfo(selected_probs.dtype).max)
|
138 |
+
|
139 |
+
masking = mask_by_random_topk(mask_len, selected_probs, self.temperatures[step_idx], generator)
|
140 |
+
|
141 |
+
# Masks tokens with lower confidence.
|
142 |
+
prev_sample = torch.where(masking, self.config.mask_token_id, pred_original_sample)
|
143 |
+
|
144 |
+
if two_dim_input:
|
145 |
+
prev_sample = prev_sample.reshape(batch_size, height, width)
|
146 |
+
pred_original_sample = pred_original_sample.reshape(batch_size, height, width)
|
147 |
+
|
148 |
+
if not return_dict:
|
149 |
+
return (prev_sample, pred_original_sample)
|
150 |
+
|
151 |
+
return SchedulerOutput(prev_sample, pred_original_sample)
|
152 |
+
|
153 |
+
def add_noise(self, sample, timesteps, generator=None):
|
154 |
+
step_idx = (self.timesteps == timesteps).nonzero()
|
155 |
+
ratio = (step_idx + 1) / len(self.timesteps)
|
156 |
+
|
157 |
+
if self.config.masking_schedule == "cosine":
|
158 |
+
mask_ratio = torch.cos(ratio * math.pi / 2)
|
159 |
+
elif self.config.masking_schedule == "linear":
|
160 |
+
mask_ratio = 1 - ratio
|
161 |
+
else:
|
162 |
+
raise ValueError(f"unknown masking schedule {self.config.masking_schedule}")
|
163 |
+
|
164 |
+
mask_indices = (
|
165 |
+
torch.rand(
|
166 |
+
sample.shape, device=generator.device if generator is not None else sample.device, generator=generator
|
167 |
+
).to(sample.device)
|
168 |
+
< mask_ratio
|
169 |
+
)
|
170 |
+
|
171 |
+
masked_sample = sample.clone()
|
172 |
+
|
173 |
+
masked_sample[mask_indices] = self.config.mask_token_id
|
174 |
+
|
175 |
+
return masked_sample
|
src/transformer.py
ADDED
@@ -0,0 +1,1215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Black Forest Labs, The HuggingFace Team, The InstantX Team and The MeissonFlow Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
17 |
+
|
18 |
+
import numpy as np
|
19 |
+
import torch
|
20 |
+
import torch.nn as nn
|
21 |
+
import torch.nn.functional as F
|
22 |
+
|
23 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
24 |
+
from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
|
25 |
+
from diffusers.models.attention import FeedForward, BasicTransformerBlock, SkipFFTransformerBlock
|
26 |
+
from diffusers.models.attention_processor import (
|
27 |
+
Attention,
|
28 |
+
AttentionProcessor,
|
29 |
+
FluxAttnProcessor2_0,
|
30 |
+
# FusedFluxAttnProcessor2_0,
|
31 |
+
)
|
32 |
+
from diffusers.models.modeling_utils import ModelMixin
|
33 |
+
from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle, GlobalResponseNorm, RMSNorm
|
34 |
+
from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
35 |
+
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
36 |
+
from diffusers.models.embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings,TimestepEmbedding, get_timestep_embedding #,FluxPosEmbed
|
37 |
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
38 |
+
from diffusers.models.resnet import Downsample2D, Upsample2D
|
39 |
+
|
40 |
+
from typing import List
|
41 |
+
|
42 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
43 |
+
|
44 |
+
|
45 |
+
|
46 |
+
def get_3d_rotary_pos_embed(
|
47 |
+
embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True
|
48 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
49 |
+
"""
|
50 |
+
RoPE for video tokens with 3D structure.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
embed_dim: (`int`):
|
54 |
+
The embedding dimension size, corresponding to hidden_size_head.
|
55 |
+
crops_coords (`Tuple[int]`):
|
56 |
+
The top-left and bottom-right coordinates of the crop.
|
57 |
+
grid_size (`Tuple[int]`):
|
58 |
+
The grid size of the spatial positional embedding (height, width).
|
59 |
+
temporal_size (`int`):
|
60 |
+
The size of the temporal dimension.
|
61 |
+
theta (`float`):
|
62 |
+
Scaling factor for frequency computation.
|
63 |
+
use_real (`bool`):
|
64 |
+
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
65 |
+
|
66 |
+
Returns:
|
67 |
+
`torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
|
68 |
+
"""
|
69 |
+
start, stop = crops_coords
|
70 |
+
grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
|
71 |
+
grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32)
|
72 |
+
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
|
73 |
+
|
74 |
+
# Compute dimensions for each axis
|
75 |
+
dim_t = embed_dim // 4
|
76 |
+
dim_h = embed_dim // 8 * 3
|
77 |
+
dim_w = embed_dim // 8 * 3
|
78 |
+
|
79 |
+
# Temporal frequencies
|
80 |
+
freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2).float() / dim_t))
|
81 |
+
grid_t = torch.from_numpy(grid_t).float()
|
82 |
+
freqs_t = torch.einsum("n , f -> n f", grid_t, freqs_t)
|
83 |
+
freqs_t = freqs_t.repeat_interleave(2, dim=-1)
|
84 |
+
|
85 |
+
# Spatial frequencies for height and width
|
86 |
+
freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2).float() / dim_h))
|
87 |
+
freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2).float() / dim_w))
|
88 |
+
grid_h = torch.from_numpy(grid_h).float()
|
89 |
+
grid_w = torch.from_numpy(grid_w).float()
|
90 |
+
freqs_h = torch.einsum("n , f -> n f", grid_h, freqs_h)
|
91 |
+
freqs_w = torch.einsum("n , f -> n f", grid_w, freqs_w)
|
92 |
+
freqs_h = freqs_h.repeat_interleave(2, dim=-1)
|
93 |
+
freqs_w = freqs_w.repeat_interleave(2, dim=-1)
|
94 |
+
|
95 |
+
# Broadcast and concatenate tensors along specified dimension
|
96 |
+
def broadcast(tensors, dim=-1):
|
97 |
+
num_tensors = len(tensors)
|
98 |
+
shape_lens = {len(t.shape) for t in tensors}
|
99 |
+
assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
|
100 |
+
shape_len = list(shape_lens)[0]
|
101 |
+
dim = (dim + shape_len) if dim < 0 else dim
|
102 |
+
dims = list(zip(*(list(t.shape) for t in tensors)))
|
103 |
+
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
|
104 |
+
assert all(
|
105 |
+
[*(len(set(t[1])) <= 2 for t in expandable_dims)]
|
106 |
+
), "invalid dimensions for broadcastable concatenation"
|
107 |
+
max_dims = [(t[0], max(t[1])) for t in expandable_dims]
|
108 |
+
expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims]
|
109 |
+
expanded_dims.insert(dim, (dim, dims[dim]))
|
110 |
+
expandable_shapes = list(zip(*(t[1] for t in expanded_dims)))
|
111 |
+
tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)]
|
112 |
+
return torch.cat(tensors, dim=dim)
|
113 |
+
|
114 |
+
freqs = broadcast((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
|
115 |
+
|
116 |
+
t, h, w, d = freqs.shape
|
117 |
+
freqs = freqs.view(t * h * w, d)
|
118 |
+
|
119 |
+
# Generate sine and cosine components
|
120 |
+
sin = freqs.sin()
|
121 |
+
cos = freqs.cos()
|
122 |
+
|
123 |
+
if use_real:
|
124 |
+
return cos, sin
|
125 |
+
else:
|
126 |
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
127 |
+
return freqs_cis
|
128 |
+
|
129 |
+
|
130 |
+
def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
|
131 |
+
"""
|
132 |
+
RoPE for image tokens with 2d structure.
|
133 |
+
|
134 |
+
Args:
|
135 |
+
embed_dim: (`int`):
|
136 |
+
The embedding dimension size
|
137 |
+
crops_coords (`Tuple[int]`)
|
138 |
+
The top-left and bottom-right coordinates of the crop.
|
139 |
+
grid_size (`Tuple[int]`):
|
140 |
+
The grid size of the positional embedding.
|
141 |
+
use_real (`bool`):
|
142 |
+
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
143 |
+
|
144 |
+
Returns:
|
145 |
+
`torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`.
|
146 |
+
"""
|
147 |
+
start, stop = crops_coords
|
148 |
+
grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
|
149 |
+
grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32)
|
150 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
151 |
+
grid = np.stack(grid, axis=0) # [2, W, H]
|
152 |
+
|
153 |
+
grid = grid.reshape([2, 1, *grid.shape[1:]])
|
154 |
+
pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real)
|
155 |
+
return pos_embed
|
156 |
+
|
157 |
+
|
158 |
+
def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
|
159 |
+
assert embed_dim % 4 == 0
|
160 |
+
|
161 |
+
# use half of dimensions to encode grid_h
|
162 |
+
emb_h = get_1d_rotary_pos_embed(
|
163 |
+
embed_dim // 2, grid[0].reshape(-1), use_real=use_real
|
164 |
+
) # (H*W, D/2) if use_real else (H*W, D/4)
|
165 |
+
emb_w = get_1d_rotary_pos_embed(
|
166 |
+
embed_dim // 2, grid[1].reshape(-1), use_real=use_real
|
167 |
+
) # (H*W, D/2) if use_real else (H*W, D/4)
|
168 |
+
|
169 |
+
if use_real:
|
170 |
+
cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D)
|
171 |
+
sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D)
|
172 |
+
return cos, sin
|
173 |
+
else:
|
174 |
+
emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2)
|
175 |
+
return emb
|
176 |
+
|
177 |
+
|
178 |
+
def get_2d_rotary_pos_embed_lumina(embed_dim, len_h, len_w, linear_factor=1.0, ntk_factor=1.0):
|
179 |
+
assert embed_dim % 4 == 0
|
180 |
+
|
181 |
+
emb_h = get_1d_rotary_pos_embed(
|
182 |
+
embed_dim // 2, len_h, linear_factor=linear_factor, ntk_factor=ntk_factor
|
183 |
+
) # (H, D/4)
|
184 |
+
emb_w = get_1d_rotary_pos_embed(
|
185 |
+
embed_dim // 2, len_w, linear_factor=linear_factor, ntk_factor=ntk_factor
|
186 |
+
) # (W, D/4)
|
187 |
+
emb_h = emb_h.view(len_h, 1, embed_dim // 4, 1).repeat(1, len_w, 1, 1) # (H, W, D/4, 1)
|
188 |
+
emb_w = emb_w.view(1, len_w, embed_dim // 4, 1).repeat(len_h, 1, 1, 1) # (H, W, D/4, 1)
|
189 |
+
|
190 |
+
emb = torch.cat([emb_h, emb_w], dim=-1).flatten(2) # (H, W, D/2)
|
191 |
+
return emb
|
192 |
+
|
193 |
+
|
194 |
+
def get_1d_rotary_pos_embed(
|
195 |
+
dim: int,
|
196 |
+
pos: Union[np.ndarray, int],
|
197 |
+
theta: float = 10000.0,
|
198 |
+
use_real=False,
|
199 |
+
linear_factor=1.0,
|
200 |
+
ntk_factor=1.0,
|
201 |
+
repeat_interleave_real=True,
|
202 |
+
freqs_dtype=torch.float32, # torch.float32 (hunyuan, stable audio), torch.float64 (flux)
|
203 |
+
):
|
204 |
+
"""
|
205 |
+
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
|
206 |
+
|
207 |
+
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
|
208 |
+
index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
|
209 |
+
data type.
|
210 |
+
|
211 |
+
Args:
|
212 |
+
dim (`int`): Dimension of the frequency tensor.
|
213 |
+
pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
|
214 |
+
theta (`float`, *optional*, defaults to 10000.0):
|
215 |
+
Scaling factor for frequency computation. Defaults to 10000.0.
|
216 |
+
use_real (`bool`, *optional*):
|
217 |
+
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
218 |
+
linear_factor (`float`, *optional*, defaults to 1.0):
|
219 |
+
Scaling factor for the context extrapolation. Defaults to 1.0.
|
220 |
+
ntk_factor (`float`, *optional*, defaults to 1.0):
|
221 |
+
Scaling factor for the NTK-Aware RoPE. Defaults to 1.0.
|
222 |
+
repeat_interleave_real (`bool`, *optional*, defaults to `True`):
|
223 |
+
If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`.
|
224 |
+
Otherwise, they are concateanted with themselves.
|
225 |
+
freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`):
|
226 |
+
the dtype of the frequency tensor.
|
227 |
+
Returns:
|
228 |
+
`torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
|
229 |
+
"""
|
230 |
+
assert dim % 2 == 0
|
231 |
+
|
232 |
+
if isinstance(pos, int):
|
233 |
+
pos = np.arange(pos)
|
234 |
+
theta = theta * ntk_factor
|
235 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype)[: (dim // 2)] / dim)) / linear_factor # [D/2]
|
236 |
+
t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S]
|
237 |
+
freqs = torch.outer(t, freqs) # type: ignore # [S, D/2]
|
238 |
+
if use_real and repeat_interleave_real:
|
239 |
+
freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
|
240 |
+
freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
|
241 |
+
return freqs_cos, freqs_sin
|
242 |
+
elif use_real:
|
243 |
+
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
|
244 |
+
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
|
245 |
+
return freqs_cos, freqs_sin
|
246 |
+
else:
|
247 |
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs).float() # complex64 # [S, D/2]
|
248 |
+
return freqs_cis
|
249 |
+
|
250 |
+
|
251 |
+
class FluxPosEmbed(nn.Module):
|
252 |
+
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
|
253 |
+
def __init__(self, theta: int, axes_dim: List[int]):
|
254 |
+
super().__init__()
|
255 |
+
self.theta = theta
|
256 |
+
self.axes_dim = axes_dim
|
257 |
+
|
258 |
+
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
259 |
+
n_axes = ids.shape[-1]
|
260 |
+
cos_out = []
|
261 |
+
sin_out = []
|
262 |
+
pos = ids.squeeze().float().cpu().numpy()
|
263 |
+
is_mps = ids.device.type == "mps"
|
264 |
+
freqs_dtype = torch.float32 if is_mps else torch.float64
|
265 |
+
for i in range(n_axes):
|
266 |
+
cos, sin = get_1d_rotary_pos_embed(
|
267 |
+
self.axes_dim[i], pos[:, i], repeat_interleave_real=True, use_real=True, freqs_dtype=freqs_dtype
|
268 |
+
)
|
269 |
+
cos_out.append(cos)
|
270 |
+
sin_out.append(sin)
|
271 |
+
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
|
272 |
+
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
|
273 |
+
return freqs_cos, freqs_sin
|
274 |
+
|
275 |
+
|
276 |
+
|
277 |
+
class FusedFluxAttnProcessor2_0:
|
278 |
+
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
279 |
+
|
280 |
+
def __init__(self):
|
281 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
282 |
+
raise ImportError(
|
283 |
+
"FusedFluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
284 |
+
)
|
285 |
+
|
286 |
+
def __call__(
|
287 |
+
self,
|
288 |
+
attn: Attention,
|
289 |
+
hidden_states: torch.FloatTensor,
|
290 |
+
encoder_hidden_states: torch.FloatTensor = None,
|
291 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
292 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
293 |
+
) -> torch.FloatTensor:
|
294 |
+
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
295 |
+
|
296 |
+
# `sample` projections.
|
297 |
+
qkv = attn.to_qkv(hidden_states)
|
298 |
+
split_size = qkv.shape[-1] // 3
|
299 |
+
query, key, value = torch.split(qkv, split_size, dim=-1)
|
300 |
+
|
301 |
+
inner_dim = key.shape[-1]
|
302 |
+
head_dim = inner_dim // attn.heads
|
303 |
+
|
304 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
305 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
306 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
307 |
+
|
308 |
+
if attn.norm_q is not None:
|
309 |
+
query = attn.norm_q(query)
|
310 |
+
if attn.norm_k is not None:
|
311 |
+
key = attn.norm_k(key)
|
312 |
+
|
313 |
+
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
|
314 |
+
# `context` projections.
|
315 |
+
if encoder_hidden_states is not None:
|
316 |
+
encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
|
317 |
+
split_size = encoder_qkv.shape[-1] // 3
|
318 |
+
(
|
319 |
+
encoder_hidden_states_query_proj,
|
320 |
+
encoder_hidden_states_key_proj,
|
321 |
+
encoder_hidden_states_value_proj,
|
322 |
+
) = torch.split(encoder_qkv, split_size, dim=-1)
|
323 |
+
|
324 |
+
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
325 |
+
batch_size, -1, attn.heads, head_dim
|
326 |
+
).transpose(1, 2)
|
327 |
+
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
328 |
+
batch_size, -1, attn.heads, head_dim
|
329 |
+
).transpose(1, 2)
|
330 |
+
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
331 |
+
batch_size, -1, attn.heads, head_dim
|
332 |
+
).transpose(1, 2)
|
333 |
+
|
334 |
+
if attn.norm_added_q is not None:
|
335 |
+
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
336 |
+
if attn.norm_added_k is not None:
|
337 |
+
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
338 |
+
|
339 |
+
# attention
|
340 |
+
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
|
341 |
+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
342 |
+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
343 |
+
|
344 |
+
if image_rotary_emb is not None:
|
345 |
+
from .embeddings import apply_rotary_emb
|
346 |
+
|
347 |
+
query = apply_rotary_emb(query, image_rotary_emb)
|
348 |
+
key = apply_rotary_emb(key, image_rotary_emb)
|
349 |
+
|
350 |
+
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
351 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
352 |
+
hidden_states = hidden_states.to(query.dtype)
|
353 |
+
|
354 |
+
if encoder_hidden_states is not None:
|
355 |
+
encoder_hidden_states, hidden_states = (
|
356 |
+
hidden_states[:, : encoder_hidden_states.shape[1]],
|
357 |
+
hidden_states[:, encoder_hidden_states.shape[1] :],
|
358 |
+
)
|
359 |
+
|
360 |
+
# linear proj
|
361 |
+
hidden_states = attn.to_out[0](hidden_states)
|
362 |
+
# dropout
|
363 |
+
hidden_states = attn.to_out[1](hidden_states)
|
364 |
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
365 |
+
|
366 |
+
return hidden_states, encoder_hidden_states
|
367 |
+
else:
|
368 |
+
return hidden_states
|
369 |
+
|
370 |
+
|
371 |
+
|
372 |
+
@maybe_allow_in_graph
|
373 |
+
class SingleTransformerBlock(nn.Module):
|
374 |
+
r"""
|
375 |
+
A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
|
376 |
+
|
377 |
+
Reference: https://arxiv.org/abs/2403.03206
|
378 |
+
|
379 |
+
Parameters:
|
380 |
+
dim (`int`): The number of channels in the input and output.
|
381 |
+
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
382 |
+
attention_head_dim (`int`): The number of channels in each head.
|
383 |
+
context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
|
384 |
+
processing of `context` conditions.
|
385 |
+
"""
|
386 |
+
|
387 |
+
def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
|
388 |
+
super().__init__()
|
389 |
+
self.mlp_hidden_dim = int(dim * mlp_ratio)
|
390 |
+
|
391 |
+
self.norm = AdaLayerNormZeroSingle(dim)
|
392 |
+
self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
|
393 |
+
self.act_mlp = nn.GELU(approximate="tanh")
|
394 |
+
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
|
395 |
+
|
396 |
+
processor = FluxAttnProcessor2_0()
|
397 |
+
self.attn = Attention(
|
398 |
+
query_dim=dim,
|
399 |
+
cross_attention_dim=None,
|
400 |
+
dim_head=attention_head_dim,
|
401 |
+
heads=num_attention_heads,
|
402 |
+
out_dim=dim,
|
403 |
+
bias=True,
|
404 |
+
processor=processor,
|
405 |
+
qk_norm="rms_norm",
|
406 |
+
eps=1e-6,
|
407 |
+
pre_only=True,
|
408 |
+
)
|
409 |
+
|
410 |
+
def forward(
|
411 |
+
self,
|
412 |
+
hidden_states: torch.FloatTensor,
|
413 |
+
temb: torch.FloatTensor,
|
414 |
+
image_rotary_emb=None,
|
415 |
+
):
|
416 |
+
residual = hidden_states
|
417 |
+
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
|
418 |
+
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
|
419 |
+
|
420 |
+
attn_output = self.attn(
|
421 |
+
hidden_states=norm_hidden_states,
|
422 |
+
image_rotary_emb=image_rotary_emb,
|
423 |
+
)
|
424 |
+
|
425 |
+
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
|
426 |
+
gate = gate.unsqueeze(1)
|
427 |
+
hidden_states = gate * self.proj_out(hidden_states)
|
428 |
+
hidden_states = residual + hidden_states
|
429 |
+
if hidden_states.dtype == torch.float16:
|
430 |
+
hidden_states = hidden_states.clip(-65504, 65504)
|
431 |
+
|
432 |
+
return hidden_states
|
433 |
+
|
434 |
+
@maybe_allow_in_graph
|
435 |
+
class TransformerBlock(nn.Module):
|
436 |
+
r"""
|
437 |
+
A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
|
438 |
+
|
439 |
+
Reference: https://arxiv.org/abs/2403.03206
|
440 |
+
|
441 |
+
Parameters:
|
442 |
+
dim (`int`): The number of channels in the input and output.
|
443 |
+
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
444 |
+
attention_head_dim (`int`): The number of channels in each head.
|
445 |
+
context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
|
446 |
+
processing of `context` conditions.
|
447 |
+
"""
|
448 |
+
|
449 |
+
def __init__(self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6):
|
450 |
+
super().__init__()
|
451 |
+
|
452 |
+
self.norm1 = AdaLayerNormZero(dim)
|
453 |
+
|
454 |
+
self.norm1_context = AdaLayerNormZero(dim)
|
455 |
+
|
456 |
+
if hasattr(F, "scaled_dot_product_attention"):
|
457 |
+
processor = FluxAttnProcessor2_0()
|
458 |
+
else:
|
459 |
+
raise ValueError(
|
460 |
+
"The current PyTorch version does not support the `scaled_dot_product_attention` function."
|
461 |
+
)
|
462 |
+
self.attn = Attention(
|
463 |
+
query_dim=dim,
|
464 |
+
cross_attention_dim=None,
|
465 |
+
added_kv_proj_dim=dim,
|
466 |
+
dim_head=attention_head_dim,
|
467 |
+
heads=num_attention_heads,
|
468 |
+
out_dim=dim,
|
469 |
+
context_pre_only=False,
|
470 |
+
bias=True,
|
471 |
+
processor=processor,
|
472 |
+
qk_norm=qk_norm,
|
473 |
+
eps=eps,
|
474 |
+
)
|
475 |
+
|
476 |
+
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
477 |
+
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
478 |
+
# self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="swiglu")
|
479 |
+
|
480 |
+
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
481 |
+
self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
482 |
+
# self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="swiglu")
|
483 |
+
|
484 |
+
# let chunk size default to None
|
485 |
+
self._chunk_size = None
|
486 |
+
self._chunk_dim = 0
|
487 |
+
|
488 |
+
def forward(
|
489 |
+
self,
|
490 |
+
hidden_states: torch.FloatTensor,
|
491 |
+
encoder_hidden_states: torch.FloatTensor,
|
492 |
+
temb: torch.FloatTensor,
|
493 |
+
image_rotary_emb=None,
|
494 |
+
):
|
495 |
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
496 |
+
|
497 |
+
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
|
498 |
+
encoder_hidden_states, emb=temb
|
499 |
+
)
|
500 |
+
# Attention.
|
501 |
+
attn_output, context_attn_output = self.attn(
|
502 |
+
hidden_states=norm_hidden_states,
|
503 |
+
encoder_hidden_states=norm_encoder_hidden_states,
|
504 |
+
image_rotary_emb=image_rotary_emb,
|
505 |
+
)
|
506 |
+
|
507 |
+
# Process attention outputs for the `hidden_states`.
|
508 |
+
attn_output = gate_msa.unsqueeze(1) * attn_output
|
509 |
+
hidden_states = hidden_states + attn_output
|
510 |
+
|
511 |
+
norm_hidden_states = self.norm2(hidden_states)
|
512 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
513 |
+
|
514 |
+
ff_output = self.ff(norm_hidden_states)
|
515 |
+
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
516 |
+
|
517 |
+
hidden_states = hidden_states + ff_output
|
518 |
+
|
519 |
+
# Process attention outputs for the `encoder_hidden_states`.
|
520 |
+
|
521 |
+
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
|
522 |
+
encoder_hidden_states = encoder_hidden_states + context_attn_output
|
523 |
+
|
524 |
+
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
|
525 |
+
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
|
526 |
+
|
527 |
+
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
528 |
+
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
|
529 |
+
if encoder_hidden_states.dtype == torch.float16:
|
530 |
+
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
|
531 |
+
|
532 |
+
return encoder_hidden_states, hidden_states
|
533 |
+
|
534 |
+
|
535 |
+
class UVit2DConvEmbed(nn.Module):
|
536 |
+
def __init__(self, in_channels, block_out_channels, vocab_size, elementwise_affine, eps, bias):
|
537 |
+
super().__init__()
|
538 |
+
self.embeddings = nn.Embedding(vocab_size, in_channels)
|
539 |
+
self.layer_norm = RMSNorm(in_channels, eps, elementwise_affine)
|
540 |
+
self.conv = nn.Conv2d(in_channels, block_out_channels, kernel_size=1, bias=bias)
|
541 |
+
|
542 |
+
def forward(self, input_ids):
|
543 |
+
embeddings = self.embeddings(input_ids)
|
544 |
+
embeddings = self.layer_norm(embeddings)
|
545 |
+
embeddings = embeddings.permute(0, 3, 1, 2)
|
546 |
+
embeddings = self.conv(embeddings)
|
547 |
+
return embeddings
|
548 |
+
|
549 |
+
class ConvMlmLayer(nn.Module):
|
550 |
+
def __init__(
|
551 |
+
self,
|
552 |
+
block_out_channels: int,
|
553 |
+
in_channels: int,
|
554 |
+
use_bias: bool,
|
555 |
+
ln_elementwise_affine: bool,
|
556 |
+
layer_norm_eps: float,
|
557 |
+
codebook_size: int,
|
558 |
+
):
|
559 |
+
super().__init__()
|
560 |
+
self.conv1 = nn.Conv2d(block_out_channels, in_channels, kernel_size=1, bias=use_bias)
|
561 |
+
self.layer_norm = RMSNorm(in_channels, layer_norm_eps, ln_elementwise_affine)
|
562 |
+
self.conv2 = nn.Conv2d(in_channels, codebook_size, kernel_size=1, bias=use_bias)
|
563 |
+
|
564 |
+
def forward(self, hidden_states):
|
565 |
+
hidden_states = self.conv1(hidden_states)
|
566 |
+
hidden_states = self.layer_norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
567 |
+
logits = self.conv2(hidden_states)
|
568 |
+
return logits
|
569 |
+
|
570 |
+
class SwiGLU(nn.Module):
|
571 |
+
r"""
|
572 |
+
A [variant](https://arxiv.org/abs/2002.05202) of the gated linear unit activation function. It's similar to `GEGLU`
|
573 |
+
but uses SiLU / Swish instead of GeLU.
|
574 |
+
|
575 |
+
Parameters:
|
576 |
+
dim_in (`int`): The number of channels in the input.
|
577 |
+
dim_out (`int`): The number of channels in the output.
|
578 |
+
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
|
579 |
+
"""
|
580 |
+
|
581 |
+
def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
|
582 |
+
super().__init__()
|
583 |
+
self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
|
584 |
+
self.activation = nn.SiLU()
|
585 |
+
|
586 |
+
def forward(self, hidden_states):
|
587 |
+
hidden_states = self.proj(hidden_states)
|
588 |
+
hidden_states, gate = hidden_states.chunk(2, dim=-1)
|
589 |
+
return hidden_states * self.activation(gate)
|
590 |
+
|
591 |
+
class ConvNextBlock(nn.Module):
|
592 |
+
def __init__(
|
593 |
+
self, channels, layer_norm_eps, ln_elementwise_affine, use_bias, hidden_dropout, hidden_size, res_ffn_factor=4
|
594 |
+
):
|
595 |
+
super().__init__()
|
596 |
+
self.depthwise = nn.Conv2d(
|
597 |
+
channels,
|
598 |
+
channels,
|
599 |
+
kernel_size=3,
|
600 |
+
padding=1,
|
601 |
+
groups=channels,
|
602 |
+
bias=use_bias,
|
603 |
+
)
|
604 |
+
self.norm = RMSNorm(channels, layer_norm_eps, ln_elementwise_affine)
|
605 |
+
self.channelwise_linear_1 = nn.Linear(channels, int(channels * res_ffn_factor), bias=use_bias)
|
606 |
+
self.channelwise_act = nn.GELU()
|
607 |
+
self.channelwise_norm = GlobalResponseNorm(int(channels * res_ffn_factor))
|
608 |
+
self.channelwise_linear_2 = nn.Linear(int(channels * res_ffn_factor), channels, bias=use_bias)
|
609 |
+
self.channelwise_dropout = nn.Dropout(hidden_dropout)
|
610 |
+
self.cond_embeds_mapper = nn.Linear(hidden_size, channels * 2, use_bias)
|
611 |
+
|
612 |
+
def forward(self, x, cond_embeds):
|
613 |
+
x_res = x
|
614 |
+
|
615 |
+
x = self.depthwise(x)
|
616 |
+
|
617 |
+
x = x.permute(0, 2, 3, 1)
|
618 |
+
x = self.norm(x)
|
619 |
+
|
620 |
+
x = self.channelwise_linear_1(x)
|
621 |
+
x = self.channelwise_act(x)
|
622 |
+
x = self.channelwise_norm(x)
|
623 |
+
x = self.channelwise_linear_2(x)
|
624 |
+
x = self.channelwise_dropout(x)
|
625 |
+
|
626 |
+
x = x.permute(0, 3, 1, 2)
|
627 |
+
|
628 |
+
x = x + x_res
|
629 |
+
|
630 |
+
scale, shift = self.cond_embeds_mapper(F.silu(cond_embeds)).chunk(2, dim=1)
|
631 |
+
x = x * (1 + scale[:, :, None, None]) + shift[:, :, None, None]
|
632 |
+
|
633 |
+
return x
|
634 |
+
|
635 |
+
class Simple_UVitBlock(nn.Module):
|
636 |
+
def __init__(
|
637 |
+
self,
|
638 |
+
channels,
|
639 |
+
ln_elementwise_affine,
|
640 |
+
layer_norm_eps,
|
641 |
+
use_bias,
|
642 |
+
downsample: bool,
|
643 |
+
upsample: bool,
|
644 |
+
):
|
645 |
+
super().__init__()
|
646 |
+
|
647 |
+
if downsample:
|
648 |
+
self.downsample = Downsample2D(
|
649 |
+
channels,
|
650 |
+
use_conv=True,
|
651 |
+
padding=0,
|
652 |
+
name="Conv2d_0",
|
653 |
+
kernel_size=2,
|
654 |
+
norm_type="rms_norm",
|
655 |
+
eps=layer_norm_eps,
|
656 |
+
elementwise_affine=ln_elementwise_affine,
|
657 |
+
bias=use_bias,
|
658 |
+
)
|
659 |
+
else:
|
660 |
+
self.downsample = None
|
661 |
+
|
662 |
+
if upsample:
|
663 |
+
self.upsample = Upsample2D(
|
664 |
+
channels,
|
665 |
+
use_conv_transpose=True,
|
666 |
+
kernel_size=2,
|
667 |
+
padding=0,
|
668 |
+
name="conv",
|
669 |
+
norm_type="rms_norm",
|
670 |
+
eps=layer_norm_eps,
|
671 |
+
elementwise_affine=ln_elementwise_affine,
|
672 |
+
bias=use_bias,
|
673 |
+
interpolate=False,
|
674 |
+
)
|
675 |
+
else:
|
676 |
+
self.upsample = None
|
677 |
+
|
678 |
+
def forward(self, x):
|
679 |
+
# print("before,", x.shape)
|
680 |
+
if self.downsample is not None:
|
681 |
+
# print('downsample')
|
682 |
+
x = self.downsample(x)
|
683 |
+
|
684 |
+
if self.upsample is not None:
|
685 |
+
# print('upsample')
|
686 |
+
x = self.upsample(x)
|
687 |
+
# print("after,", x.shape)
|
688 |
+
return x
|
689 |
+
|
690 |
+
|
691 |
+
class UVitBlock(nn.Module):
|
692 |
+
def __init__(
|
693 |
+
self,
|
694 |
+
channels,
|
695 |
+
num_res_blocks: int,
|
696 |
+
hidden_size,
|
697 |
+
hidden_dropout,
|
698 |
+
ln_elementwise_affine,
|
699 |
+
layer_norm_eps,
|
700 |
+
use_bias,
|
701 |
+
block_num_heads,
|
702 |
+
attention_dropout,
|
703 |
+
downsample: bool,
|
704 |
+
upsample: bool,
|
705 |
+
):
|
706 |
+
super().__init__()
|
707 |
+
|
708 |
+
if downsample:
|
709 |
+
self.downsample = Downsample2D(
|
710 |
+
channels,
|
711 |
+
use_conv=True,
|
712 |
+
padding=0,
|
713 |
+
name="Conv2d_0",
|
714 |
+
kernel_size=2,
|
715 |
+
norm_type="rms_norm",
|
716 |
+
eps=layer_norm_eps,
|
717 |
+
elementwise_affine=ln_elementwise_affine,
|
718 |
+
bias=use_bias,
|
719 |
+
)
|
720 |
+
else:
|
721 |
+
self.downsample = None
|
722 |
+
|
723 |
+
self.res_blocks = nn.ModuleList(
|
724 |
+
[
|
725 |
+
ConvNextBlock(
|
726 |
+
channels,
|
727 |
+
layer_norm_eps,
|
728 |
+
ln_elementwise_affine,
|
729 |
+
use_bias,
|
730 |
+
hidden_dropout,
|
731 |
+
hidden_size,
|
732 |
+
)
|
733 |
+
for i in range(num_res_blocks)
|
734 |
+
]
|
735 |
+
)
|
736 |
+
|
737 |
+
self.attention_blocks = nn.ModuleList(
|
738 |
+
[
|
739 |
+
SkipFFTransformerBlock(
|
740 |
+
channels,
|
741 |
+
block_num_heads,
|
742 |
+
channels // block_num_heads,
|
743 |
+
hidden_size,
|
744 |
+
use_bias,
|
745 |
+
attention_dropout,
|
746 |
+
channels,
|
747 |
+
attention_bias=use_bias,
|
748 |
+
attention_out_bias=use_bias,
|
749 |
+
)
|
750 |
+
for _ in range(num_res_blocks)
|
751 |
+
]
|
752 |
+
)
|
753 |
+
|
754 |
+
if upsample:
|
755 |
+
self.upsample = Upsample2D(
|
756 |
+
channels,
|
757 |
+
use_conv_transpose=True,
|
758 |
+
kernel_size=2,
|
759 |
+
padding=0,
|
760 |
+
name="conv",
|
761 |
+
norm_type="rms_norm",
|
762 |
+
eps=layer_norm_eps,
|
763 |
+
elementwise_affine=ln_elementwise_affine,
|
764 |
+
bias=use_bias,
|
765 |
+
interpolate=False,
|
766 |
+
)
|
767 |
+
else:
|
768 |
+
self.upsample = None
|
769 |
+
|
770 |
+
def forward(self, x, pooled_text_emb, encoder_hidden_states, cross_attention_kwargs):
|
771 |
+
if self.downsample is not None:
|
772 |
+
x = self.downsample(x)
|
773 |
+
|
774 |
+
for res_block, attention_block in zip(self.res_blocks, self.attention_blocks):
|
775 |
+
x = res_block(x, pooled_text_emb)
|
776 |
+
|
777 |
+
batch_size, channels, height, width = x.shape
|
778 |
+
x = x.view(batch_size, channels, height * width).permute(0, 2, 1)
|
779 |
+
x = attention_block(
|
780 |
+
x, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs
|
781 |
+
)
|
782 |
+
x = x.permute(0, 2, 1).view(batch_size, channels, height, width)
|
783 |
+
|
784 |
+
if self.upsample is not None:
|
785 |
+
x = self.upsample(x)
|
786 |
+
|
787 |
+
return x
|
788 |
+
|
789 |
+
class Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
790 |
+
"""
|
791 |
+
The Transformer model introduced in Flux.
|
792 |
+
|
793 |
+
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
|
794 |
+
|
795 |
+
Parameters:
|
796 |
+
patch_size (`int`): Patch size to turn the input data into small patches.
|
797 |
+
in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
|
798 |
+
num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
|
799 |
+
num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
|
800 |
+
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
|
801 |
+
num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
|
802 |
+
joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
803 |
+
pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
|
804 |
+
guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
|
805 |
+
"""
|
806 |
+
|
807 |
+
_supports_gradient_checkpointing = False #True
|
808 |
+
# Due to NotImplementedError: DDPOptimizer backend: Found a higher order op in the graph. This is not supported. Please turn off DDP optimizer using torch._dynamo.config.optimize_ddp=False. Note that this can cause performance degradation because there will be one bucket for the entire Dynamo graph.
|
809 |
+
# Please refer to this issue - https://github.com/pytorch/pytorch/issues/104674.
|
810 |
+
_no_split_modules = ["TransformerBlock", "SingleTransformerBlock"]
|
811 |
+
|
812 |
+
@register_to_config
|
813 |
+
def __init__(
|
814 |
+
self,
|
815 |
+
patch_size: int = 1,
|
816 |
+
in_channels: int = 64,
|
817 |
+
num_layers: int = 19,
|
818 |
+
num_single_layers: int = 38,
|
819 |
+
attention_head_dim: int = 128,
|
820 |
+
num_attention_heads: int = 24,
|
821 |
+
joint_attention_dim: int = 4096,
|
822 |
+
pooled_projection_dim: int = 768,
|
823 |
+
guidance_embeds: bool = False, # unused in our implementation
|
824 |
+
axes_dims_rope: Tuple[int] = (16, 56, 56),
|
825 |
+
vocab_size: int = 8256,
|
826 |
+
codebook_size: int = 8192,
|
827 |
+
downsample: bool = False,
|
828 |
+
upsample: bool = False,
|
829 |
+
):
|
830 |
+
super().__init__()
|
831 |
+
self.out_channels = in_channels
|
832 |
+
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
|
833 |
+
|
834 |
+
self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
|
835 |
+
text_time_guidance_cls = (
|
836 |
+
CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
|
837 |
+
)
|
838 |
+
self.time_text_embed = text_time_guidance_cls(
|
839 |
+
embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim
|
840 |
+
)
|
841 |
+
|
842 |
+
self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
|
843 |
+
|
844 |
+
self.transformer_blocks = nn.ModuleList(
|
845 |
+
[
|
846 |
+
TransformerBlock(
|
847 |
+
dim=self.inner_dim,
|
848 |
+
num_attention_heads=self.config.num_attention_heads,
|
849 |
+
attention_head_dim=self.config.attention_head_dim,
|
850 |
+
)
|
851 |
+
for i in range(self.config.num_layers)
|
852 |
+
]
|
853 |
+
)
|
854 |
+
|
855 |
+
self.single_transformer_blocks = nn.ModuleList(
|
856 |
+
[
|
857 |
+
SingleTransformerBlock(
|
858 |
+
dim=self.inner_dim,
|
859 |
+
num_attention_heads=self.config.num_attention_heads,
|
860 |
+
attention_head_dim=self.config.attention_head_dim,
|
861 |
+
)
|
862 |
+
for i in range(self.config.num_single_layers)
|
863 |
+
]
|
864 |
+
)
|
865 |
+
|
866 |
+
|
867 |
+
self.gradient_checkpointing = False
|
868 |
+
|
869 |
+
in_channels_embed = self.inner_dim
|
870 |
+
ln_elementwise_affine = True
|
871 |
+
layer_norm_eps = 1e-06
|
872 |
+
use_bias = False
|
873 |
+
micro_cond_embed_dim = 1280
|
874 |
+
self.embed = UVit2DConvEmbed(
|
875 |
+
in_channels_embed, self.inner_dim, self.config.vocab_size, ln_elementwise_affine, layer_norm_eps, use_bias
|
876 |
+
)
|
877 |
+
self.mlm_layer = ConvMlmLayer(
|
878 |
+
self.inner_dim, in_channels_embed, use_bias, ln_elementwise_affine, layer_norm_eps, self.config.codebook_size
|
879 |
+
)
|
880 |
+
self.cond_embed = TimestepEmbedding(
|
881 |
+
micro_cond_embed_dim + self.config.pooled_projection_dim, self.inner_dim, sample_proj_bias=use_bias
|
882 |
+
)
|
883 |
+
self.encoder_proj_layer_norm = RMSNorm(self.inner_dim, layer_norm_eps, ln_elementwise_affine)
|
884 |
+
self.project_to_hidden_norm = RMSNorm(in_channels_embed, layer_norm_eps, ln_elementwise_affine)
|
885 |
+
self.project_to_hidden = nn.Linear(in_channels_embed, self.inner_dim, bias=use_bias)
|
886 |
+
self.project_from_hidden_norm = RMSNorm(self.inner_dim, layer_norm_eps, ln_elementwise_affine)
|
887 |
+
self.project_from_hidden = nn.Linear(self.inner_dim, in_channels_embed, bias=use_bias)
|
888 |
+
|
889 |
+
self.down_block = Simple_UVitBlock(
|
890 |
+
self.inner_dim,
|
891 |
+
ln_elementwise_affine,
|
892 |
+
layer_norm_eps,
|
893 |
+
use_bias,
|
894 |
+
downsample,
|
895 |
+
False,
|
896 |
+
)
|
897 |
+
self.up_block = Simple_UVitBlock(
|
898 |
+
self.inner_dim, #block_out_channels,
|
899 |
+
ln_elementwise_affine,
|
900 |
+
layer_norm_eps,
|
901 |
+
use_bias,
|
902 |
+
False,
|
903 |
+
upsample=upsample,
|
904 |
+
)
|
905 |
+
|
906 |
+
# self.fuse_qkv_projections()
|
907 |
+
|
908 |
+
@property
|
909 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
910 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
911 |
+
r"""
|
912 |
+
Returns:
|
913 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
914 |
+
indexed by its weight name.
|
915 |
+
"""
|
916 |
+
# set recursively
|
917 |
+
processors = {}
|
918 |
+
|
919 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
920 |
+
if hasattr(module, "get_processor"):
|
921 |
+
processors[f"{name}.processor"] = module.get_processor()
|
922 |
+
|
923 |
+
for sub_name, child in module.named_children():
|
924 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
925 |
+
|
926 |
+
return processors
|
927 |
+
|
928 |
+
for name, module in self.named_children():
|
929 |
+
fn_recursive_add_processors(name, module, processors)
|
930 |
+
|
931 |
+
return processors
|
932 |
+
|
933 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
934 |
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
935 |
+
r"""
|
936 |
+
Sets the attention processor to use to compute attention.
|
937 |
+
|
938 |
+
Parameters:
|
939 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
940 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
941 |
+
for **all** `Attention` layers.
|
942 |
+
|
943 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
944 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
945 |
+
|
946 |
+
"""
|
947 |
+
count = len(self.attn_processors.keys())
|
948 |
+
|
949 |
+
if isinstance(processor, dict) and len(processor) != count:
|
950 |
+
raise ValueError(
|
951 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
952 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
953 |
+
)
|
954 |
+
|
955 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
956 |
+
if hasattr(module, "set_processor"):
|
957 |
+
if not isinstance(processor, dict):
|
958 |
+
module.set_processor(processor)
|
959 |
+
else:
|
960 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
961 |
+
|
962 |
+
for sub_name, child in module.named_children():
|
963 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
964 |
+
|
965 |
+
for name, module in self.named_children():
|
966 |
+
fn_recursive_attn_processor(name, module, processor)
|
967 |
+
|
968 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0
|
969 |
+
def fuse_qkv_projections(self):
|
970 |
+
"""
|
971 |
+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
972 |
+
are fused. For cross-attention modules, key and value projection matrices are fused.
|
973 |
+
|
974 |
+
<Tip warning={true}>
|
975 |
+
|
976 |
+
This API is 🧪 experimental.
|
977 |
+
|
978 |
+
</Tip>
|
979 |
+
"""
|
980 |
+
self.original_attn_processors = None
|
981 |
+
|
982 |
+
for _, attn_processor in self.attn_processors.items():
|
983 |
+
if "Added" in str(attn_processor.__class__.__name__):
|
984 |
+
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
985 |
+
|
986 |
+
self.original_attn_processors = self.attn_processors
|
987 |
+
|
988 |
+
for module in self.modules():
|
989 |
+
if isinstance(module, Attention):
|
990 |
+
module.fuse_projections(fuse=True)
|
991 |
+
|
992 |
+
self.set_attn_processor(FusedFluxAttnProcessor2_0())
|
993 |
+
|
994 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
995 |
+
def unfuse_qkv_projections(self):
|
996 |
+
"""Disables the fused QKV projection if enabled.
|
997 |
+
|
998 |
+
<Tip warning={true}>
|
999 |
+
|
1000 |
+
This API is 🧪 experimental.
|
1001 |
+
|
1002 |
+
</Tip>
|
1003 |
+
|
1004 |
+
"""
|
1005 |
+
if self.original_attn_processors is not None:
|
1006 |
+
self.set_attn_processor(self.original_attn_processors)
|
1007 |
+
|
1008 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
1009 |
+
if hasattr(module, "gradient_checkpointing"):
|
1010 |
+
module.gradient_checkpointing = value
|
1011 |
+
|
1012 |
+
def forward(
|
1013 |
+
self,
|
1014 |
+
hidden_states: torch.Tensor,
|
1015 |
+
encoder_hidden_states: torch.Tensor = None,
|
1016 |
+
pooled_projections: torch.Tensor = None,
|
1017 |
+
timestep: torch.LongTensor = None,
|
1018 |
+
img_ids: torch.Tensor = None,
|
1019 |
+
txt_ids: torch.Tensor = None,
|
1020 |
+
guidance: torch.Tensor = None,
|
1021 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
1022 |
+
controlnet_block_samples= None,
|
1023 |
+
controlnet_single_block_samples=None,
|
1024 |
+
return_dict: bool = True,
|
1025 |
+
micro_conds: torch.Tensor = None,
|
1026 |
+
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
|
1027 |
+
"""
|
1028 |
+
The [`FluxTransformer2DModel`] forward method.
|
1029 |
+
|
1030 |
+
Args:
|
1031 |
+
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
|
1032 |
+
Input `hidden_states`.
|
1033 |
+
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
|
1034 |
+
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
1035 |
+
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
|
1036 |
+
from the embeddings of input conditions.
|
1037 |
+
timestep ( `torch.LongTensor`):
|
1038 |
+
Used to indicate denoising step.
|
1039 |
+
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
|
1040 |
+
A list of tensors that if specified are added to the residuals of transformer blocks.
|
1041 |
+
joint_attention_kwargs (`dict`, *optional*):
|
1042 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
1043 |
+
`self.processor` in
|
1044 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
1045 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
1046 |
+
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
1047 |
+
tuple.
|
1048 |
+
|
1049 |
+
Returns:
|
1050 |
+
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
1051 |
+
`tuple` where the first element is the sample tensor.
|
1052 |
+
"""
|
1053 |
+
micro_cond_encode_dim = 256 # same as self.config.micro_cond_encode_dim = 256 from amused
|
1054 |
+
micro_cond_embeds = get_timestep_embedding(
|
1055 |
+
micro_conds.flatten(), micro_cond_encode_dim, flip_sin_to_cos=True, downscale_freq_shift=0
|
1056 |
+
)
|
1057 |
+
micro_cond_embeds = micro_cond_embeds.reshape((hidden_states.shape[0], -1))
|
1058 |
+
|
1059 |
+
pooled_projections = torch.cat([pooled_projections, micro_cond_embeds], dim=1)
|
1060 |
+
pooled_projections = pooled_projections.to(dtype=self.dtype)
|
1061 |
+
pooled_projections = self.cond_embed(pooled_projections).to(encoder_hidden_states.dtype)
|
1062 |
+
|
1063 |
+
|
1064 |
+
hidden_states = self.embed(hidden_states)
|
1065 |
+
|
1066 |
+
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
1067 |
+
encoder_hidden_states = self.encoder_proj_layer_norm(encoder_hidden_states)
|
1068 |
+
hidden_states = self.down_block(hidden_states)
|
1069 |
+
|
1070 |
+
batch_size, channels, height, width = hidden_states.shape
|
1071 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels)
|
1072 |
+
hidden_states = self.project_to_hidden_norm(hidden_states)
|
1073 |
+
hidden_states = self.project_to_hidden(hidden_states)
|
1074 |
+
|
1075 |
+
|
1076 |
+
if joint_attention_kwargs is not None:
|
1077 |
+
joint_attention_kwargs = joint_attention_kwargs.copy()
|
1078 |
+
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
1079 |
+
else:
|
1080 |
+
lora_scale = 1.0
|
1081 |
+
|
1082 |
+
if USE_PEFT_BACKEND:
|
1083 |
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
1084 |
+
scale_lora_layers(self, lora_scale)
|
1085 |
+
else:
|
1086 |
+
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
|
1087 |
+
logger.warning(
|
1088 |
+
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
1089 |
+
)
|
1090 |
+
|
1091 |
+
timestep = timestep.to(hidden_states.dtype) * 1000
|
1092 |
+
if guidance is not None:
|
1093 |
+
guidance = guidance.to(hidden_states.dtype) * 1000
|
1094 |
+
else:
|
1095 |
+
guidance = None
|
1096 |
+
temb = (
|
1097 |
+
self.time_text_embed(timestep, pooled_projections)
|
1098 |
+
if guidance is None
|
1099 |
+
else self.time_text_embed(timestep, guidance, pooled_projections)
|
1100 |
+
)
|
1101 |
+
|
1102 |
+
if txt_ids.ndim == 3:
|
1103 |
+
logger.warning(
|
1104 |
+
"Passing `txt_ids` 3d torch.Tensor is deprecated."
|
1105 |
+
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
1106 |
+
)
|
1107 |
+
txt_ids = txt_ids[0]
|
1108 |
+
if img_ids.ndim == 3:
|
1109 |
+
logger.warning(
|
1110 |
+
"Passing `img_ids` 3d torch.Tensor is deprecated."
|
1111 |
+
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
1112 |
+
)
|
1113 |
+
img_ids = img_ids[0]
|
1114 |
+
ids = torch.cat((txt_ids, img_ids), dim=0)
|
1115 |
+
|
1116 |
+
image_rotary_emb = self.pos_embed(ids)
|
1117 |
+
|
1118 |
+
for index_block, block in enumerate(self.transformer_blocks):
|
1119 |
+
if self.training and self.gradient_checkpointing:
|
1120 |
+
|
1121 |
+
def create_custom_forward(module, return_dict=None):
|
1122 |
+
def custom_forward(*inputs):
|
1123 |
+
if return_dict is not None:
|
1124 |
+
return module(*inputs, return_dict=return_dict)
|
1125 |
+
else:
|
1126 |
+
return module(*inputs)
|
1127 |
+
|
1128 |
+
return custom_forward
|
1129 |
+
|
1130 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
1131 |
+
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
|
1132 |
+
create_custom_forward(block),
|
1133 |
+
hidden_states,
|
1134 |
+
encoder_hidden_states,
|
1135 |
+
temb,
|
1136 |
+
image_rotary_emb,
|
1137 |
+
**ckpt_kwargs,
|
1138 |
+
)
|
1139 |
+
|
1140 |
+
else:
|
1141 |
+
encoder_hidden_states, hidden_states = block(
|
1142 |
+
hidden_states=hidden_states,
|
1143 |
+
encoder_hidden_states=encoder_hidden_states,
|
1144 |
+
temb=temb,
|
1145 |
+
image_rotary_emb=image_rotary_emb,
|
1146 |
+
)
|
1147 |
+
|
1148 |
+
|
1149 |
+
# controlnet residual
|
1150 |
+
if controlnet_block_samples is not None:
|
1151 |
+
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
|
1152 |
+
interval_control = int(np.ceil(interval_control))
|
1153 |
+
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
|
1154 |
+
|
1155 |
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
1156 |
+
|
1157 |
+
for index_block, block in enumerate(self.single_transformer_blocks):
|
1158 |
+
if self.training and self.gradient_checkpointing:
|
1159 |
+
|
1160 |
+
def create_custom_forward(module, return_dict=None):
|
1161 |
+
def custom_forward(*inputs):
|
1162 |
+
if return_dict is not None:
|
1163 |
+
return module(*inputs, return_dict=return_dict)
|
1164 |
+
else:
|
1165 |
+
return module(*inputs)
|
1166 |
+
|
1167 |
+
return custom_forward
|
1168 |
+
|
1169 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
1170 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
1171 |
+
create_custom_forward(block),
|
1172 |
+
hidden_states,
|
1173 |
+
temb,
|
1174 |
+
image_rotary_emb,
|
1175 |
+
**ckpt_kwargs,
|
1176 |
+
)
|
1177 |
+
|
1178 |
+
else:
|
1179 |
+
hidden_states = block(
|
1180 |
+
hidden_states=hidden_states,
|
1181 |
+
temb=temb,
|
1182 |
+
image_rotary_emb=image_rotary_emb,
|
1183 |
+
)
|
1184 |
+
|
1185 |
+
# controlnet residual
|
1186 |
+
if controlnet_single_block_samples is not None:
|
1187 |
+
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
|
1188 |
+
interval_control = int(np.ceil(interval_control))
|
1189 |
+
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
|
1190 |
+
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
1191 |
+
+ controlnet_single_block_samples[index_block // interval_control]
|
1192 |
+
)
|
1193 |
+
|
1194 |
+
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
1195 |
+
|
1196 |
+
|
1197 |
+
hidden_states = self.project_from_hidden_norm(hidden_states)
|
1198 |
+
hidden_states = self.project_from_hidden(hidden_states)
|
1199 |
+
|
1200 |
+
|
1201 |
+
hidden_states = hidden_states.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2)
|
1202 |
+
|
1203 |
+
hidden_states = self.up_block(hidden_states)
|
1204 |
+
|
1205 |
+
if USE_PEFT_BACKEND:
|
1206 |
+
# remove `lora_scale` from each PEFT layer
|
1207 |
+
unscale_lora_layers(self, lora_scale)
|
1208 |
+
|
1209 |
+
output = self.mlm_layer(hidden_states)
|
1210 |
+
# self.unfuse_qkv_projections()
|
1211 |
+
if not return_dict:
|
1212 |
+
return (output,)
|
1213 |
+
|
1214 |
+
|
1215 |
+
return output
|