sky24h commited on
Commit
1d24639
·
1 Parent(s): 3db7e62

init commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +160 -0
  2. LICENSE +201 -0
  3. app.py +50 -0
  4. checkpoints/stablemakeup/pytorch_model.bin +3 -0
  5. checkpoints/stablemakeup/pytorch_model_1.bin +3 -0
  6. checkpoints/stablemakeup/pytorch_model_2.bin +3 -0
  7. detail_encoder/.DS_Store +0 -0
  8. detail_encoder/__init__.py +0 -0
  9. detail_encoder/_clip.py +1349 -0
  10. detail_encoder/attention_processor.py +687 -0
  11. detail_encoder/encoder_plus.py +113 -0
  12. detail_encoder/resampler.py +112 -0
  13. diffusers/.DS_Store +0 -0
  14. diffusers/__init__.py +734 -0
  15. diffusers/commands/__init__.py +27 -0
  16. diffusers/commands/diffusers_cli.py +43 -0
  17. diffusers/commands/env.py +84 -0
  18. diffusers/commands/fp16_safetensors.py +133 -0
  19. diffusers/configuration_utils.py +694 -0
  20. diffusers/dependency_versions_check.py +35 -0
  21. diffusers/dependency_versions_table.py +46 -0
  22. diffusers/experimental/README.md +5 -0
  23. diffusers/experimental/__init__.py +1 -0
  24. diffusers/experimental/rl/__init__.py +1 -0
  25. diffusers/experimental/rl/value_guided_sampling.py +154 -0
  26. diffusers/image_processor.py +476 -0
  27. diffusers/loaders.py +0 -0
  28. diffusers/models/README.md +3 -0
  29. diffusers/models/__init__.py +77 -0
  30. diffusers/models/activations.py +120 -0
  31. diffusers/models/adapter.py +584 -0
  32. diffusers/models/attention.py +396 -0
  33. diffusers/models/attention_flax.py +486 -0
  34. diffusers/models/attention_processor.py +2020 -0
  35. diffusers/models/autoencoder_asym_kl.py +181 -0
  36. diffusers/models/autoencoder_kl.py +465 -0
  37. diffusers/models/autoencoder_tiny.py +349 -0
  38. diffusers/models/consistency_decoder_vae.py +430 -0
  39. diffusers/models/controlnet.py +844 -0
  40. diffusers/models/controlnet_flax.py +394 -0
  41. diffusers/models/dual_transformer_2d.py +155 -0
  42. diffusers/models/embeddings.py +792 -0
  43. diffusers/models/embeddings_flax.py +95 -0
  44. diffusers/models/lora.py +304 -0
  45. diffusers/models/modeling_flax_pytorch_utils.py +134 -0
  46. diffusers/models/modeling_flax_utils.py +560 -0
  47. diffusers/models/modeling_pytorch_flax_utils.py +161 -0
  48. diffusers/models/modeling_utils.py +1158 -0
  49. diffusers/models/normalization.py +148 -0
  50. diffusers/models/prior_transformer.py +382 -0
.gitignore ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
app.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import gradio as gr
3
+ from inference_utils import inference
4
+
5
+
6
+ @spaces.GPU
7
+ def send_to_model(id_image, makeup_image, guidance_scale):
8
+ if guidance_scale is None:
9
+ # when creating example caches.
10
+ guidance_scale = 1.6
11
+ return inference(id_image, makeup_image, guidance_scale, size=512)
12
+
13
+ if __name__ == "__main__":
14
+ with gr.Blocks() as demo:
15
+ gr.HTML(
16
+ """
17
+ <h1 style="text-align: center; font-size: 32px; font-family: 'Times New Roman', Times, serif;">
18
+ Stable-Makeup: When Real-World Makeup Transfer Meets Diffusion Model
19
+ </h1>
20
+ <p style="text-align: center; font-size: 20px; font-family: 'Times New Roman', Times, serif;">
21
+ <a style="text-align: center; display:inline-block"
22
+ href="https://xiaojiu-z.github.io/Stable-Makeup.github.io/">
23
+ <img src="https://huggingface.co/datasets/huggingface/badges/raw/main/paper-page-sm.svg#center"
24
+ alt="Paper Page">
25
+ </a>
26
+ <a style="text-align: center; display:inline-block" href="https://huggingface.co/spaces/sky24h/Stable-Makeup-unofficial?duplicate=true">
27
+ <img src="https://huggingface.co/datasets/huggingface/badges/raw/main/duplicate-this-space-sm.svg#center" alt="Duplicate Space">
28
+ </a>
29
+ </p>
30
+ """
31
+ )
32
+ gr.Interface(
33
+ fn=send_to_model,
34
+ inputs=[
35
+ gr.Image(type="pil", label="id_image", height=512, width=512),
36
+ gr.Image(type="pil", label="makeup_image", height=512, width=512),
37
+ gr.Slider(minimum=1.01, maximum=3, value=1.6, step=0.05, label="guidance_scale", info="1.05-1.15 is suggested for light makeup and 2 for heavy makeup."),
38
+ ],
39
+ outputs="image",
40
+ allow_flagging="never",
41
+ description="This is an unofficial demo for the paper 'Stable-Makeup: When Real-World Makeup Transfer Meets Diffusion Model'.",
42
+ examples=[
43
+ ["./test_imgs/id/1.jpg", "./test_imgs/makeup/1.jpg"],
44
+ ["./test_imgs/id/2.jpg", "./test_imgs/makeup/2.jpg"],
45
+ ["./test_imgs/id/3.jpg", "./test_imgs/makeup/3.jpg"],
46
+ ["./test_imgs/id/4.jpg", "./test_imgs/makeup/4.png"],
47
+ ],
48
+ cache_examples=True,
49
+ )
50
+ demo.queue(max_size=10).launch()
checkpoints/stablemakeup/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:da8272fdbb74cb70714272e3b4ff381958944b4f308df16977dfe8893dfc7f64
3
+ size 1373905877
checkpoints/stablemakeup/pytorch_model_1.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:05573314b4a16e592456e4c5dc3932fe705e1e35d8609ea886cb98ac2deadf47
3
+ size 1445256905
checkpoints/stablemakeup/pytorch_model_2.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:143d489e9e1f3c5014ae97f735b2c37bbc7c487b7a50f1f6062afc593ab9da40
3
+ size 1445256905
detail_encoder/.DS_Store ADDED
Binary file (6.15 kB). View file
 
detail_encoder/__init__.py ADDED
File without changes
detail_encoder/_clip.py ADDED
@@ -0,0 +1,1349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The OpenAI Team Authors and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch CLIP model."""
16
+
17
+
18
+ from dataclasses import dataclass
19
+ from typing import Any, Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.utils.checkpoint
23
+ from torch import nn
24
+
25
+ from transformers.activations import ACT2FN
26
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
27
+ from transformers.modeling_utils import PreTrainedModel
28
+ from transformers.utils import (
29
+ ModelOutput,
30
+ add_start_docstrings,
31
+ add_start_docstrings_to_model_forward,
32
+ logging,
33
+ replace_return_docstrings,
34
+ )
35
+ from transformers.models.clip.configuration_clip import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
36
+
37
+
38
+ logger = logging.get_logger(__name__)
39
+
40
+ _CHECKPOINT_FOR_DOC = "openai/clip-vit-base-patch32"
41
+
42
+ CLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [
43
+ "openai/clip-vit-base-patch32",
44
+ # See all CLIP models at https://huggingface.co/models?filter=clip
45
+ ]
46
+
47
+
48
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
49
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
50
+ """
51
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
52
+ """
53
+ bsz, src_len = mask.size()
54
+ tgt_len = tgt_len if tgt_len is not None else src_len
55
+
56
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
57
+
58
+ inverted_mask = 1.0 - expanded_mask
59
+
60
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
61
+
62
+
63
+ # contrastive loss function, adapted from
64
+ # https://sachinruk.github.io/blog/2021-03-07-clip.html
65
+ def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
66
+ return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))
67
+
68
+
69
+ def clip_loss(similarity: torch.Tensor) -> torch.Tensor:
70
+ caption_loss = contrastive_loss(similarity)
71
+ image_loss = contrastive_loss(similarity.t())
72
+ return (caption_loss + image_loss) / 2.0
73
+
74
+
75
+ @dataclass
76
+ class CLIPVisionModelOutput(ModelOutput):
77
+ """
78
+ Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
79
+
80
+ Args:
81
+ image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
82
+ The image embeddings obtained by applying the projection layer to the pooler_output.
83
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
84
+ Sequence of hidden-states at the output of the last layer of the model.
85
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
86
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
87
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
88
+
89
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
90
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
91
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
92
+ sequence_length)`.
93
+
94
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
95
+ heads.
96
+ """
97
+
98
+ image_embeds: Optional[torch.FloatTensor] = None
99
+ last_hidden_state: torch.FloatTensor = None
100
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
101
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
102
+
103
+
104
+ @dataclass
105
+ class CLIPTextModelOutput(ModelOutput):
106
+ """
107
+ Base class for text model's outputs that also contains a pooling of the last hidden states.
108
+
109
+ Args:
110
+ text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
111
+ The text embeddings obtained by applying the projection layer to the pooler_output.
112
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
113
+ Sequence of hidden-states at the output of the last layer of the model.
114
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
115
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
116
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
117
+
118
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
119
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
120
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
121
+ sequence_length)`.
122
+
123
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
124
+ heads.
125
+ """
126
+
127
+ text_embeds: Optional[torch.FloatTensor] = None
128
+ last_hidden_state: torch.FloatTensor = None
129
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
130
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
131
+
132
+
133
+ @dataclass
134
+ class CLIPOutput(ModelOutput):
135
+ """
136
+ Args:
137
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
138
+ Contrastive loss for image-text similarity.
139
+ logits_per_image:(`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
140
+ The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
141
+ similarity scores.
142
+ logits_per_text:(`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
143
+ The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
144
+ similarity scores.
145
+ text_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
146
+ The text embeddings obtained by applying the projection layer to the pooled output of [`CLIPTextModel`].
147
+ image_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
148
+ The image embeddings obtained by applying the projection layer to the pooled output of [`CLIPVisionModel`].
149
+ text_model_output(`BaseModelOutputWithPooling`):
150
+ The output of the [`CLIPTextModel`].
151
+ vision_model_output(`BaseModelOutputWithPooling`):
152
+ The output of the [`CLIPVisionModel`].
153
+ """
154
+
155
+ loss: Optional[torch.FloatTensor] = None
156
+ logits_per_image: torch.FloatTensor = None
157
+ logits_per_text: torch.FloatTensor = None
158
+ text_embeds: torch.FloatTensor = None
159
+ image_embeds: torch.FloatTensor = None
160
+ text_model_output: BaseModelOutputWithPooling = None
161
+ vision_model_output: BaseModelOutputWithPooling = None
162
+
163
+ def to_tuple(self) -> Tuple[Any]:
164
+ return tuple(
165
+ self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
166
+ for k in self.keys()
167
+ )
168
+
169
+
170
+ class CLIPVisionEmbeddings(nn.Module):
171
+ def __init__(self, config: CLIPVisionConfig):
172
+ super().__init__()
173
+ self.config = config
174
+ self.embed_dim = config.hidden_size
175
+ self.image_size = config.image_size
176
+ self.patch_size = config.patch_size
177
+
178
+ self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
179
+
180
+ self.patch_embedding = nn.Conv2d(
181
+ in_channels=config.num_channels,
182
+ out_channels=self.embed_dim,
183
+ kernel_size=self.patch_size,
184
+ stride=self.patch_size,
185
+ bias=False,
186
+ )
187
+
188
+ self.num_patches = (self.image_size // self.patch_size) ** 2
189
+ self.num_positions = self.num_patches + 1
190
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
191
+ self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
192
+
193
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
194
+ batch_size = pixel_values.shape[0]
195
+ target_dtype = self.patch_embedding.weight.dtype
196
+ patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
197
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
198
+
199
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1)
200
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
201
+ embeddings = embeddings + self.position_embedding(self.position_ids)
202
+ return embeddings
203
+
204
+
205
+ class CLIPTextEmbeddings(nn.Module):
206
+ def __init__(self, config: CLIPTextConfig):
207
+ super().__init__()
208
+ embed_dim = config.hidden_size
209
+
210
+ self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
211
+ self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
212
+
213
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
214
+ self.register_buffer(
215
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
216
+ )
217
+
218
+ def forward(
219
+ self,
220
+ input_ids: Optional[torch.LongTensor] = None,
221
+ position_ids: Optional[torch.LongTensor] = None,
222
+ inputs_embeds: Optional[torch.FloatTensor] = None,
223
+ ) -> torch.Tensor:
224
+ seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
225
+
226
+ if position_ids is None:
227
+ position_ids = self.position_ids[:, :seq_length]
228
+
229
+ if inputs_embeds is None:
230
+ inputs_embeds = self.token_embedding(input_ids)
231
+
232
+ position_embeddings = self.position_embedding(position_ids)
233
+ embeddings = inputs_embeds + position_embeddings
234
+
235
+ return embeddings
236
+
237
+
238
+ class CLIPAttention(nn.Module):
239
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
240
+
241
+ def __init__(self, config):
242
+ super().__init__()
243
+ self.config = config
244
+ self.embed_dim = config.hidden_size
245
+ self.num_heads = config.num_attention_heads
246
+ self.head_dim = self.embed_dim // self.num_heads
247
+ if self.head_dim * self.num_heads != self.embed_dim:
248
+ raise ValueError(
249
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
250
+ f" {self.num_heads})."
251
+ )
252
+ self.scale = self.head_dim**-0.5
253
+ self.dropout = config.attention_dropout
254
+
255
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
256
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
257
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
258
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
259
+
260
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
261
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
262
+
263
+ def forward(
264
+ self,
265
+ hidden_states: torch.Tensor,
266
+ attention_mask: Optional[torch.Tensor] = None,
267
+ causal_attention_mask: Optional[torch.Tensor] = None,
268
+ output_attentions: Optional[bool] = False,
269
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
270
+ """Input shape: Batch x Time x Channel"""
271
+
272
+ bsz, tgt_len, embed_dim = hidden_states.size()
273
+
274
+ # get query proj
275
+ query_states = self.q_proj(hidden_states) * self.scale
276
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
277
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
278
+
279
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
280
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
281
+ key_states = key_states.view(*proj_shape)
282
+ value_states = value_states.view(*proj_shape)
283
+
284
+ src_len = key_states.size(1)
285
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
286
+
287
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
288
+ raise ValueError(
289
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
290
+ f" {attn_weights.size()}"
291
+ )
292
+
293
+ # apply the causal_attention_mask first
294
+ if causal_attention_mask is not None:
295
+ if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
296
+ raise ValueError(
297
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
298
+ f" {causal_attention_mask.size()}"
299
+ )
300
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
301
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
302
+
303
+ if attention_mask is not None:
304
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
305
+ raise ValueError(
306
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
307
+ )
308
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
309
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
310
+
311
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
312
+
313
+ if output_attentions:
314
+ # this operation is a bit akward, but it's required to
315
+ # make sure that attn_weights keeps its gradient.
316
+ # In order to do so, attn_weights have to reshaped
317
+ # twice and have to be reused in the following
318
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
319
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
320
+ else:
321
+ attn_weights_reshaped = None
322
+
323
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
324
+
325
+ attn_output = torch.bmm(attn_probs, value_states)
326
+
327
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
328
+ raise ValueError(
329
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
330
+ f" {attn_output.size()}"
331
+ )
332
+
333
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
334
+ attn_output = attn_output.transpose(1, 2)
335
+ attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
336
+
337
+ attn_output = self.out_proj(attn_output)
338
+
339
+ return attn_output, attn_weights_reshaped
340
+
341
+
342
+ class CLIPMLP(nn.Module):
343
+ def __init__(self, config):
344
+ super().__init__()
345
+ self.config = config
346
+ self.activation_fn = ACT2FN[config.hidden_act]
347
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
348
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
349
+
350
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
351
+ hidden_states = self.fc1(hidden_states)
352
+ hidden_states = self.activation_fn(hidden_states)
353
+ hidden_states = self.fc2(hidden_states)
354
+ return hidden_states
355
+
356
+
357
+ class CLIPEncoderLayer(nn.Module):
358
+ def __init__(self, config: CLIPConfig):
359
+ super().__init__()
360
+ self.embed_dim = config.hidden_size
361
+ self.self_attn = CLIPAttention(config)
362
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
363
+ self.mlp = CLIPMLP(config)
364
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
365
+
366
+ def forward(
367
+ self,
368
+ hidden_states: torch.Tensor,
369
+ attention_mask: torch.Tensor,
370
+ causal_attention_mask: torch.Tensor,
371
+ output_attentions: Optional[bool] = False,
372
+ ) -> Tuple[torch.FloatTensor]:
373
+ """
374
+ Args:
375
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
376
+ attention_mask (`torch.FloatTensor`): attention mask of size
377
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
378
+ `(config.encoder_attention_heads,)`.
379
+ output_attentions (`bool`, *optional*):
380
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
381
+ returned tensors for more detail.
382
+ """
383
+ residual = hidden_states
384
+
385
+ hidden_states = self.layer_norm1(hidden_states)
386
+ hidden_states, attn_weights = self.self_attn(
387
+ hidden_states=hidden_states,
388
+ attention_mask=attention_mask,
389
+ causal_attention_mask=causal_attention_mask,
390
+ output_attentions=output_attentions,
391
+ )
392
+ hidden_states = residual + hidden_states
393
+
394
+ residual = hidden_states
395
+ hidden_states = self.layer_norm2(hidden_states)
396
+ hidden_states = self.mlp(hidden_states)
397
+ hidden_states = residual + hidden_states
398
+
399
+ outputs = (hidden_states,)
400
+
401
+ if output_attentions:
402
+ outputs += (attn_weights,)
403
+
404
+ return outputs
405
+
406
+
407
+ class CLIPPreTrainedModel(PreTrainedModel):
408
+ """
409
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
410
+ models.
411
+ """
412
+
413
+ config_class = CLIPConfig
414
+ base_model_prefix = "clip"
415
+ supports_gradient_checkpointing = True
416
+
417
+ def _init_weights(self, module):
418
+ """Initialize the weights"""
419
+ factor = self.config.initializer_factor
420
+ if isinstance(module, CLIPTextEmbeddings):
421
+ module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
422
+ module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
423
+ elif isinstance(module, CLIPVisionEmbeddings):
424
+ factor = self.config.initializer_factor
425
+ nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
426
+ nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
427
+ nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor)
428
+ elif isinstance(module, CLIPAttention):
429
+ factor = self.config.initializer_factor
430
+ in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
431
+ out_proj_std = (module.embed_dim**-0.5) * factor
432
+ nn.init.normal_(module.q_proj.weight, std=in_proj_std)
433
+ nn.init.normal_(module.k_proj.weight, std=in_proj_std)
434
+ nn.init.normal_(module.v_proj.weight, std=in_proj_std)
435
+ nn.init.normal_(module.out_proj.weight, std=out_proj_std)
436
+ elif isinstance(module, CLIPMLP):
437
+ factor = self.config.initializer_factor
438
+ in_proj_std = (
439
+ (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
440
+ )
441
+ fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
442
+ nn.init.normal_(module.fc1.weight, std=fc_std)
443
+ nn.init.normal_(module.fc2.weight, std=in_proj_std)
444
+ elif isinstance(module, CLIPModel):
445
+ nn.init.normal_(
446
+ module.text_projection.weight,
447
+ std=module.text_embed_dim**-0.5 * self.config.initializer_factor,
448
+ )
449
+ nn.init.normal_(
450
+ module.visual_projection.weight,
451
+ std=module.vision_embed_dim**-0.5 * self.config.initializer_factor,
452
+ )
453
+ elif isinstance(module, CLIPVisionModelWithProjection):
454
+ nn.init.normal_(
455
+ module.visual_projection.weight,
456
+ std=self.config.hidden_size**-0.5 * self.config.initializer_factor,
457
+ )
458
+ elif isinstance(module, CLIPTextModelWithProjection):
459
+ nn.init.normal_(
460
+ module.text_projection.weight,
461
+ std=self.config.hidden_size**-0.5 * self.config.initializer_factor,
462
+ )
463
+
464
+ if isinstance(module, nn.LayerNorm):
465
+ module.bias.data.zero_()
466
+ module.weight.data.fill_(1.0)
467
+ if isinstance(module, nn.Linear) and module.bias is not None:
468
+ module.bias.data.zero_()
469
+
470
+ def _set_gradient_checkpointing(self, module, value=False):
471
+ if isinstance(module, CLIPEncoder):
472
+ module.gradient_checkpointing = value
473
+
474
+
475
+ CLIP_START_DOCSTRING = r"""
476
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
477
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
478
+ etc.)
479
+
480
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
481
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
482
+ and behavior.
483
+
484
+ Parameters:
485
+ config ([`CLIPConfig`]): Model configuration class with all the parameters of the model.
486
+ Initializing with a config file does not load the weights associated with the model, only the
487
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
488
+ """
489
+
490
+ CLIP_TEXT_INPUTS_DOCSTRING = r"""
491
+ Args:
492
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
493
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
494
+ it.
495
+
496
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
497
+ [`PreTrainedTokenizer.__call__`] for details.
498
+
499
+ [What are input IDs?](../glossary#input-ids)
500
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
501
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
502
+
503
+ - 1 for tokens that are **not masked**,
504
+ - 0 for tokens that are **masked**.
505
+
506
+ [What are attention masks?](../glossary#attention-mask)
507
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
508
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
509
+ config.max_position_embeddings - 1]`.
510
+
511
+ [What are position IDs?](../glossary#position-ids)
512
+ output_attentions (`bool`, *optional*):
513
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
514
+ tensors for more detail.
515
+ output_hidden_states (`bool`, *optional*):
516
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
517
+ more detail.
518
+ return_dict (`bool`, *optional*):
519
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
520
+ """
521
+
522
+ CLIP_VISION_INPUTS_DOCSTRING = r"""
523
+ Args:
524
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
525
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
526
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
527
+ output_attentions (`bool`, *optional*):
528
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
529
+ tensors for more detail.
530
+ output_hidden_states (`bool`, *optional*):
531
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
532
+ more detail.
533
+ return_dict (`bool`, *optional*):
534
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
535
+ """
536
+
537
+ CLIP_INPUTS_DOCSTRING = r"""
538
+ Args:
539
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
540
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
541
+ it.
542
+
543
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
544
+ [`PreTrainedTokenizer.__call__`] for details.
545
+
546
+ [What are input IDs?](../glossary#input-ids)
547
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
548
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
549
+
550
+ - 1 for tokens that are **not masked**,
551
+ - 0 for tokens that are **masked**.
552
+
553
+ [What are attention masks?](../glossary#attention-mask)
554
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
555
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
556
+ config.max_position_embeddings - 1]`.
557
+
558
+ [What are position IDs?](../glossary#position-ids)
559
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
560
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
561
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
562
+ return_loss (`bool`, *optional*):
563
+ Whether or not to return the contrastive loss.
564
+ output_attentions (`bool`, *optional*):
565
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
566
+ tensors for more detail.
567
+ output_hidden_states (`bool`, *optional*):
568
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
569
+ more detail.
570
+ return_dict (`bool`, *optional*):
571
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
572
+ """
573
+
574
+
575
+ class CLIPEncoder(nn.Module):
576
+ """
577
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
578
+ [`CLIPEncoderLayer`].
579
+
580
+ Args:
581
+ config: CLIPConfig
582
+ """
583
+
584
+ def __init__(self, config: CLIPConfig):
585
+ super().__init__()
586
+ self.config = config
587
+ self.layers = nn.ModuleList([CLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)])
588
+ self.gradient_checkpointing = False
589
+
590
+ def forward(
591
+ self,
592
+ inputs_embeds,
593
+ attention_mask: Optional[torch.Tensor] = None,
594
+ causal_attention_mask: Optional[torch.Tensor] = None,
595
+ output_attentions: Optional[bool] = None,
596
+ output_hidden_states: Optional[bool] = None,
597
+ return_dict: Optional[bool] = None,
598
+ ) -> Union[Tuple, BaseModelOutput]:
599
+ r"""
600
+ Args:
601
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
602
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
603
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
604
+ than the model's internal embedding lookup matrix.
605
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
606
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
607
+
608
+ - 1 for tokens that are **not masked**,
609
+ - 0 for tokens that are **masked**.
610
+
611
+ [What are attention masks?](../glossary#attention-mask)
612
+ causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
613
+ Causal mask for the text model. Mask values selected in `[0, 1]`:
614
+
615
+ - 1 for tokens that are **not masked**,
616
+ - 0 for tokens that are **masked**.
617
+
618
+ [What are attention masks?](../glossary#attention-mask)
619
+ output_attentions (`bool`, *optional*):
620
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
621
+ returned tensors for more detail.
622
+ output_hidden_states (`bool`, *optional*):
623
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
624
+ for more detail.
625
+ return_dict (`bool`, *optional*):
626
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
627
+ """
628
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
629
+ output_hidden_states = (
630
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
631
+ )
632
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
633
+
634
+ encoder_states = () if output_hidden_states else None
635
+ all_attentions = () if output_attentions else None
636
+
637
+ hidden_states = inputs_embeds
638
+ for idx, encoder_layer in enumerate(self.layers):
639
+ if output_hidden_states:
640
+ encoder_states = encoder_states + (hidden_states,)
641
+ if self.gradient_checkpointing and self.training:
642
+
643
+ def create_custom_forward(module):
644
+ def custom_forward(*inputs):
645
+ return module(*inputs, output_attentions)
646
+
647
+ return custom_forward
648
+
649
+ layer_outputs = torch.utils.checkpoint.checkpoint(
650
+ create_custom_forward(encoder_layer),
651
+ hidden_states,
652
+ attention_mask,
653
+ causal_attention_mask,
654
+ )
655
+ else:
656
+ layer_outputs = encoder_layer(
657
+ hidden_states,
658
+ attention_mask,
659
+ causal_attention_mask,
660
+ output_attentions=output_attentions,
661
+ )
662
+
663
+ hidden_states = layer_outputs[0]
664
+
665
+ if output_attentions:
666
+ all_attentions = all_attentions + (layer_outputs[1],)
667
+
668
+ if output_hidden_states:
669
+ encoder_states = encoder_states + (hidden_states,)
670
+
671
+ if not return_dict:
672
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
673
+ return BaseModelOutput(
674
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
675
+ )
676
+
677
+
678
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
679
+ def _make_causal_mask(
680
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
681
+ ):
682
+ """
683
+ Make causal mask used for bi-directional self-attention.
684
+ """
685
+ bsz, tgt_len = input_ids_shape
686
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
687
+ mask_cond = torch.arange(mask.size(-1), device=device)
688
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
689
+ mask = mask.to(dtype)
690
+
691
+ if past_key_values_length > 0:
692
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
693
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
694
+
695
+
696
+ class CLIPTextTransformer(nn.Module):
697
+ def __init__(self, config: CLIPTextConfig):
698
+ super().__init__()
699
+ self.config = config
700
+ embed_dim = config.hidden_size
701
+ self.embeddings = CLIPTextEmbeddings(config)
702
+ self.encoder = CLIPEncoder(config)
703
+ self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
704
+
705
+ # For `pooled_output` computation
706
+ self.eos_token_id = config.eos_token_id
707
+
708
+ @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
709
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig)
710
+ def forward(
711
+ self,
712
+ input_ids: Optional[torch.Tensor] = None,
713
+ attention_mask: Optional[torch.Tensor] = None,
714
+ position_ids: Optional[torch.Tensor] = None,
715
+ output_attentions: Optional[bool] = None,
716
+ output_hidden_states: Optional[bool] = None,
717
+ return_dict: Optional[bool] = None,
718
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
719
+ r"""
720
+ Returns:
721
+
722
+ """
723
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
724
+ output_hidden_states = (
725
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
726
+ )
727
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
728
+
729
+ if input_ids is None:
730
+ raise ValueError("You have to specify input_ids")
731
+
732
+ input_shape = input_ids.size()
733
+ input_ids = input_ids.view(-1, input_shape[-1])
734
+
735
+ hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
736
+
737
+ # CLIP's text model uses causal mask, prepare it here.
738
+ # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
739
+ causal_attention_mask = _make_causal_mask(input_shape, hidden_states.dtype, device=hidden_states.device)
740
+ # expand attention_mask
741
+ if attention_mask is not None:
742
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
743
+ attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
744
+
745
+ encoder_outputs = self.encoder(
746
+ inputs_embeds=hidden_states,
747
+ attention_mask=attention_mask,
748
+ causal_attention_mask=causal_attention_mask,
749
+ output_attentions=output_attentions,
750
+ output_hidden_states=output_hidden_states,
751
+ return_dict=return_dict,
752
+ )
753
+
754
+ last_hidden_state = encoder_outputs[0]
755
+ last_hidden_state = self.final_layer_norm(last_hidden_state)
756
+
757
+ if self.eos_token_id == 2:
758
+ # The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here.
759
+ # A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added
760
+ # ------------------------------------------------------------
761
+ # text_embeds.shape = [batch_size, sequence_length, transformer.width]
762
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
763
+ # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
764
+ pooled_output = last_hidden_state[
765
+ torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
766
+ input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
767
+ ]
768
+ else:
769
+ # The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible)
770
+ pooled_output = last_hidden_state[
771
+ torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
772
+ # We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`)
773
+ (input_ids.to(dtype=torch.int, device=last_hidden_state.device) == self.eos_token_id)
774
+ .int()
775
+ .argmax(dim=-1),
776
+ ]
777
+
778
+ if not return_dict:
779
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
780
+
781
+ return BaseModelOutputWithPooling(
782
+ last_hidden_state=last_hidden_state,
783
+ pooler_output=pooled_output,
784
+ hidden_states=encoder_outputs.hidden_states,
785
+ attentions=encoder_outputs.attentions,
786
+ )
787
+
788
+
789
+ @add_start_docstrings(
790
+ """The text model from CLIP without any head or projection on top.""",
791
+ CLIP_START_DOCSTRING,
792
+ )
793
+ class CLIPTextModel(CLIPPreTrainedModel):
794
+ config_class = CLIPTextConfig
795
+
796
+ _no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"]
797
+
798
+ def __init__(self, config: CLIPTextConfig):
799
+ super().__init__(config)
800
+ self.text_model = CLIPTextTransformer(config)
801
+ # Initialize weights and apply final processing
802
+ self.post_init()
803
+
804
+ def get_input_embeddings(self) -> nn.Module:
805
+ return self.text_model.embeddings.token_embedding
806
+
807
+ def set_input_embeddings(self, value):
808
+ self.text_model.embeddings.token_embedding = value
809
+
810
+ @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
811
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig)
812
+ def forward(
813
+ self,
814
+ input_ids: Optional[torch.Tensor] = None,
815
+ attention_mask: Optional[torch.Tensor] = None,
816
+ position_ids: Optional[torch.Tensor] = None,
817
+ output_attentions: Optional[bool] = None,
818
+ output_hidden_states: Optional[bool] = None,
819
+ return_dict: Optional[bool] = None,
820
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
821
+ r"""
822
+ Returns:
823
+
824
+ Examples:
825
+
826
+ ```python
827
+ >>> from transformers import AutoTokenizer, CLIPTextModel
828
+
829
+ >>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
830
+ >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
831
+
832
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
833
+
834
+ >>> outputs = model(**inputs)
835
+ >>> last_hidden_state = outputs.last_hidden_state
836
+ >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
837
+ ```"""
838
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
839
+
840
+ return self.text_model(
841
+ input_ids=input_ids,
842
+ attention_mask=attention_mask,
843
+ position_ids=position_ids,
844
+ output_attentions=output_attentions,
845
+ output_hidden_states=output_hidden_states,
846
+ return_dict=return_dict,
847
+ )
848
+
849
+
850
+ class CLIPVisionTransformer(nn.Module):
851
+ def __init__(self, config: CLIPVisionConfig):
852
+ super().__init__()
853
+ self.config = config
854
+ embed_dim = config.hidden_size
855
+
856
+ self.embeddings = CLIPVisionEmbeddings(config)
857
+ self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
858
+ self.encoder = CLIPEncoder(config)
859
+ self.post_layernorm = None
860
+
861
+ @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
862
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPVisionConfig)
863
+ def forward(
864
+ self,
865
+ pixel_values: Optional[torch.FloatTensor] = None,
866
+ output_attentions: Optional[bool] = None,
867
+ output_hidden_states: Optional[bool] = None,
868
+ return_dict: Optional[bool] = None,
869
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
870
+ r"""
871
+ Returns:
872
+
873
+ """
874
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
875
+ output_hidden_states = (
876
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
877
+ )
878
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
879
+
880
+ if pixel_values is None:
881
+ raise ValueError("You have to specify pixel_values")
882
+
883
+ hidden_states = self.embeddings(pixel_values)
884
+ hidden_states = self.pre_layrnorm(hidden_states)
885
+
886
+ encoder_outputs = self.encoder(
887
+ inputs_embeds=hidden_states,
888
+ output_attentions=output_attentions,
889
+ output_hidden_states=output_hidden_states,
890
+ return_dict=return_dict,
891
+ )
892
+
893
+ last_hidden_state = encoder_outputs[0]
894
+ # pooled_output = last_hidden_state[:, 0, :]
895
+ # pooled_output = self.post_layernorm(pooled_output)
896
+ pooled_output = None
897
+
898
+ if not return_dict:
899
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
900
+
901
+ return BaseModelOutputWithPooling(
902
+ last_hidden_state=last_hidden_state,
903
+ pooler_output=pooled_output,
904
+ hidden_states=encoder_outputs.hidden_states,
905
+ attentions=encoder_outputs.attentions,
906
+ )
907
+
908
+
909
+ @add_start_docstrings(
910
+ """The vision model from CLIP without any head or projection on top.""",
911
+ CLIP_START_DOCSTRING,
912
+ )
913
+ class CLIPVisionModel(CLIPPreTrainedModel):
914
+ config_class = CLIPVisionConfig
915
+ main_input_name = "pixel_values"
916
+
917
+ def __init__(self, config: CLIPVisionConfig):
918
+ super().__init__(config)
919
+ self.vision_model = CLIPVisionTransformer(config)
920
+ # Initialize weights and apply final processing
921
+ self.post_init()
922
+
923
+ def get_input_embeddings(self) -> nn.Module:
924
+ return self.vision_model.embeddings.patch_embedding
925
+
926
+ @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
927
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPVisionConfig)
928
+ def forward(
929
+ self,
930
+ pixel_values: Optional[torch.FloatTensor] = None,
931
+ output_attentions: Optional[bool] = None,
932
+ output_hidden_states: Optional[bool] = None,
933
+ return_dict: Optional[bool] = None,
934
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
935
+ r"""
936
+ Returns:
937
+
938
+ Examples:
939
+
940
+ ```python
941
+ >>> from PIL import Image
942
+ >>> import requests
943
+ >>> from transformers import AutoProcessor, CLIPVisionModel
944
+
945
+ >>> model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
946
+ >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
947
+
948
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
949
+ >>> image = Image.open(requests.get(url, stream=True).raw)
950
+
951
+ >>> inputs = processor(images=image, return_tensors="pt")
952
+
953
+ >>> outputs = model(**inputs)
954
+ >>> last_hidden_state = outputs.last_hidden_state
955
+ >>> pooled_output = outputs.pooler_output # pooled CLS states
956
+ ```"""
957
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
958
+
959
+ return self.vision_model(
960
+ pixel_values=pixel_values,
961
+ output_attentions=output_attentions,
962
+ output_hidden_states=output_hidden_states,
963
+ return_dict=return_dict,
964
+ )
965
+
966
+
967
+ @add_start_docstrings(CLIP_START_DOCSTRING)
968
+ class CLIPModel(CLIPPreTrainedModel):
969
+ config_class = CLIPConfig
970
+
971
+ def __init__(self, config: CLIPConfig):
972
+ super().__init__(config)
973
+
974
+ if not isinstance(config.text_config, CLIPTextConfig):
975
+ raise ValueError(
976
+ "config.text_config is expected to be of type CLIPTextConfig but is of type"
977
+ f" {type(config.text_config)}."
978
+ )
979
+
980
+ if not isinstance(config.vision_config, CLIPVisionConfig):
981
+ raise ValueError(
982
+ "config.vision_config is expected to be of type CLIPVisionConfig but is of type"
983
+ f" {type(config.vision_config)}."
984
+ )
985
+
986
+ text_config = config.text_config
987
+ vision_config = config.vision_config
988
+
989
+ self.projection_dim = config.projection_dim
990
+ self.text_embed_dim = text_config.hidden_size
991
+ self.vision_embed_dim = vision_config.hidden_size
992
+
993
+ self.text_model = CLIPTextTransformer(text_config)
994
+ self.vision_model = CLIPVisionTransformer(vision_config)
995
+
996
+ self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False)
997
+ self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
998
+ self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
999
+
1000
+ # Initialize weights and apply final processing
1001
+ self.post_init()
1002
+
1003
+ @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
1004
+ def get_text_features(
1005
+ self,
1006
+ input_ids: Optional[torch.Tensor] = None,
1007
+ attention_mask: Optional[torch.Tensor] = None,
1008
+ position_ids: Optional[torch.Tensor] = None,
1009
+ output_attentions: Optional[bool] = None,
1010
+ output_hidden_states: Optional[bool] = None,
1011
+ return_dict: Optional[bool] = None,
1012
+ ) -> torch.FloatTensor:
1013
+ r"""
1014
+ Returns:
1015
+ text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
1016
+ applying the projection layer to the pooled output of [`CLIPTextModel`].
1017
+
1018
+ Examples:
1019
+
1020
+ ```python
1021
+ >>> from transformers import AutoTokenizer, CLIPModel
1022
+
1023
+ >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
1024
+ >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
1025
+
1026
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
1027
+ >>> text_features = model.get_text_features(**inputs)
1028
+ ```"""
1029
+ # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
1030
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1031
+ output_hidden_states = (
1032
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1033
+ )
1034
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1035
+
1036
+ text_outputs = self.text_model(
1037
+ input_ids=input_ids,
1038
+ attention_mask=attention_mask,
1039
+ position_ids=position_ids,
1040
+ output_attentions=output_attentions,
1041
+ output_hidden_states=output_hidden_states,
1042
+ return_dict=return_dict,
1043
+ )
1044
+
1045
+ pooled_output = text_outputs[1]
1046
+ text_features = self.text_projection(pooled_output)
1047
+
1048
+ return text_features
1049
+
1050
+ @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
1051
+ def get_image_features(
1052
+ self,
1053
+ pixel_values: Optional[torch.FloatTensor] = None,
1054
+ output_attentions: Optional[bool] = None,
1055
+ output_hidden_states: Optional[bool] = None,
1056
+ return_dict: Optional[bool] = None,
1057
+ ) -> torch.FloatTensor:
1058
+ r"""
1059
+ Returns:
1060
+ image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
1061
+ applying the projection layer to the pooled output of [`CLIPVisionModel`].
1062
+
1063
+ Examples:
1064
+
1065
+ ```python
1066
+ >>> from PIL import Image
1067
+ >>> import requests
1068
+ >>> from transformers import AutoProcessor, CLIPModel
1069
+
1070
+ >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
1071
+ >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
1072
+
1073
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1074
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1075
+
1076
+ >>> inputs = processor(images=image, return_tensors="pt")
1077
+
1078
+ >>> image_features = model.get_image_features(**inputs)
1079
+ ```"""
1080
+ # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
1081
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1082
+ output_hidden_states = (
1083
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1084
+ )
1085
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1086
+
1087
+ vision_outputs = self.vision_model(
1088
+ pixel_values=pixel_values,
1089
+ output_attentions=output_attentions,
1090
+ output_hidden_states=output_hidden_states,
1091
+ return_dict=return_dict,
1092
+ )
1093
+
1094
+ pooled_output = vision_outputs[1] # pooled_output
1095
+ image_features = self.visual_projection(pooled_output)
1096
+
1097
+ return image_features
1098
+
1099
+ @add_start_docstrings_to_model_forward(CLIP_INPUTS_DOCSTRING)
1100
+ @replace_return_docstrings(output_type=CLIPOutput, config_class=CLIPConfig)
1101
+ def forward(
1102
+ self,
1103
+ input_ids: Optional[torch.LongTensor] = None,
1104
+ pixel_values: Optional[torch.FloatTensor] = None,
1105
+ attention_mask: Optional[torch.Tensor] = None,
1106
+ position_ids: Optional[torch.LongTensor] = None,
1107
+ return_loss: Optional[bool] = None,
1108
+ output_attentions: Optional[bool] = None,
1109
+ output_hidden_states: Optional[bool] = None,
1110
+ return_dict: Optional[bool] = None,
1111
+ ) -> Union[Tuple, CLIPOutput]:
1112
+ r"""
1113
+ Returns:
1114
+
1115
+ Examples:
1116
+
1117
+ ```python
1118
+ >>> from PIL import Image
1119
+ >>> import requests
1120
+ >>> from transformers import AutoProcessor, CLIPModel
1121
+
1122
+ >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
1123
+ >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
1124
+
1125
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1126
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1127
+
1128
+ >>> inputs = processor(
1129
+ ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True
1130
+ ... )
1131
+
1132
+ >>> outputs = model(**inputs)
1133
+ >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
1134
+ >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
1135
+ ```"""
1136
+ # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
1137
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1138
+ output_hidden_states = (
1139
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1140
+ )
1141
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1142
+
1143
+ vision_outputs = self.vision_model(
1144
+ pixel_values=pixel_values,
1145
+ output_attentions=output_attentions,
1146
+ output_hidden_states=output_hidden_states,
1147
+ return_dict=return_dict,
1148
+ )
1149
+
1150
+ text_outputs = self.text_model(
1151
+ input_ids=input_ids,
1152
+ attention_mask=attention_mask,
1153
+ position_ids=position_ids,
1154
+ output_attentions=output_attentions,
1155
+ output_hidden_states=output_hidden_states,
1156
+ return_dict=return_dict,
1157
+ )
1158
+
1159
+ image_embeds = vision_outputs[1]
1160
+ image_embeds = self.visual_projection(image_embeds)
1161
+
1162
+ text_embeds = text_outputs[1]
1163
+ text_embeds = self.text_projection(text_embeds)
1164
+
1165
+ # normalized features
1166
+ image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
1167
+ text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
1168
+
1169
+ # cosine similarity as logits
1170
+ logit_scale = self.logit_scale.exp()
1171
+ logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
1172
+ logits_per_image = logits_per_text.t()
1173
+
1174
+ loss = None
1175
+ if return_loss:
1176
+ loss = clip_loss(logits_per_text)
1177
+
1178
+ if not return_dict:
1179
+ output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
1180
+ return ((loss,) + output) if loss is not None else output
1181
+
1182
+ return CLIPOutput(
1183
+ loss=loss,
1184
+ logits_per_image=logits_per_image,
1185
+ logits_per_text=logits_per_text,
1186
+ text_embeds=text_embeds,
1187
+ image_embeds=image_embeds,
1188
+ text_model_output=text_outputs,
1189
+ vision_model_output=vision_outputs,
1190
+ )
1191
+
1192
+
1193
+ @add_start_docstrings(
1194
+ """
1195
+ CLIP Text Model with a projection layer on top (a linear layer on top of the pooled output).
1196
+ """,
1197
+ CLIP_START_DOCSTRING,
1198
+ )
1199
+ class CLIPTextModelWithProjection(CLIPPreTrainedModel):
1200
+ config_class = CLIPTextConfig
1201
+
1202
+ _no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"]
1203
+
1204
+ def __init__(self, config: CLIPTextConfig):
1205
+ super().__init__(config)
1206
+
1207
+ self.text_model = CLIPTextTransformer(config)
1208
+
1209
+ self.text_projection = nn.Linear(config.hidden_size, config.projection_dim, bias=False)
1210
+
1211
+ # Initialize weights and apply final processing
1212
+ self.post_init()
1213
+
1214
+ def get_input_embeddings(self) -> nn.Module:
1215
+ return self.text_model.embeddings.token_embedding
1216
+
1217
+ def set_input_embeddings(self, value):
1218
+ self.text_model.embeddings.token_embedding = value
1219
+
1220
+ @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
1221
+ @replace_return_docstrings(output_type=CLIPTextModelOutput, config_class=CLIPTextConfig)
1222
+ def forward(
1223
+ self,
1224
+ input_ids: Optional[torch.Tensor] = None,
1225
+ attention_mask: Optional[torch.Tensor] = None,
1226
+ position_ids: Optional[torch.Tensor] = None,
1227
+ output_attentions: Optional[bool] = None,
1228
+ output_hidden_states: Optional[bool] = None,
1229
+ return_dict: Optional[bool] = None,
1230
+ ) -> Union[Tuple, CLIPTextModelOutput]:
1231
+ r"""
1232
+ Returns:
1233
+
1234
+ Examples:
1235
+
1236
+ ```python
1237
+ >>> from transformers import AutoTokenizer, CLIPTextModelWithProjection
1238
+
1239
+ >>> model = CLIPTextModelWithProjection.from_pretrained("openai/clip-vit-base-patch32")
1240
+ >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
1241
+
1242
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
1243
+
1244
+ >>> outputs = model(**inputs)
1245
+ >>> text_embeds = outputs.text_embeds
1246
+ ```"""
1247
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1248
+
1249
+ text_outputs = self.text_model(
1250
+ input_ids=input_ids,
1251
+ attention_mask=attention_mask,
1252
+ position_ids=position_ids,
1253
+ output_attentions=output_attentions,
1254
+ output_hidden_states=output_hidden_states,
1255
+ return_dict=return_dict,
1256
+ )
1257
+
1258
+ pooled_output = text_outputs[1]
1259
+
1260
+ text_embeds = self.text_projection(pooled_output)
1261
+
1262
+ if not return_dict:
1263
+ outputs = (text_embeds, text_outputs[0]) + text_outputs[2:]
1264
+ return tuple(output for output in outputs if output is not None)
1265
+
1266
+ return CLIPTextModelOutput(
1267
+ text_embeds=text_embeds,
1268
+ last_hidden_state=text_outputs.last_hidden_state,
1269
+ hidden_states=text_outputs.hidden_states,
1270
+ attentions=text_outputs.attentions,
1271
+ )
1272
+
1273
+
1274
+ @add_start_docstrings(
1275
+ """
1276
+ CLIP Vision Model with a projection layer on top (a linear layer on top of the pooled output).
1277
+ """,
1278
+ CLIP_START_DOCSTRING,
1279
+ )
1280
+ class CLIPVisionModelWithProjection(CLIPPreTrainedModel):
1281
+ config_class = CLIPVisionConfig
1282
+ main_input_name = "pixel_values"
1283
+
1284
+ def __init__(self, config: CLIPVisionConfig):
1285
+ super().__init__(config)
1286
+
1287
+ self.vision_model = CLIPVisionTransformer(config)
1288
+
1289
+ self.visual_projection = nn.Linear(config.hidden_size, config.projection_dim, bias=False)
1290
+
1291
+ # Initialize weights and apply final processing
1292
+ self.post_init()
1293
+
1294
+ def get_input_embeddings(self) -> nn.Module:
1295
+ return self.vision_model.embeddings.patch_embedding
1296
+
1297
+ @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
1298
+ @replace_return_docstrings(output_type=CLIPVisionModelOutput, config_class=CLIPVisionConfig)
1299
+ def forward(
1300
+ self,
1301
+ pixel_values: Optional[torch.FloatTensor] = None,
1302
+ output_attentions: Optional[bool] = None,
1303
+ output_hidden_states: Optional[bool] = None,
1304
+ return_dict: Optional[bool] = None,
1305
+ ) -> Union[Tuple, CLIPVisionModelOutput]:
1306
+ r"""
1307
+ Returns:
1308
+
1309
+ Examples:
1310
+
1311
+ ```python
1312
+ >>> from PIL import Image
1313
+ >>> import requests
1314
+ >>> from transformers import AutoProcessor, CLIPVisionModelWithProjection
1315
+
1316
+ >>> model = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-base-patch32")
1317
+ >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
1318
+
1319
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1320
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1321
+
1322
+ >>> inputs = processor(images=image, return_tensors="pt")
1323
+
1324
+ >>> outputs = model(**inputs)
1325
+ >>> image_embeds = outputs.image_embeds
1326
+ ```"""
1327
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1328
+
1329
+ vision_outputs = self.vision_model(
1330
+ pixel_values=pixel_values,
1331
+ output_attentions=output_attentions,
1332
+ output_hidden_states=output_hidden_states,
1333
+ return_dict=return_dict,
1334
+ )
1335
+
1336
+ pooled_output = vision_outputs[1] # pooled_output
1337
+
1338
+ image_embeds = self.visual_projection(pooled_output)
1339
+
1340
+ if not return_dict:
1341
+ outputs = (image_embeds, vision_outputs[0]) + vision_outputs[2:]
1342
+ return tuple(output for output in outputs if output is not None)
1343
+
1344
+ return CLIPVisionModelOutput(
1345
+ image_embeds=image_embeds,
1346
+ last_hidden_state=vision_outputs.last_hidden_state,
1347
+ hidden_states=vision_outputs.hidden_states,
1348
+ attentions=vision_outputs.attentions,
1349
+ )
detail_encoder/attention_processor.py ADDED
@@ -0,0 +1,687 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from diffusers.utils.import_utils import is_xformers_available
6
+ from torchvision import transforms
7
+ if is_xformers_available():
8
+ import xformers
9
+ import xformers.ops
10
+ else:
11
+ xformers = None
12
+
13
+ class SSRAttnProcessor(nn.Module):
14
+ r"""
15
+ Attention processor for SSR-Adapater.
16
+ """
17
+
18
+ def __init__(self, hidden_size, cross_attention_dim=None, scale=1):
19
+ super().__init__()
20
+ self.hidden_size = hidden_size
21
+ self.cross_attention_dim = cross_attention_dim
22
+ self.scale = scale
23
+ # self.to_q_SSR = nn.Linear(hidden_size, hidden_size, bias=False)
24
+ self.to_k_SSR = nn.Linear(cross_attention_dim, hidden_size, bias=False)
25
+ self.to_v_SSR = nn.Linear(cross_attention_dim, hidden_size, bias=False)
26
+
27
+ def __call__(
28
+ self,
29
+ attn,
30
+ hidden_states,
31
+ encoder_hidden_states=None,
32
+ attention_mask=None,
33
+ temb=None,
34
+ ):
35
+ residual = hidden_states
36
+
37
+ if attn.spatial_norm is not None:
38
+ hidden_states = attn.spatial_norm(hidden_states, temb)
39
+
40
+ input_ndim = hidden_states.ndim
41
+
42
+ if input_ndim == 4:
43
+ batch_size, channel, height, width = hidden_states.shape
44
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
45
+
46
+ batch_size, sequence_length, _ = (
47
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
48
+ )
49
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
50
+
51
+ if attn.group_norm is not None:
52
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
53
+
54
+ # query = self.to_q_SSR(hidden_states)
55
+ query = attn.to_q(hidden_states)
56
+ query = attn.head_to_batch_dim(query)
57
+
58
+ if encoder_hidden_states is None:
59
+ encoder_hidden_states = hidden_states
60
+ elif attn.norm_cross:
61
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
62
+
63
+ _hidden_states = encoder_hidden_states
64
+ _key = self.to_k_SSR(_hidden_states)
65
+ _value = self.to_v_SSR(_hidden_states)
66
+ _key = attn.head_to_batch_dim(_key)
67
+ _value = attn.head_to_batch_dim(_value)
68
+ _attention_probs = attn.get_attention_scores(query, _key, None)
69
+ _hidden_states = torch.bmm(_attention_probs, _value)
70
+ _hidden_states = attn.batch_to_head_dim(_hidden_states)
71
+ hidden_states = self.scale * _hidden_states
72
+
73
+ # linear proj
74
+ hidden_states = attn.to_out[0](hidden_states)
75
+ # dropout
76
+ hidden_states = attn.to_out[1](hidden_states)
77
+
78
+ if input_ndim == 4:
79
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
80
+
81
+ if attn.residual_connection:
82
+ hidden_states = hidden_states + residual
83
+
84
+ hidden_states = hidden_states / attn.rescale_output_factor
85
+
86
+ return hidden_states
87
+
88
+
89
+ class SSRAttnProcessor2_0(torch.nn.Module):
90
+ r"""
91
+ Attention processor for SSR-Adapater for PyTorch 2.0.
92
+ """
93
+
94
+ def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0):
95
+ super().__init__()
96
+
97
+ if not hasattr(F, "scaled_dot_product_attention"):
98
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
99
+ self.hidden_size = hidden_size
100
+ self.cross_attention_dim = cross_attention_dim
101
+ self.scale = scale
102
+ # self.to_q_SSR = nn.Linear(hidden_size, hidden_size, bias=False)
103
+ self.to_k_SSR = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
104
+ self.to_v_SSR = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
105
+
106
+ def __call__(
107
+ self,
108
+ attn,
109
+ hidden_states,
110
+ encoder_hidden_states=None,
111
+ attention_mask=None,
112
+ temb=None,
113
+ ):
114
+ residual = hidden_states
115
+
116
+ if attn.spatial_norm is not None:
117
+ hidden_states = attn.spatial_norm(hidden_states, temb)
118
+
119
+ input_ndim = hidden_states.ndim
120
+
121
+ if input_ndim == 4:
122
+ batch_size, channel, height, width = hidden_states.shape
123
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
124
+
125
+ batch_size, sequence_length, _ = (
126
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
127
+ )
128
+
129
+ if attention_mask is not None:
130
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
131
+ # scaled_dot_product_attention expects attention_mask shape to be
132
+ # (batch, heads, source_length, target_length)
133
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
134
+
135
+ if attn.group_norm is not None:
136
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
137
+
138
+ # query = self.to_q_SSR(hidden_states)
139
+ query = attn.to_q(hidden_states)
140
+
141
+ if encoder_hidden_states is None:
142
+ encoder_hidden_states = hidden_states
143
+ elif attn.norm_cross:
144
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
145
+
146
+ # split hidden states
147
+ _hidden_states = encoder_hidden_states
148
+
149
+ _key = self.to_k_SSR(_hidden_states)
150
+ _value = self.to_v_SSR(_hidden_states)
151
+ inner_dim = _key.shape[-1]
152
+ head_dim = inner_dim // attn.heads
153
+
154
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
155
+
156
+ _key = _key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
157
+ _value = _value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
158
+
159
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
160
+ # TODO: add support for attn.scale when we move to Torch 2.1
161
+ _hidden_states = F.scaled_dot_product_attention(
162
+ query, _key, _value, attn_mask=None, dropout_p=0.0, is_causal=False
163
+ )
164
+
165
+ _hidden_states = _hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
166
+ _hidden_states = _hidden_states.to(query.dtype)
167
+
168
+ hidden_states = self.scale * _hidden_states
169
+
170
+ # linear proj
171
+ hidden_states = attn.to_out[0](hidden_states)
172
+ # dropout
173
+ hidden_states = attn.to_out[1](hidden_states)
174
+
175
+ if input_ndim == 4:
176
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
177
+
178
+ if attn.residual_connection:
179
+ hidden_states = hidden_states + residual
180
+
181
+ hidden_states = hidden_states / attn.rescale_output_factor
182
+
183
+ return hidden_states
184
+
185
+
186
+ class AttnProcessor2_0(torch.nn.Module):
187
+ r"""
188
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
189
+ """
190
+
191
+ def __init__(
192
+ self,
193
+ hidden_size=None,
194
+ cross_attention_dim=None,
195
+ ):
196
+ super().__init__()
197
+ if not hasattr(F, "scaled_dot_product_attention"):
198
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
199
+
200
+ def __call__(
201
+ self,
202
+ attn,
203
+ hidden_states,
204
+ encoder_hidden_states=None,
205
+ attention_mask=None,
206
+ temb=None,
207
+ ):
208
+ residual = hidden_states
209
+
210
+ if attn.spatial_norm is not None:
211
+ hidden_states = attn.spatial_norm(hidden_states, temb)
212
+
213
+ input_ndim = hidden_states.ndim
214
+
215
+ if input_ndim == 4:
216
+ batch_size, channel, height, width = hidden_states.shape
217
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
218
+
219
+ batch_size, sequence_length, _ = (
220
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
221
+ )
222
+
223
+ if attention_mask is not None:
224
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
225
+ # scaled_dot_product_attention expects attention_mask shape to be
226
+ # (batch, heads, source_length, target_length)
227
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
228
+
229
+ if attn.group_norm is not None:
230
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
231
+
232
+ query = attn.to_q(hidden_states)
233
+
234
+ if encoder_hidden_states is None:
235
+ encoder_hidden_states = hidden_states
236
+ elif attn.norm_cross:
237
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
238
+
239
+ key = attn.to_k(encoder_hidden_states)
240
+ value = attn.to_v(encoder_hidden_states)
241
+
242
+ inner_dim = key.shape[-1]
243
+ head_dim = inner_dim // attn.heads
244
+
245
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
246
+
247
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
248
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
249
+
250
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
251
+ # TODO: add support for attn.scale when we move to Torch 2.1
252
+ hidden_states = F.scaled_dot_product_attention(
253
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
254
+ )
255
+
256
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
257
+ hidden_states = hidden_states.to(query.dtype)
258
+
259
+ # linear proj
260
+ hidden_states = attn.to_out[0](hidden_states)
261
+ # dropout
262
+ hidden_states = attn.to_out[1](hidden_states)
263
+
264
+ if input_ndim == 4:
265
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
266
+
267
+ if attn.residual_connection:
268
+ hidden_states = hidden_states + residual
269
+
270
+ hidden_states = hidden_states / attn.rescale_output_factor
271
+
272
+ return hidden_states
273
+
274
+ class AttnProcessor(nn.Module):
275
+ r"""
276
+ Default processor for performing attention-related computations.
277
+ """
278
+ def __init__(
279
+ self,
280
+ hidden_size=None,
281
+ cross_attention_dim=None,
282
+ ):
283
+ super().__init__()
284
+
285
+ def __call__(
286
+ self,
287
+ attn,
288
+ hidden_states,
289
+ encoder_hidden_states=None,
290
+ attention_mask=None,
291
+ temb=None,
292
+ ):
293
+ residual = hidden_states
294
+
295
+ if attn.spatial_norm is not None:
296
+ hidden_states = attn.spatial_norm(hidden_states, temb)
297
+
298
+ input_ndim = hidden_states.ndim
299
+
300
+ if input_ndim == 4:
301
+ batch_size, channel, height, width = hidden_states.shape
302
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
303
+
304
+ batch_size, sequence_length, _ = (
305
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
306
+ )
307
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
308
+
309
+ if attn.group_norm is not None:
310
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
311
+
312
+ query = attn.to_q(hidden_states)
313
+
314
+ if encoder_hidden_states is None:
315
+ encoder_hidden_states = hidden_states
316
+ elif attn.norm_cross:
317
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
318
+
319
+ key = attn.to_k(encoder_hidden_states)
320
+ value = attn.to_v(encoder_hidden_states)
321
+
322
+ query = attn.head_to_batch_dim(query)
323
+ key = attn.head_to_batch_dim(key)
324
+ value = attn.head_to_batch_dim(value)
325
+
326
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
327
+ hidden_states = torch.bmm(attention_probs, value)
328
+ hidden_states = attn.batch_to_head_dim(hidden_states)
329
+
330
+ # linear proj
331
+ hidden_states = attn.to_out[0](hidden_states)
332
+ # dropout
333
+ hidden_states = attn.to_out[1](hidden_states)
334
+
335
+ if input_ndim == 4:
336
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
337
+
338
+ if attn.residual_connection:
339
+ hidden_states = hidden_states + residual
340
+
341
+ hidden_states = hidden_states / attn.rescale_output_factor
342
+
343
+ return hidden_states
344
+
345
+
346
+ class ConvAttnProcessor:
347
+ def __call__(
348
+ self,
349
+ attn,
350
+ hidden_states,
351
+ encoder_hidden_states=None,
352
+ attention_mask=None,
353
+ ):
354
+ ## map to 2D
355
+ if len(hidden_states.shape) == 4:
356
+ shape = hidden_states.shape
357
+ hidden_states = torch.reshape(hidden_states, (shape[0], shape[1], shape[2] * shape[3]))
358
+ hidden_states = hidden_states.permute(0, 2, 1)
359
+ if encoder_hidden_states is not None:
360
+ if len(encoder_hidden_states.shape) == 4:
361
+ kv_shape = encoder_hidden_states.shape
362
+ encoder_hidden_states = torch.reshape(
363
+ encoder_hidden_states, (kv_shape[0], kv_shape[1], kv_shape[2] * kv_shape[3])
364
+ )
365
+ encoder_hidden_states = encoder_hidden_states.permute(0, 2, 1)
366
+
367
+ # the same to standard attn
368
+ batch_size, sequence_length, _ = (
369
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
370
+ )
371
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
372
+ query = attn.to_q(hidden_states)
373
+
374
+ if encoder_hidden_states is None:
375
+ encoder_hidden_states = hidden_states
376
+ elif attn.norm_cross:
377
+ encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
378
+
379
+ key = attn.to_k(encoder_hidden_states)
380
+ value = attn.to_v(encoder_hidden_states)
381
+
382
+ query = attn.head_to_batch_dim(query)
383
+ key = attn.head_to_batch_dim(key)
384
+ value = attn.head_to_batch_dim(value)
385
+
386
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
387
+ hidden_states = torch.bmm(attention_probs, value)
388
+ hidden_states = attn.batch_to_head_dim(hidden_states)
389
+
390
+ # linear proj
391
+ hidden_states = attn.to_out[0](hidden_states)
392
+ # dropout
393
+ hidden_states = attn.to_out[1](hidden_states)
394
+
395
+ # map back to 4D
396
+ if len(hidden_states.shape) == 3:
397
+ hidden_states = hidden_states.permute(0, 2, 1)
398
+ hidden_states = torch.reshape(hidden_states, (shape[0], shape[1], shape[2], shape[3]))
399
+
400
+ return hidden_states
401
+
402
+
403
+ class SSRAttnProcessor_text(nn.Module):
404
+ r"""
405
+ Attention processor for SSR-Adapater.
406
+ """
407
+
408
+ def __init__(self, hidden_size, cross_attention_dim=None, scale=1):
409
+ super().__init__()
410
+ self.text_context_len = 77
411
+ self.hidden_size = hidden_size
412
+ self.cross_attention_dim = cross_attention_dim
413
+ self.scale = scale
414
+ self.to_k_SSR = nn.Linear(cross_attention_dim, hidden_size, bias=False)
415
+ self.to_v_SSR = nn.Linear(cross_attention_dim, hidden_size, bias=False)
416
+
417
+ def __call__(
418
+ self,
419
+ attn,
420
+ hidden_states,
421
+ encoder_hidden_states=None,
422
+ attention_mask=None,
423
+ temb=None,
424
+ ):
425
+ residual = hidden_states
426
+
427
+ if attn.spatial_norm is not None:
428
+ hidden_states = attn.spatial_norm(hidden_states, temb)
429
+
430
+ input_ndim = hidden_states.ndim
431
+
432
+ if input_ndim == 4:
433
+ batch_size, channel, height, width = hidden_states.shape
434
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
435
+
436
+ batch_size, sequence_length, _ = (
437
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
438
+ )
439
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
440
+
441
+ if attn.group_norm is not None:
442
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
443
+
444
+ query = attn.to_q(hidden_states)
445
+ query = attn.head_to_batch_dim(query)
446
+
447
+ if encoder_hidden_states is None:
448
+ encoder_hidden_states = hidden_states
449
+ elif attn.norm_cross:
450
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
451
+
452
+ # split hidden states
453
+ encoder_hidden_states, _hidden_states = encoder_hidden_states[:, :self.text_context_len,
454
+ :], encoder_hidden_states[:, self.text_context_len:, :]
455
+ encoder_hidden_states = encoder_hidden_states[:, :, :768]
456
+ # for text
457
+ key = attn.to_k(encoder_hidden_states)
458
+ value = attn.to_v(encoder_hidden_states)
459
+
460
+ key = attn.head_to_batch_dim(key)
461
+ value = attn.head_to_batch_dim(value)
462
+
463
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
464
+ hidden_states = torch.bmm(attention_probs, value)
465
+ hidden_states = attn.batch_to_head_dim(hidden_states)
466
+
467
+ # for image
468
+ _key = self.to_k_SSR(_hidden_states)
469
+ _value = self.to_v_SSR(_hidden_states)
470
+ _key = attn.head_to_batch_dim(_key)
471
+ _value = attn.head_to_batch_dim(_value)
472
+ _attention_probs = attn.get_attention_scores(query, _key, None)
473
+ _hidden_states = torch.bmm(_attention_probs, _value)
474
+ _hidden_states = attn.batch_to_head_dim(_hidden_states)
475
+ hidden_states = self.scale * _hidden_states + hidden_states
476
+
477
+ # linear proj
478
+ hidden_states = attn.to_out[0](hidden_states)
479
+ # dropout
480
+ hidden_states = attn.to_out[1](hidden_states)
481
+
482
+ if input_ndim == 4:
483
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
484
+
485
+ if attn.residual_connection:
486
+ hidden_states = hidden_states + residual
487
+
488
+ hidden_states = hidden_states / attn.rescale_output_factor
489
+
490
+ return hidden_states
491
+
492
+
493
+ class SSRAttnProcessor2_0_text(torch.nn.Module):
494
+ r"""
495
+ Attention processor for SSR-Adapater for PyTorch 2.0.
496
+ """
497
+
498
+ def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0):
499
+ super().__init__()
500
+
501
+ if not hasattr(F, "scaled_dot_product_attention"):
502
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
503
+ self.text_context_len = 77
504
+ self.hidden_size = hidden_size
505
+ self.cross_attention_dim = cross_attention_dim
506
+ self.scale = scale
507
+ self.to_k_SSR = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
508
+ self.to_v_SSR = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
509
+
510
+ def __call__(
511
+ self,
512
+ attn,
513
+ hidden_states,
514
+ encoder_hidden_states=None,
515
+ attention_mask=None,
516
+ temb=None,
517
+ ):
518
+ residual = hidden_states
519
+
520
+ if attn.spatial_norm is not None:
521
+ hidden_states = attn.spatial_norm(hidden_states, temb)
522
+
523
+ input_ndim = hidden_states.ndim
524
+
525
+ if input_ndim == 4:
526
+ batch_size, channel, height, width = hidden_states.shape
527
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
528
+
529
+ batch_size, sequence_length, _ = (
530
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
531
+ )
532
+
533
+ if attention_mask is not None:
534
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
535
+ # scaled_dot_product_attention expects attention_mask shape to be
536
+ # (batch, heads, source_length, target_length)
537
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
538
+
539
+ if attn.group_norm is not None:
540
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
541
+
542
+ query = attn.to_q(hidden_states)
543
+
544
+ if encoder_hidden_states is None:
545
+ encoder_hidden_states = hidden_states
546
+ elif attn.norm_cross:
547
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
548
+
549
+ # split hidden states
550
+ encoder_hidden_states, _hidden_states = encoder_hidden_states[:, :self.text_context_len,
551
+ :], encoder_hidden_states[:, self.text_context_len:, :]
552
+
553
+ encoder_hidden_states = encoder_hidden_states[:, :, :768]
554
+ # for text
555
+ key = attn.to_k(encoder_hidden_states)
556
+ value = attn.to_v(encoder_hidden_states)
557
+ inner_dim = key.shape[-1]
558
+ head_dim = inner_dim // attn.heads
559
+
560
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
561
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
562
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
563
+
564
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
565
+ # TODO: add support for attn.scale when we move to Torch 2.1
566
+ hidden_states = F.scaled_dot_product_attention(
567
+ query, key, value, attn_mask=attention_mask, dropout_p = 0.0, is_causal = False
568
+ )
569
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
570
+ hidden_states = hidden_states.to(query.dtype)
571
+
572
+ # for image
573
+ _key = self.to_k_SSR(_hidden_states)
574
+ _value = self.to_v_SSR(_hidden_states)
575
+ inner_dim = _key.shape[-1]
576
+ head_dim = inner_dim // attn.heads
577
+
578
+ _key = _key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
579
+ _value = _value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
580
+
581
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
582
+ # TODO: add support for attn.scale when we move to Torch 2.1
583
+ _hidden_states = F.scaled_dot_product_attention(
584
+ query, _key, _value, attn_mask=None, dropout_p=0.0, is_causal=False
585
+ )
586
+
587
+ _hidden_states = _hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
588
+ _hidden_states = _hidden_states.to(query.dtype)
589
+
590
+ hidden_states = self.scale * _hidden_states + hidden_states
591
+
592
+ # linear proj
593
+ hidden_states = attn.to_out[0](hidden_states)
594
+ # dropout
595
+ hidden_states = attn.to_out[1](hidden_states)
596
+
597
+ if input_ndim == 4:
598
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
599
+
600
+ if attn.residual_connection:
601
+ hidden_states = hidden_states + residual
602
+
603
+ hidden_states = hidden_states / attn.rescale_output_factor
604
+
605
+ return hidden_states
606
+
607
+
608
+ class SSRAttnProcessor_visual(nn.Module):
609
+ r"""
610
+ Attention processor for attn visualization.
611
+ """
612
+
613
+ def __init__(self, hidden_size, cross_attention_dim=None, scale=1, attnstore=None, place_in_unet=None):
614
+ super().__init__()
615
+ self.hidden_size = hidden_size
616
+ self.cross_attention_dim = cross_attention_dim
617
+ self.scale = scale
618
+ self.to_k_SSR = nn.Linear(cross_attention_dim, hidden_size, bias=False)
619
+ self.to_v_SSR = nn.Linear(cross_attention_dim, hidden_size, bias=False)
620
+ self.attnstore = attnstore
621
+ self.place_in_unet = place_in_unet
622
+
623
+ def __call__(
624
+ self,
625
+ attn,
626
+ hidden_states,
627
+ encoder_hidden_states=None,
628
+ attention_mask=None,
629
+ temb=None,
630
+ ):
631
+ residual = hidden_states
632
+
633
+ if attn.spatial_norm is not None:
634
+ hidden_states = attn.spatial_norm(hidden_states, temb)
635
+
636
+ input_ndim = hidden_states.ndim
637
+
638
+ if input_ndim == 4:
639
+ batch_size, channel, height, width = hidden_states.shape
640
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
641
+
642
+ batch_size, sequence_length, _ = (
643
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
644
+ )
645
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
646
+
647
+ if attn.group_norm is not None:
648
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
649
+
650
+ # query = self.to_q_SSR(hidden_states)
651
+ query = attn.to_q(hidden_states)
652
+ query = attn.head_to_batch_dim(query)
653
+
654
+ if encoder_hidden_states is None:
655
+ encoder_hidden_states = hidden_states
656
+ elif attn.norm_cross:
657
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
658
+
659
+ _hidden_states = encoder_hidden_states
660
+ _key = self.to_k_SSR(_hidden_states)
661
+ _value = self.to_v_SSR(_hidden_states)
662
+ _key = attn.head_to_batch_dim(_key)
663
+ _value = attn.head_to_batch_dim(_value)
664
+ _attention_probs = attn.get_attention_scores(query, _key, None)
665
+
666
+ # store attention maps
667
+ is_cross = encoder_hidden_states is not None
668
+ self.attnstore(_attention_probs, is_cross, self.place_in_unet)
669
+
670
+ _hidden_states = torch.bmm(_attention_probs, _value)
671
+ _hidden_states = attn.batch_to_head_dim(_hidden_states)
672
+ hidden_states = self.scale * _hidden_states
673
+
674
+ # linear proj
675
+ hidden_states = attn.to_out[0](hidden_states)
676
+ # dropout
677
+ hidden_states = attn.to_out[1](hidden_states)
678
+
679
+ if input_ndim == 4:
680
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
681
+
682
+ if attn.residual_connection:
683
+ hidden_states = hidden_states + residual
684
+
685
+ hidden_states = hidden_states / attn.rescale_output_factor
686
+
687
+ return hidden_states
detail_encoder/encoder_plus.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ import torch
3
+ from torchvision import transforms
4
+ from transformers import CLIPImageProcessor
5
+ from transformers import CLIPVisionModel as OriginalCLIPVisionModel
6
+ from ._clip import CLIPVisionModel
7
+ from PIL import Image
8
+ import torch.nn.functional as F
9
+ import torch.nn as nn
10
+ import os
11
+
12
+ def is_torch2_available():
13
+ return hasattr(F, "scaled_dot_product_attention")
14
+ if is_torch2_available():
15
+ from .attention_processor import SSRAttnProcessor2_0 as SSRAttnProcessor, AttnProcessor2_0 as AttnProcessor
16
+ else:
17
+ from .attention_processor import SSRAttnProcessor, AttnProcessor
18
+ from .resampler import Resampler
19
+
20
+ class detail_encoder(torch.nn.Module):
21
+ """from SSR-encoder"""
22
+ def __init__(self, unet, image_encoder_path, device="cuda", dtype=torch.float32):
23
+ super().__init__()
24
+ self.device = device
25
+ self.dtype = dtype
26
+
27
+ # load image encoder
28
+ clip_encoder = OriginalCLIPVisionModel.from_pretrained(image_encoder_path)
29
+ self.image_encoder = CLIPVisionModel(clip_encoder.config)
30
+ state_dict = clip_encoder.state_dict()
31
+ self.image_encoder.load_state_dict(state_dict, strict=False)
32
+ self.image_encoder.to(self.device, self.dtype)
33
+ del clip_encoder
34
+ self.clip_image_processor = CLIPImageProcessor()
35
+
36
+ # load SSR layers
37
+ attn_procs = {}
38
+ for name in unet.attn_processors.keys():
39
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
40
+ if name.startswith("mid_block"):
41
+ hidden_size = unet.config.block_out_channels[-1]
42
+ elif name.startswith("up_blocks"):
43
+ block_id = int(name[len("up_blocks.")])
44
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
45
+ elif name.startswith("down_blocks"):
46
+ block_id = int(name[len("down_blocks.")])
47
+ hidden_size = unet.config.block_out_channels[block_id]
48
+ if cross_attention_dim is None:
49
+ attn_procs[name] = AttnProcessor()
50
+ else:
51
+ attn_procs[name] = SSRAttnProcessor(hidden_size=hidden_size, cross_attention_dim=1024, scale=1).to(self.device, dtype=self.dtype)
52
+ unet.set_attn_processor(attn_procs)
53
+ adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())
54
+ self.SSR_layers = adapter_modules
55
+ self.SSR_layers.to(self.device, dtype=self.dtype)
56
+ self.resampler = self.init_proj()
57
+
58
+ def init_proj(self):
59
+ resampler = Resampler().to(self.device, dtype=self.dtype)
60
+ return resampler
61
+
62
+ def forward(self, img):
63
+ image_embeds = self.image_encoder(img, output_hidden_states=True)['hidden_states'][2::2]
64
+ image_embeds = torch.cat(image_embeds, dim=1)
65
+ image_embeds = self.resampler(image_embeds)
66
+ return image_embeds
67
+
68
+ @torch.inference_mode()
69
+ def get_image_embeds(self, pil_image):
70
+ if isinstance(pil_image, Image.Image):
71
+ pil_image = [pil_image]
72
+ clip_image = []
73
+ for pil in pil_image:
74
+ tensor_image = self.clip_image_processor(images=pil, return_tensors="pt").pixel_values.to(self.device, dtype=self.dtype)
75
+ clip_image.append(tensor_image)
76
+ clip_image = torch.cat(clip_image, dim=0)
77
+
78
+ # cond
79
+ clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True)['hidden_states'][2::2] # 1 257*12 1024
80
+ clip_image_embeds = torch.cat(clip_image_embeds, dim=1)
81
+ uncond_clip_image_embeds = self.image_encoder(torch.zeros_like(clip_image), output_hidden_states=True)['hidden_states'][2::2]
82
+ uncond_clip_image_embeds = torch.cat(uncond_clip_image_embeds, dim=1)
83
+ clip_image_embeds = self.resampler(clip_image_embeds)
84
+ uncond_clip_image_embeds = self.resampler(uncond_clip_image_embeds)
85
+ return clip_image_embeds, uncond_clip_image_embeds
86
+
87
+ def generate(
88
+ self,
89
+ id_image,
90
+ makeup_image,
91
+ seed=None,
92
+ guidance_scale=2,
93
+ num_inference_steps=30,
94
+ pipe=None,
95
+ **kwargs,
96
+ ):
97
+ image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(makeup_image)
98
+
99
+ prompt_embeds = image_prompt_embeds
100
+ negative_prompt_embeds = uncond_image_prompt_embeds
101
+
102
+ generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
103
+ image = pipe(
104
+ image=id_image,
105
+ prompt_embeds=prompt_embeds,
106
+ negative_prompt_embeds=negative_prompt_embeds,
107
+ guidance_scale=guidance_scale,
108
+ num_inference_steps=num_inference_steps,
109
+ generator=generator,
110
+ **kwargs,
111
+ ).images[0]
112
+
113
+ return image
detail_encoder/resampler.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ import torch.nn.functional as F
4
+ from torch import nn, einsum
5
+ from inspect import isfunction
6
+
7
+
8
+ def exists(val):
9
+ return val is not None
10
+
11
+ def uniq(arr):
12
+ return{el: True for el in arr}.keys()
13
+
14
+
15
+ def default(val, d):
16
+ if exists(val):
17
+ return val
18
+ return d() if isfunction(d) else d
19
+
20
+
21
+ def max_neg_value(t):
22
+ return -torch.finfo(t.dtype).max
23
+
24
+
25
+ def init_(tensor):
26
+ dim = tensor.shape[-1]
27
+ std = 1 / math.sqrt(dim)
28
+ tensor.uniform_(-std, std)
29
+ return tensor
30
+
31
+
32
+ # feedforward
33
+ class GEGLU(nn.Module):
34
+ def __init__(self, dim_in, dim_out):
35
+ super().__init__()
36
+ self.proj = nn.Linear(dim_in, dim_out * 2)
37
+
38
+ def forward(self, x):
39
+ x, gate = self.proj(x).chunk(2, dim=-1)
40
+ return x * F.gelu(gate)
41
+
42
+
43
+ class FeedForward(nn.Module):
44
+ def __init__(self, dim, dim_out=None, mult=4, glu=True, dropout=0.):
45
+ super().__init__()
46
+ inner_dim = int(dim * mult)
47
+ dim_out = default(dim_out, dim)
48
+ project_in = nn.Sequential(
49
+ nn.Linear(dim, inner_dim),
50
+ nn.GELU()
51
+ ) if not glu else GEGLU(dim, inner_dim)
52
+
53
+ self.net = nn.Sequential(
54
+ project_in,
55
+ nn.Dropout(dropout),
56
+ nn.Linear(inner_dim, dim_out)
57
+ )
58
+
59
+ def forward(self, x):
60
+ return self.net(x)
61
+
62
+
63
+ class SelfAttention(nn.Module):
64
+ def __init__(self, query_dim, heads=8, dim_head=64, dropout=0.):
65
+ super().__init__()
66
+ inner_dim = dim_head * heads
67
+ self.scale = dim_head ** -0.5
68
+ self.heads = heads
69
+
70
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
71
+ self.to_k = nn.Linear(query_dim, inner_dim, bias=False)
72
+ self.to_v = nn.Linear(query_dim, inner_dim, bias=False)
73
+
74
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) )
75
+
76
+ def forward(self, x):
77
+ q = self.to_q(x) # B*N*(H*C)
78
+ k = self.to_k(x) # B*N*(H*C)
79
+ v = self.to_v(x) # B*N*(H*C)
80
+
81
+ B, N, HC = q.shape
82
+ H = self.heads
83
+ C = HC // H
84
+
85
+ q = q.view(B,N,H,C).permute(0,2,1,3).reshape(B*H,N,C) # (B*H)*N*C
86
+ k = k.view(B,N,H,C).permute(0,2,1,3).reshape(B*H,N,C) # (B*H)*N*C
87
+ v = v.view(B,N,H,C).permute(0,2,1,3).reshape(B*H,N,C) # (B*H)*N*C
88
+
89
+ sim = torch.einsum('b i c, b j c -> b i j', q, k) * self.scale # (B*H)*N*N
90
+ attn = sim.softmax(dim=-1) # (B*H)*N*N
91
+
92
+ out = torch.einsum('b i j, b j c -> b i c', attn, v) # (B*H)*N*C
93
+ out = out.view(B,H,N,C).permute(0,2,1,3).reshape(B,N,(H*C)) # B*N*(H*C)
94
+
95
+ return self.to_out(out)
96
+
97
+
98
+
99
+ class Resampler(nn.Module):
100
+ def __init__(self, query_dim=1024, n_heads=8, d_head=64):
101
+ super().__init__()
102
+
103
+ self.attn = SelfAttention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
104
+ self.ff = FeedForward(query_dim, glu=True)
105
+
106
+ self.norm1 = nn.LayerNorm(query_dim)
107
+ self.norm2 = nn.LayerNorm(query_dim)
108
+
109
+ def forward(self, x):
110
+ x = x + self.attn(self.norm1(x))
111
+ x = x + self.ff(self.norm2(x))
112
+ return x
diffusers/.DS_Store ADDED
Binary file (8.2 kB). View file
 
diffusers/__init__.py ADDED
@@ -0,0 +1,734 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __version__ = "0.23.1"
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ from .utils import (
6
+ DIFFUSERS_SLOW_IMPORT,
7
+ OptionalDependencyNotAvailable,
8
+ _LazyModule,
9
+ is_flax_available,
10
+ is_k_diffusion_available,
11
+ is_librosa_available,
12
+ is_note_seq_available,
13
+ is_onnx_available,
14
+ is_scipy_available,
15
+ is_torch_available,
16
+ is_torchsde_available,
17
+ is_transformers_available,
18
+ )
19
+
20
+
21
+ # Lazy Import based on
22
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/__init__.py
23
+
24
+ # When adding a new object to this init, please add it to `_import_structure`. The `_import_structure` is a dictionary submodule to list of object names,
25
+ # and is used to defer the actual importing for when the objects are requested.
26
+ # This way `import diffusers` provides the names in the namespace without actually importing anything (and especially none of the backends).
27
+
28
+ _import_structure = {
29
+ "configuration_utils": ["ConfigMixin"],
30
+ "models": [],
31
+ "pipelines": [],
32
+ "schedulers": [],
33
+ "utils": [
34
+ "OptionalDependencyNotAvailable",
35
+ "is_flax_available",
36
+ "is_inflect_available",
37
+ "is_invisible_watermark_available",
38
+ "is_k_diffusion_available",
39
+ "is_k_diffusion_version",
40
+ "is_librosa_available",
41
+ "is_note_seq_available",
42
+ "is_onnx_available",
43
+ "is_scipy_available",
44
+ "is_torch_available",
45
+ "is_torchsde_available",
46
+ "is_transformers_available",
47
+ "is_transformers_version",
48
+ "is_unidecode_available",
49
+ "logging",
50
+ ],
51
+ }
52
+
53
+ try:
54
+ if not is_onnx_available():
55
+ raise OptionalDependencyNotAvailable()
56
+ except OptionalDependencyNotAvailable:
57
+ from .utils import dummy_onnx_objects # noqa F403
58
+
59
+ _import_structure["utils.dummy_onnx_objects"] = [
60
+ name for name in dir(dummy_onnx_objects) if not name.startswith("_")
61
+ ]
62
+
63
+ else:
64
+ _import_structure["pipelines"].extend(["OnnxRuntimeModel"])
65
+
66
+ try:
67
+ if not is_torch_available():
68
+ raise OptionalDependencyNotAvailable()
69
+ except OptionalDependencyNotAvailable:
70
+ from .utils import dummy_pt_objects # noqa F403
71
+
72
+ _import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")]
73
+
74
+ else:
75
+ _import_structure["models"].extend(
76
+ [
77
+ "AsymmetricAutoencoderKL",
78
+ "AutoencoderKL",
79
+ "AutoencoderTiny",
80
+ "ConsistencyDecoderVAE",
81
+ "ControlNetModel",
82
+ "ModelMixin",
83
+ "MotionAdapter",
84
+ "MultiAdapter",
85
+ "PriorTransformer",
86
+ "T2IAdapter",
87
+ "T5FilmDecoder",
88
+ "Transformer2DModel",
89
+ "UNet1DModel",
90
+ "UNet2DConditionModel",
91
+ "UNet2DModel",
92
+ "UNet3DConditionModel",
93
+ "UNetMotionModel",
94
+ "VQModel",
95
+ ]
96
+ )
97
+ _import_structure["optimization"] = [
98
+ "get_constant_schedule",
99
+ "get_constant_schedule_with_warmup",
100
+ "get_cosine_schedule_with_warmup",
101
+ "get_cosine_with_hard_restarts_schedule_with_warmup",
102
+ "get_linear_schedule_with_warmup",
103
+ "get_polynomial_decay_schedule_with_warmup",
104
+ "get_scheduler",
105
+ ]
106
+
107
+ _import_structure["pipelines"].extend(
108
+ [
109
+ "AudioPipelineOutput",
110
+ "AutoPipelineForImage2Image",
111
+ "AutoPipelineForInpainting",
112
+ "AutoPipelineForText2Image",
113
+ "ConsistencyModelPipeline",
114
+ "DanceDiffusionPipeline",
115
+ "DDIMPipeline",
116
+ "DDPMPipeline",
117
+ "DiffusionPipeline",
118
+ "DiTPipeline",
119
+ "ImagePipelineOutput",
120
+ "KarrasVePipeline",
121
+ "LDMPipeline",
122
+ "LDMSuperResolutionPipeline",
123
+ "PNDMPipeline",
124
+ "RePaintPipeline",
125
+ "ScoreSdeVePipeline",
126
+ ]
127
+ )
128
+ _import_structure["schedulers"].extend(
129
+ [
130
+ "CMStochasticIterativeScheduler",
131
+ "DDIMInverseScheduler",
132
+ "DDIMParallelScheduler",
133
+ "DDIMScheduler",
134
+ "DDPMParallelScheduler",
135
+ "DDPMScheduler",
136
+ "DDPMWuerstchenScheduler",
137
+ "DEISMultistepScheduler",
138
+ "DPMSolverMultistepInverseScheduler",
139
+ "DPMSolverMultistepScheduler",
140
+ "DPMSolverSinglestepScheduler",
141
+ "EulerAncestralDiscreteScheduler",
142
+ "EulerDiscreteScheduler",
143
+ "HeunDiscreteScheduler",
144
+ "IPNDMScheduler",
145
+ "KarrasVeScheduler",
146
+ "KDPM2AncestralDiscreteScheduler",
147
+ "KDPM2DiscreteScheduler",
148
+ "LCMScheduler",
149
+ "PNDMScheduler",
150
+ "RePaintScheduler",
151
+ "SchedulerMixin",
152
+ "ScoreSdeVeScheduler",
153
+ "UnCLIPScheduler",
154
+ "UniPCMultistepScheduler",
155
+ "VQDiffusionScheduler",
156
+ ]
157
+ )
158
+ _import_structure["training_utils"] = ["EMAModel"]
159
+
160
+ try:
161
+ if not (is_torch_available() and is_scipy_available()):
162
+ raise OptionalDependencyNotAvailable()
163
+ except OptionalDependencyNotAvailable:
164
+ from .utils import dummy_torch_and_scipy_objects # noqa F403
165
+
166
+ _import_structure["utils.dummy_torch_and_scipy_objects"] = [
167
+ name for name in dir(dummy_torch_and_scipy_objects) if not name.startswith("_")
168
+ ]
169
+
170
+ else:
171
+ _import_structure["schedulers"].extend(["LMSDiscreteScheduler"])
172
+
173
+ try:
174
+ if not (is_torch_available() and is_torchsde_available()):
175
+ raise OptionalDependencyNotAvailable()
176
+ except OptionalDependencyNotAvailable:
177
+ from .utils import dummy_torch_and_torchsde_objects # noqa F403
178
+
179
+ _import_structure["utils.dummy_torch_and_torchsde_objects"] = [
180
+ name for name in dir(dummy_torch_and_torchsde_objects) if not name.startswith("_")
181
+ ]
182
+
183
+ else:
184
+ _import_structure["schedulers"].extend(["DPMSolverSDEScheduler"])
185
+
186
+ try:
187
+ if not (is_torch_available() and is_transformers_available()):
188
+ raise OptionalDependencyNotAvailable()
189
+ except OptionalDependencyNotAvailable:
190
+ from .utils import dummy_torch_and_transformers_objects # noqa F403
191
+
192
+ _import_structure["utils.dummy_torch_and_transformers_objects"] = [
193
+ name for name in dir(dummy_torch_and_transformers_objects) if not name.startswith("_")
194
+ ]
195
+
196
+ else:
197
+ _import_structure["pipelines"].extend(
198
+ [
199
+ "AltDiffusionImg2ImgPipeline",
200
+ "AltDiffusionPipeline",
201
+ "AnimateDiffPipeline",
202
+ "AudioLDM2Pipeline",
203
+ "AudioLDM2ProjectionModel",
204
+ "AudioLDM2UNet2DConditionModel",
205
+ "AudioLDMPipeline",
206
+ "BlipDiffusionControlNetPipeline",
207
+ "BlipDiffusionPipeline",
208
+ "CLIPImageProjection",
209
+ "CycleDiffusionPipeline",
210
+ "IFImg2ImgPipeline",
211
+ "IFImg2ImgSuperResolutionPipeline",
212
+ "IFInpaintingPipeline",
213
+ "IFInpaintingSuperResolutionPipeline",
214
+ "IFPipeline",
215
+ "IFSuperResolutionPipeline",
216
+ "ImageTextPipelineOutput",
217
+ "KandinskyCombinedPipeline",
218
+ "KandinskyImg2ImgCombinedPipeline",
219
+ "KandinskyImg2ImgPipeline",
220
+ "KandinskyInpaintCombinedPipeline",
221
+ "KandinskyInpaintPipeline",
222
+ "KandinskyPipeline",
223
+ "KandinskyPriorPipeline",
224
+ "KandinskyV22CombinedPipeline",
225
+ "KandinskyV22ControlnetImg2ImgPipeline",
226
+ "KandinskyV22ControlnetPipeline",
227
+ "KandinskyV22Img2ImgCombinedPipeline",
228
+ "KandinskyV22Img2ImgPipeline",
229
+ "KandinskyV22InpaintCombinedPipeline",
230
+ "KandinskyV22InpaintPipeline",
231
+ "KandinskyV22Pipeline",
232
+ "KandinskyV22PriorEmb2EmbPipeline",
233
+ "KandinskyV22PriorPipeline",
234
+ "LatentConsistencyModelImg2ImgPipeline",
235
+ "LatentConsistencyModelPipeline",
236
+ "LDMTextToImagePipeline",
237
+ "MusicLDMPipeline",
238
+ "PaintByExamplePipeline",
239
+ "PixArtAlphaPipeline",
240
+ "SemanticStableDiffusionPipeline",
241
+ "ShapEImg2ImgPipeline",
242
+ "ShapEPipeline",
243
+ "StableDiffusionAdapterPipeline",
244
+ "StableDiffusionAttendAndExcitePipeline",
245
+ "StableDiffusionControlNetImg2ImgPipeline",
246
+ "StableDiffusionControlNetInpaintPipeline",
247
+ "StableDiffusionControlNetPipeline",
248
+ "StableDiffusionDepth2ImgPipeline",
249
+ "StableDiffusionDiffEditPipeline",
250
+ "StableDiffusionGLIGENPipeline",
251
+ "StableDiffusionGLIGENTextImagePipeline",
252
+ "StableDiffusionImageVariationPipeline",
253
+ "StableDiffusionImg2ImgPipeline",
254
+ "StableDiffusionInpaintPipeline",
255
+ "StableDiffusionInpaintPipelineLegacy",
256
+ "StableDiffusionInstructPix2PixPipeline",
257
+ "StableDiffusionLatentUpscalePipeline",
258
+ "StableDiffusionLDM3DPipeline",
259
+ "StableDiffusionModelEditingPipeline",
260
+ "StableDiffusionPanoramaPipeline",
261
+ "StableDiffusionParadigmsPipeline",
262
+ "StableDiffusionPipeline",
263
+ "StableDiffusionPipelineSafe",
264
+ "StableDiffusionPix2PixZeroPipeline",
265
+ "StableDiffusionSAGPipeline",
266
+ "StableDiffusionUpscalePipeline",
267
+ "StableDiffusionXLAdapterPipeline",
268
+ "StableDiffusionXLControlNetImg2ImgPipeline",
269
+ "StableDiffusionXLControlNetInpaintPipeline",
270
+ "StableDiffusionXLControlNetPipeline",
271
+ "StableDiffusionXLImg2ImgPipeline",
272
+ "StableDiffusionXLInpaintPipeline",
273
+ "StableDiffusionXLInstructPix2PixPipeline",
274
+ "StableDiffusionXLPipeline",
275
+ "StableUnCLIPImg2ImgPipeline",
276
+ "StableUnCLIPPipeline",
277
+ "TextToVideoSDPipeline",
278
+ "TextToVideoZeroPipeline",
279
+ "UnCLIPImageVariationPipeline",
280
+ "UnCLIPPipeline",
281
+ "UniDiffuserModel",
282
+ "UniDiffuserPipeline",
283
+ "UniDiffuserTextDecoder",
284
+ "VersatileDiffusionDualGuidedPipeline",
285
+ "VersatileDiffusionImageVariationPipeline",
286
+ "VersatileDiffusionPipeline",
287
+ "VersatileDiffusionTextToImagePipeline",
288
+ "VideoToVideoSDPipeline",
289
+ "VQDiffusionPipeline",
290
+ "WuerstchenCombinedPipeline",
291
+ "WuerstchenDecoderPipeline",
292
+ "WuerstchenPriorPipeline",
293
+ ]
294
+ )
295
+
296
+ try:
297
+ if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
298
+ raise OptionalDependencyNotAvailable()
299
+ except OptionalDependencyNotAvailable:
300
+ from .utils import dummy_torch_and_transformers_and_k_diffusion_objects # noqa F403
301
+
302
+ _import_structure["utils.dummy_torch_and_transformers_and_k_diffusion_objects"] = [
303
+ name for name in dir(dummy_torch_and_transformers_and_k_diffusion_objects) if not name.startswith("_")
304
+ ]
305
+
306
+ else:
307
+ _import_structure["pipelines"].extend(["StableDiffusionKDiffusionPipeline"])
308
+
309
+ try:
310
+ if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
311
+ raise OptionalDependencyNotAvailable()
312
+ except OptionalDependencyNotAvailable:
313
+ from .utils import dummy_torch_and_transformers_and_onnx_objects # noqa F403
314
+
315
+ _import_structure["utils.dummy_torch_and_transformers_and_onnx_objects"] = [
316
+ name for name in dir(dummy_torch_and_transformers_and_onnx_objects) if not name.startswith("_")
317
+ ]
318
+
319
+ else:
320
+ _import_structure["pipelines"].extend(
321
+ [
322
+ "OnnxStableDiffusionImg2ImgPipeline",
323
+ "OnnxStableDiffusionInpaintPipeline",
324
+ "OnnxStableDiffusionInpaintPipelineLegacy",
325
+ "OnnxStableDiffusionPipeline",
326
+ "OnnxStableDiffusionUpscalePipeline",
327
+ "StableDiffusionOnnxPipeline",
328
+ ]
329
+ )
330
+
331
+ try:
332
+ if not (is_torch_available() and is_librosa_available()):
333
+ raise OptionalDependencyNotAvailable()
334
+ except OptionalDependencyNotAvailable:
335
+ from .utils import dummy_torch_and_librosa_objects # noqa F403
336
+
337
+ _import_structure["utils.dummy_torch_and_librosa_objects"] = [
338
+ name for name in dir(dummy_torch_and_librosa_objects) if not name.startswith("_")
339
+ ]
340
+
341
+ else:
342
+ _import_structure["pipelines"].extend(["AudioDiffusionPipeline", "Mel"])
343
+
344
+ try:
345
+ if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
346
+ raise OptionalDependencyNotAvailable()
347
+ except OptionalDependencyNotAvailable:
348
+ from .utils import dummy_transformers_and_torch_and_note_seq_objects # noqa F403
349
+
350
+ _import_structure["utils.dummy_transformers_and_torch_and_note_seq_objects"] = [
351
+ name for name in dir(dummy_transformers_and_torch_and_note_seq_objects) if not name.startswith("_")
352
+ ]
353
+
354
+
355
+ else:
356
+ _import_structure["pipelines"].extend(["SpectrogramDiffusionPipeline"])
357
+
358
+ try:
359
+ if not is_flax_available():
360
+ raise OptionalDependencyNotAvailable()
361
+ except OptionalDependencyNotAvailable:
362
+ from .utils import dummy_flax_objects # noqa F403
363
+
364
+ _import_structure["utils.dummy_flax_objects"] = [
365
+ name for name in dir(dummy_flax_objects) if not name.startswith("_")
366
+ ]
367
+
368
+
369
+ else:
370
+ _import_structure["models.controlnet_flax"] = ["FlaxControlNetModel"]
371
+ _import_structure["models.modeling_flax_utils"] = ["FlaxModelMixin"]
372
+ _import_structure["models.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
373
+ _import_structure["models.vae_flax"] = ["FlaxAutoencoderKL"]
374
+ _import_structure["pipelines"].extend(["FlaxDiffusionPipeline"])
375
+ _import_structure["schedulers"].extend(
376
+ [
377
+ "FlaxDDIMScheduler",
378
+ "FlaxDDPMScheduler",
379
+ "FlaxDPMSolverMultistepScheduler",
380
+ "FlaxEulerDiscreteScheduler",
381
+ "FlaxKarrasVeScheduler",
382
+ "FlaxLMSDiscreteScheduler",
383
+ "FlaxPNDMScheduler",
384
+ "FlaxSchedulerMixin",
385
+ "FlaxScoreSdeVeScheduler",
386
+ ]
387
+ )
388
+
389
+
390
+ try:
391
+ if not (is_flax_available() and is_transformers_available()):
392
+ raise OptionalDependencyNotAvailable()
393
+ except OptionalDependencyNotAvailable:
394
+ from .utils import dummy_flax_and_transformers_objects # noqa F403
395
+
396
+ _import_structure["utils.dummy_flax_and_transformers_objects"] = [
397
+ name for name in dir(dummy_flax_and_transformers_objects) if not name.startswith("_")
398
+ ]
399
+
400
+
401
+ else:
402
+ _import_structure["pipelines"].extend(
403
+ [
404
+ "FlaxStableDiffusionControlNetPipeline",
405
+ "FlaxStableDiffusionImg2ImgPipeline",
406
+ "FlaxStableDiffusionInpaintPipeline",
407
+ "FlaxStableDiffusionPipeline",
408
+ "FlaxStableDiffusionXLPipeline",
409
+ ]
410
+ )
411
+
412
+ try:
413
+ if not (is_note_seq_available()):
414
+ raise OptionalDependencyNotAvailable()
415
+ except OptionalDependencyNotAvailable:
416
+ from .utils import dummy_note_seq_objects # noqa F403
417
+
418
+ _import_structure["utils.dummy_note_seq_objects"] = [
419
+ name for name in dir(dummy_note_seq_objects) if not name.startswith("_")
420
+ ]
421
+
422
+
423
+ else:
424
+ _import_structure["pipelines"].extend(["MidiProcessor"])
425
+
426
+ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
427
+ from .configuration_utils import ConfigMixin
428
+
429
+ try:
430
+ if not is_onnx_available():
431
+ raise OptionalDependencyNotAvailable()
432
+ except OptionalDependencyNotAvailable:
433
+ from .utils.dummy_onnx_objects import * # noqa F403
434
+ else:
435
+ from .pipelines import OnnxRuntimeModel
436
+
437
+ try:
438
+ if not is_torch_available():
439
+ raise OptionalDependencyNotAvailable()
440
+ except OptionalDependencyNotAvailable:
441
+ from .utils.dummy_pt_objects import * # noqa F403
442
+ else:
443
+ from .models import (
444
+ AsymmetricAutoencoderKL,
445
+ AutoencoderKL,
446
+ AutoencoderTiny,
447
+ ConsistencyDecoderVAE,
448
+ ControlNetModel,
449
+ ModelMixin,
450
+ MotionAdapter,
451
+ MultiAdapter,
452
+ PriorTransformer,
453
+ T2IAdapter,
454
+ T5FilmDecoder,
455
+ Transformer2DModel,
456
+ UNet1DModel,
457
+ UNet2DConditionModel,
458
+ UNet2DModel,
459
+ UNet3DConditionModel,
460
+ UNetMotionModel,
461
+ VQModel,
462
+ )
463
+ from .optimization import (
464
+ get_constant_schedule,
465
+ get_constant_schedule_with_warmup,
466
+ get_cosine_schedule_with_warmup,
467
+ get_cosine_with_hard_restarts_schedule_with_warmup,
468
+ get_linear_schedule_with_warmup,
469
+ get_polynomial_decay_schedule_with_warmup,
470
+ get_scheduler,
471
+ )
472
+ from .pipelines import (
473
+ AudioPipelineOutput,
474
+ AutoPipelineForImage2Image,
475
+ AutoPipelineForInpainting,
476
+ AutoPipelineForText2Image,
477
+ BlipDiffusionControlNetPipeline,
478
+ BlipDiffusionPipeline,
479
+ CLIPImageProjection,
480
+ ConsistencyModelPipeline,
481
+ DanceDiffusionPipeline,
482
+ DDIMPipeline,
483
+ DDPMPipeline,
484
+ DiffusionPipeline,
485
+ DiTPipeline,
486
+ ImagePipelineOutput,
487
+ KarrasVePipeline,
488
+ LDMPipeline,
489
+ LDMSuperResolutionPipeline,
490
+ PNDMPipeline,
491
+ RePaintPipeline,
492
+ ScoreSdeVePipeline,
493
+ )
494
+ from .schedulers import (
495
+ CMStochasticIterativeScheduler,
496
+ DDIMInverseScheduler,
497
+ DDIMParallelScheduler,
498
+ DDIMScheduler,
499
+ DDPMParallelScheduler,
500
+ DDPMScheduler,
501
+ DDPMWuerstchenScheduler,
502
+ DEISMultistepScheduler,
503
+ DPMSolverMultistepInverseScheduler,
504
+ DPMSolverMultistepScheduler,
505
+ DPMSolverSinglestepScheduler,
506
+ EulerAncestralDiscreteScheduler,
507
+ EulerDiscreteScheduler,
508
+ HeunDiscreteScheduler,
509
+ IPNDMScheduler,
510
+ KarrasVeScheduler,
511
+ KDPM2AncestralDiscreteScheduler,
512
+ KDPM2DiscreteScheduler,
513
+ LCMScheduler,
514
+ PNDMScheduler,
515
+ RePaintScheduler,
516
+ SchedulerMixin,
517
+ ScoreSdeVeScheduler,
518
+ UnCLIPScheduler,
519
+ UniPCMultistepScheduler,
520
+ VQDiffusionScheduler,
521
+ )
522
+ from .training_utils import EMAModel
523
+
524
+ try:
525
+ if not (is_torch_available() and is_scipy_available()):
526
+ raise OptionalDependencyNotAvailable()
527
+ except OptionalDependencyNotAvailable:
528
+ from .utils.dummy_torch_and_scipy_objects import * # noqa F403
529
+ else:
530
+ from .schedulers import LMSDiscreteScheduler
531
+
532
+ try:
533
+ if not (is_torch_available() and is_torchsde_available()):
534
+ raise OptionalDependencyNotAvailable()
535
+ except OptionalDependencyNotAvailable:
536
+ from .utils.dummy_torch_and_torchsde_objects import * # noqa F403
537
+ else:
538
+ from .schedulers import DPMSolverSDEScheduler
539
+
540
+ try:
541
+ if not (is_torch_available() and is_transformers_available()):
542
+ raise OptionalDependencyNotAvailable()
543
+ except OptionalDependencyNotAvailable:
544
+ from .utils.dummy_torch_and_transformers_objects import * # noqa F403
545
+ else:
546
+ from .pipelines import (
547
+ AltDiffusionImg2ImgPipeline,
548
+ AltDiffusionPipeline,
549
+ AnimateDiffPipeline,
550
+ AudioLDM2Pipeline,
551
+ AudioLDM2ProjectionModel,
552
+ AudioLDM2UNet2DConditionModel,
553
+ AudioLDMPipeline,
554
+ CLIPImageProjection,
555
+ CycleDiffusionPipeline,
556
+ IFImg2ImgPipeline,
557
+ IFImg2ImgSuperResolutionPipeline,
558
+ IFInpaintingPipeline,
559
+ IFInpaintingSuperResolutionPipeline,
560
+ IFPipeline,
561
+ IFSuperResolutionPipeline,
562
+ ImageTextPipelineOutput,
563
+ KandinskyCombinedPipeline,
564
+ KandinskyImg2ImgCombinedPipeline,
565
+ KandinskyImg2ImgPipeline,
566
+ KandinskyInpaintCombinedPipeline,
567
+ KandinskyInpaintPipeline,
568
+ KandinskyPipeline,
569
+ KandinskyPriorPipeline,
570
+ KandinskyV22CombinedPipeline,
571
+ KandinskyV22ControlnetImg2ImgPipeline,
572
+ KandinskyV22ControlnetPipeline,
573
+ KandinskyV22Img2ImgCombinedPipeline,
574
+ KandinskyV22Img2ImgPipeline,
575
+ KandinskyV22InpaintCombinedPipeline,
576
+ KandinskyV22InpaintPipeline,
577
+ KandinskyV22Pipeline,
578
+ KandinskyV22PriorEmb2EmbPipeline,
579
+ KandinskyV22PriorPipeline,
580
+ LatentConsistencyModelImg2ImgPipeline,
581
+ LatentConsistencyModelPipeline,
582
+ LDMTextToImagePipeline,
583
+ MusicLDMPipeline,
584
+ PaintByExamplePipeline,
585
+ PixArtAlphaPipeline,
586
+ SemanticStableDiffusionPipeline,
587
+ ShapEImg2ImgPipeline,
588
+ ShapEPipeline,
589
+ StableDiffusionAdapterPipeline,
590
+ StableDiffusionAttendAndExcitePipeline,
591
+ StableDiffusionControlNetImg2ImgPipeline,
592
+ StableDiffusionControlNetInpaintPipeline,
593
+ StableDiffusionControlNetPipeline,
594
+ StableDiffusionDepth2ImgPipeline,
595
+ StableDiffusionDiffEditPipeline,
596
+ StableDiffusionGLIGENPipeline,
597
+ StableDiffusionGLIGENTextImagePipeline,
598
+ StableDiffusionImageVariationPipeline,
599
+ StableDiffusionImg2ImgPipeline,
600
+ StableDiffusionInpaintPipeline,
601
+ StableDiffusionInpaintPipelineLegacy,
602
+ StableDiffusionInstructPix2PixPipeline,
603
+ StableDiffusionLatentUpscalePipeline,
604
+ StableDiffusionLDM3DPipeline,
605
+ StableDiffusionModelEditingPipeline,
606
+ StableDiffusionPanoramaPipeline,
607
+ StableDiffusionParadigmsPipeline,
608
+ StableDiffusionPipeline,
609
+ StableDiffusionPipelineSafe,
610
+ StableDiffusionPix2PixZeroPipeline,
611
+ StableDiffusionSAGPipeline,
612
+ StableDiffusionUpscalePipeline,
613
+ StableDiffusionXLAdapterPipeline,
614
+ StableDiffusionXLControlNetImg2ImgPipeline,
615
+ StableDiffusionXLControlNetInpaintPipeline,
616
+ StableDiffusionXLControlNetPipeline,
617
+ StableDiffusionXLImg2ImgPipeline,
618
+ StableDiffusionXLInpaintPipeline,
619
+ StableDiffusionXLInstructPix2PixPipeline,
620
+ StableDiffusionXLPipeline,
621
+ StableUnCLIPImg2ImgPipeline,
622
+ StableUnCLIPPipeline,
623
+ TextToVideoSDPipeline,
624
+ TextToVideoZeroPipeline,
625
+ UnCLIPImageVariationPipeline,
626
+ UnCLIPPipeline,
627
+ UniDiffuserModel,
628
+ UniDiffuserPipeline,
629
+ UniDiffuserTextDecoder,
630
+ VersatileDiffusionDualGuidedPipeline,
631
+ VersatileDiffusionImageVariationPipeline,
632
+ VersatileDiffusionPipeline,
633
+ VersatileDiffusionTextToImagePipeline,
634
+ VideoToVideoSDPipeline,
635
+ VQDiffusionPipeline,
636
+ WuerstchenCombinedPipeline,
637
+ WuerstchenDecoderPipeline,
638
+ WuerstchenPriorPipeline,
639
+ )
640
+
641
+ try:
642
+ if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
643
+ raise OptionalDependencyNotAvailable()
644
+ except OptionalDependencyNotAvailable:
645
+ from .utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403
646
+ else:
647
+ from .pipelines import StableDiffusionKDiffusionPipeline
648
+
649
+ try:
650
+ if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
651
+ raise OptionalDependencyNotAvailable()
652
+ except OptionalDependencyNotAvailable:
653
+ from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403
654
+ else:
655
+ from .pipelines import (
656
+ OnnxStableDiffusionImg2ImgPipeline,
657
+ OnnxStableDiffusionInpaintPipeline,
658
+ OnnxStableDiffusionInpaintPipelineLegacy,
659
+ OnnxStableDiffusionPipeline,
660
+ OnnxStableDiffusionUpscalePipeline,
661
+ StableDiffusionOnnxPipeline,
662
+ )
663
+
664
+ try:
665
+ if not (is_torch_available() and is_librosa_available()):
666
+ raise OptionalDependencyNotAvailable()
667
+ except OptionalDependencyNotAvailable:
668
+ from .utils.dummy_torch_and_librosa_objects import * # noqa F403
669
+ else:
670
+ from .pipelines import AudioDiffusionPipeline, Mel
671
+
672
+ try:
673
+ if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
674
+ raise OptionalDependencyNotAvailable()
675
+ except OptionalDependencyNotAvailable:
676
+ from .utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403
677
+ else:
678
+ from .pipelines import SpectrogramDiffusionPipeline
679
+
680
+ try:
681
+ if not is_flax_available():
682
+ raise OptionalDependencyNotAvailable()
683
+ except OptionalDependencyNotAvailable:
684
+ from .utils.dummy_flax_objects import * # noqa F403
685
+ else:
686
+ from .models.controlnet_flax import FlaxControlNetModel
687
+ from .models.modeling_flax_utils import FlaxModelMixin
688
+ from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel
689
+ from .models.vae_flax import FlaxAutoencoderKL
690
+ from .pipelines import FlaxDiffusionPipeline
691
+ from .schedulers import (
692
+ FlaxDDIMScheduler,
693
+ FlaxDDPMScheduler,
694
+ FlaxDPMSolverMultistepScheduler,
695
+ FlaxEulerDiscreteScheduler,
696
+ FlaxKarrasVeScheduler,
697
+ FlaxLMSDiscreteScheduler,
698
+ FlaxPNDMScheduler,
699
+ FlaxSchedulerMixin,
700
+ FlaxScoreSdeVeScheduler,
701
+ )
702
+
703
+ try:
704
+ if not (is_flax_available() and is_transformers_available()):
705
+ raise OptionalDependencyNotAvailable()
706
+ except OptionalDependencyNotAvailable:
707
+ from .utils.dummy_flax_and_transformers_objects import * # noqa F403
708
+ else:
709
+ from .pipelines import (
710
+ FlaxStableDiffusionControlNetPipeline,
711
+ FlaxStableDiffusionImg2ImgPipeline,
712
+ FlaxStableDiffusionInpaintPipeline,
713
+ FlaxStableDiffusionPipeline,
714
+ FlaxStableDiffusionXLPipeline,
715
+ )
716
+
717
+ try:
718
+ if not (is_note_seq_available()):
719
+ raise OptionalDependencyNotAvailable()
720
+ except OptionalDependencyNotAvailable:
721
+ from .utils.dummy_note_seq_objects import * # noqa F403
722
+ else:
723
+ from .pipelines import MidiProcessor
724
+
725
+ else:
726
+ import sys
727
+
728
+ sys.modules[__name__] = _LazyModule(
729
+ __name__,
730
+ globals()["__file__"],
731
+ _import_structure,
732
+ module_spec=__spec__,
733
+ extra_objects={"__version__": __version__},
734
+ )
diffusers/commands/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from abc import ABC, abstractmethod
16
+ from argparse import ArgumentParser
17
+
18
+
19
+ class BaseDiffusersCLICommand(ABC):
20
+ @staticmethod
21
+ @abstractmethod
22
+ def register_subcommand(parser: ArgumentParser):
23
+ raise NotImplementedError()
24
+
25
+ @abstractmethod
26
+ def run(self):
27
+ raise NotImplementedError()
diffusers/commands/diffusers_cli.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from argparse import ArgumentParser
17
+
18
+ from .env import EnvironmentCommand
19
+ from .fp16_safetensors import FP16SafetensorsCommand
20
+
21
+
22
+ def main():
23
+ parser = ArgumentParser("Diffusers CLI tool", usage="diffusers-cli <command> [<args>]")
24
+ commands_parser = parser.add_subparsers(help="diffusers-cli command helpers")
25
+
26
+ # Register commands
27
+ EnvironmentCommand.register_subcommand(commands_parser)
28
+ FP16SafetensorsCommand.register_subcommand(commands_parser)
29
+
30
+ # Let's go
31
+ args = parser.parse_args()
32
+
33
+ if not hasattr(args, "func"):
34
+ parser.print_help()
35
+ exit(1)
36
+
37
+ # Run
38
+ service = args.func(args)
39
+ service.run()
40
+
41
+
42
+ if __name__ == "__main__":
43
+ main()
diffusers/commands/env.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import platform
16
+ from argparse import ArgumentParser
17
+
18
+ import huggingface_hub
19
+
20
+ from .. import __version__ as version
21
+ from ..utils import is_accelerate_available, is_torch_available, is_transformers_available, is_xformers_available
22
+ from . import BaseDiffusersCLICommand
23
+
24
+
25
+ def info_command_factory(_):
26
+ return EnvironmentCommand()
27
+
28
+
29
+ class EnvironmentCommand(BaseDiffusersCLICommand):
30
+ @staticmethod
31
+ def register_subcommand(parser: ArgumentParser):
32
+ download_parser = parser.add_parser("env")
33
+ download_parser.set_defaults(func=info_command_factory)
34
+
35
+ def run(self):
36
+ hub_version = huggingface_hub.__version__
37
+
38
+ pt_version = "not installed"
39
+ pt_cuda_available = "NA"
40
+ if is_torch_available():
41
+ import torch
42
+
43
+ pt_version = torch.__version__
44
+ pt_cuda_available = torch.cuda.is_available()
45
+
46
+ transformers_version = "not installed"
47
+ if is_transformers_available():
48
+ import transformers
49
+
50
+ transformers_version = transformers.__version__
51
+
52
+ accelerate_version = "not installed"
53
+ if is_accelerate_available():
54
+ import accelerate
55
+
56
+ accelerate_version = accelerate.__version__
57
+
58
+ xformers_version = "not installed"
59
+ if is_xformers_available():
60
+ import xformers
61
+
62
+ xformers_version = xformers.__version__
63
+
64
+ info = {
65
+ "`diffusers` version": version,
66
+ "Platform": platform.platform(),
67
+ "Python version": platform.python_version(),
68
+ "PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})",
69
+ "Huggingface_hub version": hub_version,
70
+ "Transformers version": transformers_version,
71
+ "Accelerate version": accelerate_version,
72
+ "xFormers version": xformers_version,
73
+ "Using GPU in script?": "<fill in>",
74
+ "Using distributed or parallel set-up in script?": "<fill in>",
75
+ }
76
+
77
+ print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n")
78
+ print(self.format_dict(info))
79
+
80
+ return info
81
+
82
+ @staticmethod
83
+ def format_dict(d):
84
+ return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n"
diffusers/commands/fp16_safetensors.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Usage example:
17
+ diffusers-cli fp16_safetensors --ckpt_id=openai/shap-e --fp16 --use_safetensors
18
+ """
19
+
20
+ import glob
21
+ import json
22
+ from argparse import ArgumentParser, Namespace
23
+ from importlib import import_module
24
+
25
+ import huggingface_hub
26
+ import torch
27
+ from huggingface_hub import hf_hub_download
28
+ from packaging import version
29
+
30
+ from ..utils import logging
31
+ from . import BaseDiffusersCLICommand
32
+
33
+
34
+ def conversion_command_factory(args: Namespace):
35
+ return FP16SafetensorsCommand(
36
+ args.ckpt_id,
37
+ args.fp16,
38
+ args.use_safetensors,
39
+ args.use_auth_token,
40
+ )
41
+
42
+
43
+ class FP16SafetensorsCommand(BaseDiffusersCLICommand):
44
+ @staticmethod
45
+ def register_subcommand(parser: ArgumentParser):
46
+ conversion_parser = parser.add_parser("fp16_safetensors")
47
+ conversion_parser.add_argument(
48
+ "--ckpt_id",
49
+ type=str,
50
+ help="Repo id of the checkpoints on which to run the conversion. Example: 'openai/shap-e'.",
51
+ )
52
+ conversion_parser.add_argument(
53
+ "--fp16", action="store_true", help="If serializing the variables in FP16 precision."
54
+ )
55
+ conversion_parser.add_argument(
56
+ "--use_safetensors", action="store_true", help="If serializing in the safetensors format."
57
+ )
58
+ conversion_parser.add_argument(
59
+ "--use_auth_token",
60
+ action="store_true",
61
+ help="When working with checkpoints having private visibility. When used `huggingface-cli login` needs to be run beforehand.",
62
+ )
63
+ conversion_parser.set_defaults(func=conversion_command_factory)
64
+
65
+ def __init__(self, ckpt_id: str, fp16: bool, use_safetensors: bool, use_auth_token: bool):
66
+ self.logger = logging.get_logger("diffusers-cli/fp16_safetensors")
67
+ self.ckpt_id = ckpt_id
68
+ self.local_ckpt_dir = f"/tmp/{ckpt_id}"
69
+ self.fp16 = fp16
70
+
71
+ self.use_safetensors = use_safetensors
72
+
73
+ if not self.use_safetensors and not self.fp16:
74
+ raise NotImplementedError(
75
+ "When `use_safetensors` and `fp16` both are False, then this command is of no use."
76
+ )
77
+
78
+ self.use_auth_token = use_auth_token
79
+
80
+ def run(self):
81
+ if version.parse(huggingface_hub.__version__) < version.parse("0.9.0"):
82
+ raise ImportError(
83
+ "The huggingface_hub version must be >= 0.9.0 to use this command. Please update your huggingface_hub"
84
+ " installation."
85
+ )
86
+ else:
87
+ from huggingface_hub import create_commit
88
+ from huggingface_hub._commit_api import CommitOperationAdd
89
+
90
+ model_index = hf_hub_download(repo_id=self.ckpt_id, filename="model_index.json", token=self.use_auth_token)
91
+ with open(model_index, "r") as f:
92
+ pipeline_class_name = json.load(f)["_class_name"]
93
+ pipeline_class = getattr(import_module("diffusers"), pipeline_class_name)
94
+ self.logger.info(f"Pipeline class imported: {pipeline_class_name}.")
95
+
96
+ # Load the appropriate pipeline. We could have use `DiffusionPipeline`
97
+ # here, but just to avoid any rough edge cases.
98
+ pipeline = pipeline_class.from_pretrained(
99
+ self.ckpt_id, torch_dtype=torch.float16 if self.fp16 else torch.float32, use_auth_token=self.use_auth_token
100
+ )
101
+ pipeline.save_pretrained(
102
+ self.local_ckpt_dir,
103
+ safe_serialization=True if self.use_safetensors else False,
104
+ variant="fp16" if self.fp16 else None,
105
+ )
106
+ self.logger.info(f"Pipeline locally saved to {self.local_ckpt_dir}.")
107
+
108
+ # Fetch all the paths.
109
+ if self.fp16:
110
+ modified_paths = glob.glob(f"{self.local_ckpt_dir}/*/*.fp16.*")
111
+ elif self.use_safetensors:
112
+ modified_paths = glob.glob(f"{self.local_ckpt_dir}/*/*.safetensors")
113
+
114
+ # Prepare for the PR.
115
+ commit_message = f"Serialize variables with FP16: {self.fp16} and safetensors: {self.use_safetensors}."
116
+ operations = []
117
+ for path in modified_paths:
118
+ operations.append(CommitOperationAdd(path_in_repo="/".join(path.split("/")[4:]), path_or_fileobj=path))
119
+
120
+ # Open the PR.
121
+ commit_description = (
122
+ "Variables converted by the [`diffusers`' `fp16_safetensors`"
123
+ " CLI](https://github.com/huggingface/diffusers/blob/main/src/diffusers/commands/fp16_safetensors.py)."
124
+ )
125
+ hub_pr_url = create_commit(
126
+ repo_id=self.ckpt_id,
127
+ operations=operations,
128
+ commit_message=commit_message,
129
+ commit_description=commit_description,
130
+ repo_type="model",
131
+ create_pr=True,
132
+ ).pr_url
133
+ self.logger.info(f"PR created here: {hub_pr_url}.")
diffusers/configuration_utils.py ADDED
@@ -0,0 +1,694 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ ConfigMixin base class and utilities."""
17
+ import dataclasses
18
+ import functools
19
+ import importlib
20
+ import inspect
21
+ import json
22
+ import os
23
+ import re
24
+ from collections import OrderedDict
25
+ from pathlib import PosixPath
26
+ from typing import Any, Dict, Tuple, Union
27
+
28
+ import numpy as np
29
+ from huggingface_hub import create_repo, hf_hub_download
30
+ from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
31
+ from requests import HTTPError
32
+
33
+ from . import __version__
34
+ from .utils import (
35
+ DIFFUSERS_CACHE,
36
+ HUGGINGFACE_CO_RESOLVE_ENDPOINT,
37
+ DummyObject,
38
+ deprecate,
39
+ extract_commit_hash,
40
+ http_user_agent,
41
+ logging,
42
+ )
43
+
44
+
45
+ logger = logging.get_logger(__name__)
46
+
47
+ _re_configuration_file = re.compile(r"config\.(.*)\.json")
48
+
49
+
50
+ class FrozenDict(OrderedDict):
51
+ def __init__(self, *args, **kwargs):
52
+ super().__init__(*args, **kwargs)
53
+
54
+ for key, value in self.items():
55
+ setattr(self, key, value)
56
+
57
+ self.__frozen = True
58
+
59
+ def __delitem__(self, *args, **kwargs):
60
+ raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
61
+
62
+ def setdefault(self, *args, **kwargs):
63
+ raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
64
+
65
+ def pop(self, *args, **kwargs):
66
+ raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
67
+
68
+ def update(self, *args, **kwargs):
69
+ raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
70
+
71
+ def __setattr__(self, name, value):
72
+ if hasattr(self, "__frozen") and self.__frozen:
73
+ raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
74
+ super().__setattr__(name, value)
75
+
76
+ def __setitem__(self, name, value):
77
+ if hasattr(self, "__frozen") and self.__frozen:
78
+ raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
79
+ super().__setitem__(name, value)
80
+
81
+
82
+ class ConfigMixin:
83
+ r"""
84
+ Base class for all configuration classes. All configuration parameters are stored under `self.config`. Also
85
+ provides the [`~ConfigMixin.from_config`] and [`~ConfigMixin.save_config`] methods for loading, downloading, and
86
+ saving classes that inherit from [`ConfigMixin`].
87
+
88
+ Class attributes:
89
+ - **config_name** (`str`) -- A filename under which the config should stored when calling
90
+ [`~ConfigMixin.save_config`] (should be overridden by parent class).
91
+ - **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
92
+ overridden by subclass).
93
+ - **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by subclass).
94
+ - **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the `init` function
95
+ should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by
96
+ subclass).
97
+ """
98
+ config_name = None
99
+ ignore_for_config = []
100
+ has_compatibles = False
101
+
102
+ _deprecated_kwargs = []
103
+
104
+ def register_to_config(self, **kwargs):
105
+ if self.config_name is None:
106
+ raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`")
107
+ # Special case for `kwargs` used in deprecation warning added to schedulers
108
+ # TODO: remove this when we remove the deprecation warning, and the `kwargs` argument,
109
+ # or solve in a more general way.
110
+ kwargs.pop("kwargs", None)
111
+
112
+ if not hasattr(self, "_internal_dict"):
113
+ internal_dict = kwargs
114
+ else:
115
+ previous_dict = dict(self._internal_dict)
116
+ internal_dict = {**self._internal_dict, **kwargs}
117
+ logger.debug(f"Updating config from {previous_dict} to {internal_dict}")
118
+
119
+ self._internal_dict = FrozenDict(internal_dict)
120
+
121
+ def __getattr__(self, name: str) -> Any:
122
+ """The only reason we overwrite `getattr` here is to gracefully deprecate accessing
123
+ config attributes directly. See https://github.com/huggingface/diffusers/pull/3129
124
+
125
+ Tihs funtion is mostly copied from PyTorch's __getattr__ overwrite:
126
+ https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
127
+ """
128
+
129
+ is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name)
130
+ is_attribute = name in self.__dict__
131
+
132
+ if is_in_config and not is_attribute:
133
+ deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'scheduler.config.{name}'."
134
+ deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)
135
+ return self._internal_dict[name]
136
+
137
+ raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
138
+
139
+ def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
140
+ """
141
+ Save a configuration object to the directory specified in `save_directory` so that it can be reloaded using the
142
+ [`~ConfigMixin.from_config`] class method.
143
+
144
+ Args:
145
+ save_directory (`str` or `os.PathLike`):
146
+ Directory where the configuration JSON file is saved (will be created if it does not exist).
147
+ push_to_hub (`bool`, *optional*, defaults to `False`):
148
+ Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
149
+ repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
150
+ namespace).
151
+ kwargs (`Dict[str, Any]`, *optional*):
152
+ Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
153
+ """
154
+ if os.path.isfile(save_directory):
155
+ raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
156
+
157
+ os.makedirs(save_directory, exist_ok=True)
158
+
159
+ # If we save using the predefined names, we can load using `from_config`
160
+ output_config_file = os.path.join(save_directory, self.config_name)
161
+
162
+ self.to_json_file(output_config_file)
163
+ logger.info(f"Configuration saved in {output_config_file}")
164
+
165
+ if push_to_hub:
166
+ commit_message = kwargs.pop("commit_message", None)
167
+ private = kwargs.pop("private", False)
168
+ create_pr = kwargs.pop("create_pr", False)
169
+ token = kwargs.pop("token", None)
170
+ repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
171
+ repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
172
+
173
+ self._upload_folder(
174
+ save_directory,
175
+ repo_id,
176
+ token=token,
177
+ commit_message=commit_message,
178
+ create_pr=create_pr,
179
+ )
180
+
181
+ @classmethod
182
+ def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs):
183
+ r"""
184
+ Instantiate a Python class from a config dictionary.
185
+
186
+ Parameters:
187
+ config (`Dict[str, Any]`):
188
+ A config dictionary from which the Python class is instantiated. Make sure to only load configuration
189
+ files of compatible classes.
190
+ return_unused_kwargs (`bool`, *optional*, defaults to `False`):
191
+ Whether kwargs that are not consumed by the Python class should be returned or not.
192
+ kwargs (remaining dictionary of keyword arguments, *optional*):
193
+ Can be used to update the configuration object (after it is loaded) and initiate the Python class.
194
+ `**kwargs` are passed directly to the underlying scheduler/model's `__init__` method and eventually
195
+ overwrite the same named arguments in `config`.
196
+
197
+ Returns:
198
+ [`ModelMixin`] or [`SchedulerMixin`]:
199
+ A model or scheduler object instantiated from a config dictionary.
200
+
201
+ Examples:
202
+
203
+ ```python
204
+ >>> from diffusers import DDPMScheduler, DDIMScheduler, PNDMScheduler
205
+
206
+ >>> # Download scheduler from huggingface.co and cache.
207
+ >>> scheduler = DDPMScheduler.from_pretrained("google/ddpm-cifar10-32")
208
+
209
+ >>> # Instantiate DDIM scheduler class with same config as DDPM
210
+ >>> scheduler = DDIMScheduler.from_config(scheduler.config)
211
+
212
+ >>> # Instantiate PNDM scheduler class with same config as DDPM
213
+ >>> scheduler = PNDMScheduler.from_config(scheduler.config)
214
+ ```
215
+ """
216
+ # <===== TO BE REMOVED WITH DEPRECATION
217
+ # TODO(Patrick) - make sure to remove the following lines when config=="model_path" is deprecated
218
+ if "pretrained_model_name_or_path" in kwargs:
219
+ config = kwargs.pop("pretrained_model_name_or_path")
220
+
221
+ if config is None:
222
+ raise ValueError("Please make sure to provide a config as the first positional argument.")
223
+ # ======>
224
+
225
+ if not isinstance(config, dict):
226
+ deprecation_message = "It is deprecated to pass a pretrained model name or path to `from_config`."
227
+ if "Scheduler" in cls.__name__:
228
+ deprecation_message += (
229
+ f"If you were trying to load a scheduler, please use {cls}.from_pretrained(...) instead."
230
+ " Otherwise, please make sure to pass a configuration dictionary instead. This functionality will"
231
+ " be removed in v1.0.0."
232
+ )
233
+ elif "Model" in cls.__name__:
234
+ deprecation_message += (
235
+ f"If you were trying to load a model, please use {cls}.load_config(...) followed by"
236
+ f" {cls}.from_config(...) instead. Otherwise, please make sure to pass a configuration dictionary"
237
+ " instead. This functionality will be removed in v1.0.0."
238
+ )
239
+ deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False)
240
+ config, kwargs = cls.load_config(pretrained_model_name_or_path=config, return_unused_kwargs=True, **kwargs)
241
+
242
+ init_dict, unused_kwargs, hidden_dict = cls.extract_init_dict(config, **kwargs)
243
+
244
+ # Allow dtype to be specified on initialization
245
+ if "dtype" in unused_kwargs:
246
+ init_dict["dtype"] = unused_kwargs.pop("dtype")
247
+
248
+ # add possible deprecated kwargs
249
+ for deprecated_kwarg in cls._deprecated_kwargs:
250
+ if deprecated_kwarg in unused_kwargs:
251
+ init_dict[deprecated_kwarg] = unused_kwargs.pop(deprecated_kwarg)
252
+
253
+ # Return model and optionally state and/or unused_kwargs
254
+ model = cls(**init_dict)
255
+
256
+ # make sure to also save config parameters that might be used for compatible classes
257
+ model.register_to_config(**hidden_dict)
258
+
259
+ # add hidden kwargs of compatible classes to unused_kwargs
260
+ unused_kwargs = {**unused_kwargs, **hidden_dict}
261
+
262
+ if return_unused_kwargs:
263
+ return (model, unused_kwargs)
264
+ else:
265
+ return model
266
+
267
+ @classmethod
268
+ def get_config_dict(cls, *args, **kwargs):
269
+ deprecation_message = (
270
+ f" The function get_config_dict is deprecated. Please use {cls}.load_config instead. This function will be"
271
+ " removed in version v1.0.0"
272
+ )
273
+ deprecate("get_config_dict", "1.0.0", deprecation_message, standard_warn=False)
274
+ return cls.load_config(*args, **kwargs)
275
+
276
+ @classmethod
277
+ def load_config(
278
+ cls,
279
+ pretrained_model_name_or_path: Union[str, os.PathLike],
280
+ return_unused_kwargs=False,
281
+ return_commit_hash=False,
282
+ **kwargs,
283
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
284
+ r"""
285
+ Load a model or scheduler configuration.
286
+
287
+ Parameters:
288
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
289
+ Can be either:
290
+
291
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
292
+ the Hub.
293
+ - A path to a *directory* (for example `./my_model_directory`) containing model weights saved with
294
+ [`~ConfigMixin.save_config`].
295
+
296
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
297
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
298
+ is not used.
299
+ force_download (`bool`, *optional*, defaults to `False`):
300
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
301
+ cached versions if they exist.
302
+ resume_download (`bool`, *optional*, defaults to `False`):
303
+ Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
304
+ incompletely downloaded files are deleted.
305
+ proxies (`Dict[str, str]`, *optional*):
306
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
307
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
308
+ output_loading_info(`bool`, *optional*, defaults to `False`):
309
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
310
+ local_files_only (`bool`, *optional*, defaults to `False`):
311
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
312
+ won't be downloaded from the Hub.
313
+ use_auth_token (`str` or *bool*, *optional*):
314
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
315
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
316
+ revision (`str`, *optional*, defaults to `"main"`):
317
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
318
+ allowed by Git.
319
+ subfolder (`str`, *optional*, defaults to `""`):
320
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
321
+ return_unused_kwargs (`bool`, *optional*, defaults to `False):
322
+ Whether unused keyword arguments of the config are returned.
323
+ return_commit_hash (`bool`, *optional*, defaults to `False):
324
+ Whether the `commit_hash` of the loaded configuration are returned.
325
+
326
+ Returns:
327
+ `dict`:
328
+ A dictionary of all the parameters stored in a JSON configuration file.
329
+
330
+ """
331
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
332
+ force_download = kwargs.pop("force_download", False)
333
+ resume_download = kwargs.pop("resume_download", False)
334
+ proxies = kwargs.pop("proxies", None)
335
+ use_auth_token = kwargs.pop("use_auth_token", None)
336
+ local_files_only = kwargs.pop("local_files_only", False)
337
+ revision = kwargs.pop("revision", None)
338
+ _ = kwargs.pop("mirror", None)
339
+ subfolder = kwargs.pop("subfolder", None)
340
+ user_agent = kwargs.pop("user_agent", {})
341
+
342
+ user_agent = {**user_agent, "file_type": "config"}
343
+ user_agent = http_user_agent(user_agent)
344
+
345
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
346
+
347
+ if cls.config_name is None:
348
+ raise ValueError(
349
+ "`self.config_name` is not defined. Note that one should not load a config from "
350
+ "`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`"
351
+ )
352
+
353
+ if os.path.isfile(pretrained_model_name_or_path):
354
+ config_file = pretrained_model_name_or_path
355
+ elif os.path.isdir(pretrained_model_name_or_path):
356
+ if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)):
357
+ # Load from a PyTorch checkpoint
358
+ config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
359
+ elif subfolder is not None and os.path.isfile(
360
+ os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
361
+ ):
362
+ config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
363
+ else:
364
+ raise EnvironmentError(
365
+ f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}."
366
+ )
367
+ else:
368
+ try:
369
+ # Load from URL or cache if already cached
370
+ config_file = hf_hub_download(
371
+ pretrained_model_name_or_path,
372
+ filename=cls.config_name,
373
+ cache_dir=cache_dir,
374
+ force_download=force_download,
375
+ proxies=proxies,
376
+ resume_download=resume_download,
377
+ local_files_only=local_files_only,
378
+ use_auth_token=use_auth_token,
379
+ user_agent=user_agent,
380
+ subfolder=subfolder,
381
+ revision=revision,
382
+ )
383
+ except RepositoryNotFoundError:
384
+ raise EnvironmentError(
385
+ f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier"
386
+ " listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a"
387
+ " token having permission to this repo with `use_auth_token` or log in with `huggingface-cli"
388
+ " login`."
389
+ )
390
+ except RevisionNotFoundError:
391
+ raise EnvironmentError(
392
+ f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for"
393
+ " this model name. Check the model page at"
394
+ f" 'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
395
+ )
396
+ except EntryNotFoundError:
397
+ raise EnvironmentError(
398
+ f"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}."
399
+ )
400
+ except HTTPError as err:
401
+ raise EnvironmentError(
402
+ "There was a specific connection error when trying to load"
403
+ f" {pretrained_model_name_or_path}:\n{err}"
404
+ )
405
+ except ValueError:
406
+ raise EnvironmentError(
407
+ f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
408
+ f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
409
+ f" directory containing a {cls.config_name} file.\nCheckout your internet connection or see how to"
410
+ " run the library in offline mode at"
411
+ " 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
412
+ )
413
+ except EnvironmentError:
414
+ raise EnvironmentError(
415
+ f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from "
416
+ "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
417
+ f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
418
+ f"containing a {cls.config_name} file"
419
+ )
420
+
421
+ try:
422
+ # Load config dict
423
+ config_dict = cls._dict_from_json_file(config_file)
424
+
425
+ commit_hash = extract_commit_hash(config_file)
426
+ except (json.JSONDecodeError, UnicodeDecodeError):
427
+ raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.")
428
+
429
+ if not (return_unused_kwargs or return_commit_hash):
430
+ return config_dict
431
+
432
+ outputs = (config_dict,)
433
+
434
+ if return_unused_kwargs:
435
+ outputs += (kwargs,)
436
+
437
+ if return_commit_hash:
438
+ outputs += (commit_hash,)
439
+
440
+ return outputs
441
+
442
+ @staticmethod
443
+ def _get_init_keys(cls):
444
+ return set(dict(inspect.signature(cls.__init__).parameters).keys())
445
+
446
+ @classmethod
447
+ def extract_init_dict(cls, config_dict, **kwargs):
448
+ # Skip keys that were not present in the original config, so default __init__ values were used
449
+ used_defaults = config_dict.get("_use_default_values", [])
450
+ config_dict = {k: v for k, v in config_dict.items() if k not in used_defaults and k != "_use_default_values"}
451
+
452
+ # 0. Copy origin config dict
453
+ original_dict = dict(config_dict.items())
454
+
455
+ # 1. Retrieve expected config attributes from __init__ signature
456
+ expected_keys = cls._get_init_keys(cls)
457
+ expected_keys.remove("self")
458
+ # remove general kwargs if present in dict
459
+ if "kwargs" in expected_keys:
460
+ expected_keys.remove("kwargs")
461
+ # remove flax internal keys
462
+ if hasattr(cls, "_flax_internal_args"):
463
+ for arg in cls._flax_internal_args:
464
+ expected_keys.remove(arg)
465
+
466
+ # 2. Remove attributes that cannot be expected from expected config attributes
467
+ # remove keys to be ignored
468
+ if len(cls.ignore_for_config) > 0:
469
+ expected_keys = expected_keys - set(cls.ignore_for_config)
470
+
471
+ # load diffusers library to import compatible and original scheduler
472
+ diffusers_library = importlib.import_module(__name__.split(".")[0])
473
+
474
+ if cls.has_compatibles:
475
+ compatible_classes = [c for c in cls._get_compatibles() if not isinstance(c, DummyObject)]
476
+ else:
477
+ compatible_classes = []
478
+
479
+ expected_keys_comp_cls = set()
480
+ for c in compatible_classes:
481
+ expected_keys_c = cls._get_init_keys(c)
482
+ expected_keys_comp_cls = expected_keys_comp_cls.union(expected_keys_c)
483
+ expected_keys_comp_cls = expected_keys_comp_cls - cls._get_init_keys(cls)
484
+ config_dict = {k: v for k, v in config_dict.items() if k not in expected_keys_comp_cls}
485
+
486
+ # remove attributes from orig class that cannot be expected
487
+ orig_cls_name = config_dict.pop("_class_name", cls.__name__)
488
+ if (
489
+ isinstance(orig_cls_name, str)
490
+ and orig_cls_name != cls.__name__
491
+ and hasattr(diffusers_library, orig_cls_name)
492
+ ):
493
+ orig_cls = getattr(diffusers_library, orig_cls_name)
494
+ unexpected_keys_from_orig = cls._get_init_keys(orig_cls) - expected_keys
495
+ config_dict = {k: v for k, v in config_dict.items() if k not in unexpected_keys_from_orig}
496
+ elif not isinstance(orig_cls_name, str) and not isinstance(orig_cls_name, (list, tuple)):
497
+ raise ValueError(
498
+ "Make sure that the `_class_name` is of type string or list of string (for custom pipelines)."
499
+ )
500
+
501
+ # remove private attributes
502
+ config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")}
503
+
504
+ # 3. Create keyword arguments that will be passed to __init__ from expected keyword arguments
505
+ init_dict = {}
506
+ for key in expected_keys:
507
+ # if config param is passed to kwarg and is present in config dict
508
+ # it should overwrite existing config dict key
509
+ if key in kwargs and key in config_dict:
510
+ config_dict[key] = kwargs.pop(key)
511
+
512
+ if key in kwargs:
513
+ # overwrite key
514
+ init_dict[key] = kwargs.pop(key)
515
+ elif key in config_dict:
516
+ # use value from config dict
517
+ init_dict[key] = config_dict.pop(key)
518
+
519
+ # 4. Give nice warning if unexpected values have been passed
520
+ if len(config_dict) > 0:
521
+ logger.warning(
522
+ f"The config attributes {config_dict} were passed to {cls.__name__}, "
523
+ "but are not expected and will be ignored. Please verify your "
524
+ f"{cls.config_name} configuration file."
525
+ )
526
+
527
+ # 5. Give nice info if config attributes are initiliazed to default because they have not been passed
528
+ passed_keys = set(init_dict.keys())
529
+ if len(expected_keys - passed_keys) > 0:
530
+ logger.info(
531
+ f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values."
532
+ )
533
+
534
+ # 6. Define unused keyword arguments
535
+ unused_kwargs = {**config_dict, **kwargs}
536
+
537
+ # 7. Define "hidden" config parameters that were saved for compatible classes
538
+ hidden_config_dict = {k: v for k, v in original_dict.items() if k not in init_dict}
539
+
540
+ return init_dict, unused_kwargs, hidden_config_dict
541
+
542
+ @classmethod
543
+ def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
544
+ with open(json_file, "r", encoding="utf-8") as reader:
545
+ text = reader.read()
546
+ return json.loads(text)
547
+
548
+ def __repr__(self):
549
+ return f"{self.__class__.__name__} {self.to_json_string()}"
550
+
551
+ @property
552
+ def config(self) -> Dict[str, Any]:
553
+ """
554
+ Returns the config of the class as a frozen dictionary
555
+
556
+ Returns:
557
+ `Dict[str, Any]`: Config of the class.
558
+ """
559
+ return self._internal_dict
560
+
561
+ def to_json_string(self) -> str:
562
+ """
563
+ Serializes the configuration instance to a JSON string.
564
+
565
+ Returns:
566
+ `str`:
567
+ String containing all the attributes that make up the configuration instance in JSON format.
568
+ """
569
+ config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {}
570
+ config_dict["_class_name"] = self.__class__.__name__
571
+ config_dict["_diffusers_version"] = __version__
572
+
573
+ def to_json_saveable(value):
574
+ if isinstance(value, np.ndarray):
575
+ value = value.tolist()
576
+ elif isinstance(value, PosixPath):
577
+ value = str(value)
578
+ return value
579
+
580
+ config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()}
581
+ # Don't save "_ignore_files" or "_use_default_values"
582
+ config_dict.pop("_ignore_files", None)
583
+ config_dict.pop("_use_default_values", None)
584
+
585
+ return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
586
+
587
+ def to_json_file(self, json_file_path: Union[str, os.PathLike]):
588
+ """
589
+ Save the configuration instance's parameters to a JSON file.
590
+
591
+ Args:
592
+ json_file_path (`str` or `os.PathLike`):
593
+ Path to the JSON file to save a configuration instance's parameters.
594
+ """
595
+ with open(json_file_path, "w", encoding="utf-8") as writer:
596
+ writer.write(self.to_json_string())
597
+
598
+
599
+ def register_to_config(init):
600
+ r"""
601
+ Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are
602
+ automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that
603
+ shouldn't be registered in the config, use the `ignore_for_config` class variable
604
+
605
+ Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init!
606
+ """
607
+
608
+ @functools.wraps(init)
609
+ def inner_init(self, *args, **kwargs):
610
+ # Ignore private kwargs in the init.
611
+ init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
612
+ config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")}
613
+ if not isinstance(self, ConfigMixin):
614
+ raise RuntimeError(
615
+ f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
616
+ "not inherit from `ConfigMixin`."
617
+ )
618
+
619
+ ignore = getattr(self, "ignore_for_config", [])
620
+ # Get positional arguments aligned with kwargs
621
+ new_kwargs = {}
622
+ signature = inspect.signature(init)
623
+ parameters = {
624
+ name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore
625
+ }
626
+ for arg, name in zip(args, parameters.keys()):
627
+ new_kwargs[name] = arg
628
+
629
+ # Then add all kwargs
630
+ new_kwargs.update(
631
+ {
632
+ k: init_kwargs.get(k, default)
633
+ for k, default in parameters.items()
634
+ if k not in ignore and k not in new_kwargs
635
+ }
636
+ )
637
+
638
+ # Take note of the parameters that were not present in the loaded config
639
+ if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0:
640
+ new_kwargs["_use_default_values"] = list(set(new_kwargs.keys()) - set(init_kwargs))
641
+
642
+ new_kwargs = {**config_init_kwargs, **new_kwargs}
643
+ getattr(self, "register_to_config")(**new_kwargs)
644
+ init(self, *args, **init_kwargs)
645
+
646
+ return inner_init
647
+
648
+
649
+ def flax_register_to_config(cls):
650
+ original_init = cls.__init__
651
+
652
+ @functools.wraps(original_init)
653
+ def init(self, *args, **kwargs):
654
+ if not isinstance(self, ConfigMixin):
655
+ raise RuntimeError(
656
+ f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
657
+ "not inherit from `ConfigMixin`."
658
+ )
659
+
660
+ # Ignore private kwargs in the init. Retrieve all passed attributes
661
+ init_kwargs = dict(kwargs.items())
662
+
663
+ # Retrieve default values
664
+ fields = dataclasses.fields(self)
665
+ default_kwargs = {}
666
+ for field in fields:
667
+ # ignore flax specific attributes
668
+ if field.name in self._flax_internal_args:
669
+ continue
670
+ if type(field.default) == dataclasses._MISSING_TYPE:
671
+ default_kwargs[field.name] = None
672
+ else:
673
+ default_kwargs[field.name] = getattr(self, field.name)
674
+
675
+ # Make sure init_kwargs override default kwargs
676
+ new_kwargs = {**default_kwargs, **init_kwargs}
677
+ # dtype should be part of `init_kwargs`, but not `new_kwargs`
678
+ if "dtype" in new_kwargs:
679
+ new_kwargs.pop("dtype")
680
+
681
+ # Get positional arguments aligned with kwargs
682
+ for i, arg in enumerate(args):
683
+ name = fields[i].name
684
+ new_kwargs[name] = arg
685
+
686
+ # Take note of the parameters that were not present in the loaded config
687
+ if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0:
688
+ new_kwargs["_use_default_values"] = list(set(new_kwargs.keys()) - set(init_kwargs))
689
+
690
+ getattr(self, "register_to_config")(**new_kwargs)
691
+ original_init(self, *args, **kwargs)
692
+
693
+ cls.__init__ = init
694
+ return cls
diffusers/dependency_versions_check.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import sys
15
+
16
+ from .dependency_versions_table import deps
17
+ from .utils.versions import require_version, require_version_core
18
+
19
+
20
+ # define which module versions we always want to check at run time
21
+ # (usually the ones defined in `install_requires` in setup.py)
22
+ #
23
+ # order specific notes:
24
+ # - tqdm must be checked before tokenizers
25
+
26
+ pkgs_to_check_at_runtime = "python requests filelock numpy".split()
27
+ for pkg in pkgs_to_check_at_runtime:
28
+ if pkg in deps:
29
+ require_version_core(deps[pkg])
30
+ else:
31
+ raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py")
32
+
33
+
34
+ def dep_version_check(pkg, hint=None):
35
+ require_version(deps[pkg], hint)
diffusers/dependency_versions_table.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # THIS FILE HAS BEEN AUTOGENERATED. To update:
2
+ # 1. modify the `_deps` dict in setup.py
3
+ # 2. run `make deps_table_update``
4
+ deps = {
5
+ "Pillow": "Pillow",
6
+ "accelerate": "accelerate>=0.11.0",
7
+ "compel": "compel==0.1.8",
8
+ "black": "black~=23.1",
9
+ "datasets": "datasets",
10
+ "filelock": "filelock",
11
+ "flax": "flax>=0.4.1",
12
+ "hf-doc-builder": "hf-doc-builder>=0.3.0",
13
+ "huggingface-hub": "huggingface-hub>=0.13.2",
14
+ "requests-mock": "requests-mock==1.10.0",
15
+ "importlib_metadata": "importlib_metadata",
16
+ "invisible-watermark": "invisible-watermark>=0.2.0",
17
+ "isort": "isort>=5.5.4",
18
+ "jax": "jax>=0.4.1",
19
+ "jaxlib": "jaxlib>=0.4.1",
20
+ "Jinja2": "Jinja2",
21
+ "k-diffusion": "k-diffusion>=0.0.12",
22
+ "torchsde": "torchsde",
23
+ "note_seq": "note_seq",
24
+ "librosa": "librosa",
25
+ "numpy": "numpy",
26
+ "omegaconf": "omegaconf",
27
+ "parameterized": "parameterized",
28
+ "peft": "peft<=0.6.2",
29
+ "protobuf": "protobuf>=3.20.3,<4",
30
+ "pytest": "pytest",
31
+ "pytest-timeout": "pytest-timeout",
32
+ "pytest-xdist": "pytest-xdist",
33
+ "python": "python>=3.8.0",
34
+ "ruff": "ruff==0.0.280",
35
+ "safetensors": "safetensors>=0.3.1",
36
+ "sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
37
+ "scipy": "scipy",
38
+ "onnx": "onnx",
39
+ "regex": "regex!=2019.12.17",
40
+ "requests": "requests",
41
+ "tensorboard": "tensorboard",
42
+ "torch": "torch>=1.4",
43
+ "torchvision": "torchvision",
44
+ "transformers": "transformers>=4.25.1",
45
+ "urllib3": "urllib3<=2.0.0",
46
+ }
diffusers/experimental/README.md ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # 🧨 Diffusers Experimental
2
+
3
+ We are adding experimental code to support novel applications and usages of the Diffusers library.
4
+ Currently, the following experiments are supported:
5
+ * Reinforcement learning via an implementation of the [Diffuser](https://arxiv.org/abs/2205.09991) model.
diffusers/experimental/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .rl import ValueGuidedRLPipeline
diffusers/experimental/rl/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .value_guided_sampling import ValueGuidedRLPipeline
diffusers/experimental/rl/value_guided_sampling.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import numpy as np
16
+ import torch
17
+ import tqdm
18
+
19
+ from ...models.unet_1d import UNet1DModel
20
+ from ...pipelines import DiffusionPipeline
21
+ from ...utils.dummy_pt_objects import DDPMScheduler
22
+ from ...utils.torch_utils import randn_tensor
23
+
24
+
25
+ class ValueGuidedRLPipeline(DiffusionPipeline):
26
+ r"""
27
+ Pipeline for value-guided sampling from a diffusion model trained to predict sequences of states.
28
+
29
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
30
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
31
+
32
+ Parameters:
33
+ value_function ([`UNet1DModel`]):
34
+ A specialized UNet for fine-tuning trajectories base on reward.
35
+ unet ([`UNet1DModel`]):
36
+ UNet architecture to denoise the encoded trajectories.
37
+ scheduler ([`SchedulerMixin`]):
38
+ A scheduler to be used in combination with `unet` to denoise the encoded trajectories. Default for this
39
+ application is [`DDPMScheduler`].
40
+ env ():
41
+ An environment following the OpenAI gym API to act in. For now only Hopper has pretrained models.
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ value_function: UNet1DModel,
47
+ unet: UNet1DModel,
48
+ scheduler: DDPMScheduler,
49
+ env,
50
+ ):
51
+ super().__init__()
52
+ self.value_function = value_function
53
+ self.unet = unet
54
+ self.scheduler = scheduler
55
+ self.env = env
56
+ self.data = env.get_dataset()
57
+ self.means = {}
58
+ for key in self.data.keys():
59
+ try:
60
+ self.means[key] = self.data[key].mean()
61
+ except: # noqa: E722
62
+ pass
63
+ self.stds = {}
64
+ for key in self.data.keys():
65
+ try:
66
+ self.stds[key] = self.data[key].std()
67
+ except: # noqa: E722
68
+ pass
69
+ self.state_dim = env.observation_space.shape[0]
70
+ self.action_dim = env.action_space.shape[0]
71
+
72
+ def normalize(self, x_in, key):
73
+ return (x_in - self.means[key]) / self.stds[key]
74
+
75
+ def de_normalize(self, x_in, key):
76
+ return x_in * self.stds[key] + self.means[key]
77
+
78
+ def to_torch(self, x_in):
79
+ if isinstance(x_in, dict):
80
+ return {k: self.to_torch(v) for k, v in x_in.items()}
81
+ elif torch.is_tensor(x_in):
82
+ return x_in.to(self.unet.device)
83
+ return torch.tensor(x_in, device=self.unet.device)
84
+
85
+ def reset_x0(self, x_in, cond, act_dim):
86
+ for key, val in cond.items():
87
+ x_in[:, key, act_dim:] = val.clone()
88
+ return x_in
89
+
90
+ def run_diffusion(self, x, conditions, n_guide_steps, scale):
91
+ batch_size = x.shape[0]
92
+ y = None
93
+ for i in tqdm.tqdm(self.scheduler.timesteps):
94
+ # create batch of timesteps to pass into model
95
+ timesteps = torch.full((batch_size,), i, device=self.unet.device, dtype=torch.long)
96
+ for _ in range(n_guide_steps):
97
+ with torch.enable_grad():
98
+ x.requires_grad_()
99
+
100
+ # permute to match dimension for pre-trained models
101
+ y = self.value_function(x.permute(0, 2, 1), timesteps).sample
102
+ grad = torch.autograd.grad([y.sum()], [x])[0]
103
+
104
+ posterior_variance = self.scheduler._get_variance(i)
105
+ model_std = torch.exp(0.5 * posterior_variance)
106
+ grad = model_std * grad
107
+
108
+ grad[timesteps < 2] = 0
109
+ x = x.detach()
110
+ x = x + scale * grad
111
+ x = self.reset_x0(x, conditions, self.action_dim)
112
+
113
+ prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1)
114
+
115
+ # TODO: verify deprecation of this kwarg
116
+ x = self.scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"]
117
+
118
+ # apply conditions to the trajectory (set the initial state)
119
+ x = self.reset_x0(x, conditions, self.action_dim)
120
+ x = self.to_torch(x)
121
+ return x, y
122
+
123
+ def __call__(self, obs, batch_size=64, planning_horizon=32, n_guide_steps=2, scale=0.1):
124
+ # normalize the observations and create batch dimension
125
+ obs = self.normalize(obs, "observations")
126
+ obs = obs[None].repeat(batch_size, axis=0)
127
+
128
+ conditions = {0: self.to_torch(obs)}
129
+ shape = (batch_size, planning_horizon, self.state_dim + self.action_dim)
130
+
131
+ # generate initial noise and apply our conditions (to make the trajectories start at current state)
132
+ x1 = randn_tensor(shape, device=self.unet.device)
133
+ x = self.reset_x0(x1, conditions, self.action_dim)
134
+ x = self.to_torch(x)
135
+
136
+ # run the diffusion process
137
+ x, y = self.run_diffusion(x, conditions, n_guide_steps, scale)
138
+
139
+ # sort output trajectories by value
140
+ sorted_idx = y.argsort(0, descending=True).squeeze()
141
+ sorted_values = x[sorted_idx]
142
+ actions = sorted_values[:, :, : self.action_dim]
143
+ actions = actions.detach().cpu().numpy()
144
+ denorm_actions = self.de_normalize(actions, key="actions")
145
+
146
+ # select the action with the highest value
147
+ if y is not None:
148
+ selected_index = 0
149
+ else:
150
+ # if we didn't run value guiding, select a random action
151
+ selected_index = np.random.randint(0, batch_size)
152
+
153
+ denorm_actions = denorm_actions[selected_index, 0]
154
+ return denorm_actions
diffusers/image_processor.py ADDED
@@ -0,0 +1,476 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import warnings
16
+ from typing import List, Optional, Union
17
+
18
+ import numpy as np
19
+ import PIL.Image
20
+ import torch
21
+ from PIL import Image
22
+
23
+ from .configuration_utils import ConfigMixin, register_to_config
24
+ from .utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate
25
+
26
+
27
+ PipelineImageInput = Union[
28
+ PIL.Image.Image,
29
+ np.ndarray,
30
+ torch.FloatTensor,
31
+ List[PIL.Image.Image],
32
+ List[np.ndarray],
33
+ List[torch.FloatTensor],
34
+ ]
35
+
36
+
37
+ class VaeImageProcessor(ConfigMixin):
38
+ """
39
+ Image processor for VAE.
40
+
41
+ Args:
42
+ do_resize (`bool`, *optional*, defaults to `True`):
43
+ Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept
44
+ `height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method.
45
+ vae_scale_factor (`int`, *optional*, defaults to `8`):
46
+ VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
47
+ resample (`str`, *optional*, defaults to `lanczos`):
48
+ Resampling filter to use when resizing the image.
49
+ do_normalize (`bool`, *optional*, defaults to `True`):
50
+ Whether to normalize the image to [-1,1].
51
+ do_binarize (`bool`, *optional*, defaults to `False`):
52
+ Whether to binarize the image to 0/1.
53
+ do_convert_rgb (`bool`, *optional*, defaults to be `False`):
54
+ Whether to convert the images to RGB format.
55
+ do_convert_grayscale (`bool`, *optional*, defaults to be `False`):
56
+ Whether to convert the images to grayscale format.
57
+ """
58
+
59
+ config_name = CONFIG_NAME
60
+
61
+ @register_to_config
62
+ def __init__(
63
+ self,
64
+ do_resize: bool = True,
65
+ vae_scale_factor: int = 8,
66
+ resample: str = "lanczos",
67
+ do_normalize: bool = True,
68
+ do_binarize: bool = False,
69
+ do_convert_rgb: bool = False,
70
+ do_convert_grayscale: bool = False,
71
+ ):
72
+ super().__init__()
73
+ if do_convert_rgb and do_convert_grayscale:
74
+ raise ValueError(
75
+ "`do_convert_rgb` and `do_convert_grayscale` can not both be set to `True`,"
76
+ " if you intended to convert the image into RGB format, please set `do_convert_grayscale = False`.",
77
+ " if you intended to convert the image into grayscale format, please set `do_convert_rgb = False`",
78
+ )
79
+ self.config.do_convert_rgb = False
80
+
81
+ @staticmethod
82
+ def numpy_to_pil(images: np.ndarray) -> PIL.Image.Image:
83
+ """
84
+ Convert a numpy image or a batch of images to a PIL image.
85
+ """
86
+ if images.ndim == 3:
87
+ images = images[None, ...]
88
+ images = (images * 255).round().astype("uint8")
89
+ if images.shape[-1] == 1:
90
+ # special case for grayscale (single channel) images
91
+ pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
92
+ else:
93
+ pil_images = [Image.fromarray(image) for image in images]
94
+
95
+ return pil_images
96
+
97
+ @staticmethod
98
+ def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
99
+ """
100
+ Convert a PIL image or a list of PIL images to NumPy arrays.
101
+ """
102
+ if not isinstance(images, list):
103
+ images = [images]
104
+ images = [np.array(image).astype(np.float32) / 255.0 for image in images]
105
+ images = np.stack(images, axis=0)
106
+
107
+ return images
108
+
109
+ @staticmethod
110
+ def numpy_to_pt(images: np.ndarray) -> torch.FloatTensor:
111
+ """
112
+ Convert a NumPy image to a PyTorch tensor.
113
+ """
114
+ if images.ndim == 3:
115
+ images = images[..., None]
116
+
117
+ images = torch.from_numpy(images.transpose(0, 3, 1, 2))
118
+ return images
119
+
120
+ @staticmethod
121
+ def pt_to_numpy(images: torch.FloatTensor) -> np.ndarray:
122
+ """
123
+ Convert a PyTorch tensor to a NumPy image.
124
+ """
125
+ images = images.cpu().permute(0, 2, 3, 1).float().numpy()
126
+ return images
127
+
128
+ @staticmethod
129
+ def normalize(images):
130
+ """
131
+ Normalize an image array to [-1,1].
132
+ """
133
+ return 2.0 * images - 1.0
134
+
135
+ @staticmethod
136
+ def denormalize(images):
137
+ """
138
+ Denormalize an image array to [0,1].
139
+ """
140
+ return (images / 2 + 0.5).clamp(0, 1)
141
+
142
+ @staticmethod
143
+ def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image:
144
+ """
145
+ Converts a PIL image to RGB format.
146
+ """
147
+ image = image.convert("RGB")
148
+
149
+ return image
150
+
151
+ @staticmethod
152
+ def convert_to_grayscale(image: PIL.Image.Image) -> PIL.Image.Image:
153
+ """
154
+ Converts a PIL image to grayscale format.
155
+ """
156
+ image = image.convert("L")
157
+
158
+ return image
159
+
160
+ def get_default_height_width(
161
+ self,
162
+ image: [PIL.Image.Image, np.ndarray, torch.Tensor],
163
+ height: Optional[int] = None,
164
+ width: Optional[int] = None,
165
+ ):
166
+ """
167
+ This function return the height and width that are downscaled to the next integer multiple of
168
+ `vae_scale_factor`.
169
+
170
+ Args:
171
+ image(`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`):
172
+ The image input, can be a PIL image, numpy array or pytorch tensor. if it is a numpy array, should have
173
+ shape `[batch, height, width]` or `[batch, height, width, channel]` if it is a pytorch tensor, should
174
+ have shape `[batch, channel, height, width]`.
175
+ height (`int`, *optional*, defaults to `None`):
176
+ The height in preprocessed image. If `None`, will use the height of `image` input.
177
+ width (`int`, *optional*`, defaults to `None`):
178
+ The width in preprocessed. If `None`, will use the width of the `image` input.
179
+ """
180
+
181
+ if height is None:
182
+ if isinstance(image, PIL.Image.Image):
183
+ height = image.height
184
+ elif isinstance(image, torch.Tensor):
185
+ height = image.shape[2]
186
+ else:
187
+ height = image.shape[1]
188
+
189
+ if width is None:
190
+ if isinstance(image, PIL.Image.Image):
191
+ width = image.width
192
+ elif isinstance(image, torch.Tensor):
193
+ width = image.shape[3]
194
+ else:
195
+ width = image.shape[2]
196
+
197
+ width, height = (
198
+ x - x % self.config.vae_scale_factor for x in (width, height)
199
+ ) # resize to integer multiple of vae_scale_factor
200
+
201
+ return height, width
202
+
203
+ def resize(
204
+ self,
205
+ image: [PIL.Image.Image, np.ndarray, torch.Tensor],
206
+ height: Optional[int] = None,
207
+ width: Optional[int] = None,
208
+ ) -> [PIL.Image.Image, np.ndarray, torch.Tensor]:
209
+ """
210
+ Resize image.
211
+ """
212
+ if isinstance(image, PIL.Image.Image):
213
+ image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample])
214
+ elif isinstance(image, torch.Tensor):
215
+ image = torch.nn.functional.interpolate(
216
+ image,
217
+ size=(height, width),
218
+ )
219
+ elif isinstance(image, np.ndarray):
220
+ image = self.numpy_to_pt(image)
221
+ image = torch.nn.functional.interpolate(
222
+ image,
223
+ size=(height, width),
224
+ )
225
+ image = self.pt_to_numpy(image)
226
+ return image
227
+
228
+ def binarize(self, image: PIL.Image.Image) -> PIL.Image.Image:
229
+ """
230
+ create a mask
231
+ """
232
+ image[image < 0.5] = 0
233
+ image[image >= 0.5] = 1
234
+ return image
235
+
236
+ def preprocess(
237
+ self,
238
+ image: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
239
+ height: Optional[int] = None,
240
+ width: Optional[int] = None,
241
+ ) -> torch.Tensor:
242
+ """
243
+ Preprocess the image input. Accepted formats are PIL images, NumPy arrays or PyTorch tensors.
244
+ """
245
+ supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
246
+
247
+ # Expand the missing dimension for 3-dimensional pytorch tensor or numpy array that represents grayscale image
248
+ if self.config.do_convert_grayscale and isinstance(image, (torch.Tensor, np.ndarray)) and image.ndim == 3:
249
+ if isinstance(image, torch.Tensor):
250
+ # if image is a pytorch tensor could have 2 possible shapes:
251
+ # 1. batch x height x width: we should insert the channel dimension at position 1
252
+ # 2. channnel x height x width: we should insert batch dimension at position 0,
253
+ # however, since both channel and batch dimension has same size 1, it is same to insert at position 1
254
+ # for simplicity, we insert a dimension of size 1 at position 1 for both cases
255
+ image = image.unsqueeze(1)
256
+ else:
257
+ # if it is a numpy array, it could have 2 possible shapes:
258
+ # 1. batch x height x width: insert channel dimension on last position
259
+ # 2. height x width x channel: insert batch dimension on first position
260
+ if image.shape[-1] == 1:
261
+ image = np.expand_dims(image, axis=0)
262
+ else:
263
+ image = np.expand_dims(image, axis=-1)
264
+
265
+ if isinstance(image, supported_formats):
266
+ image = [image]
267
+ elif not (isinstance(image, list) and all(isinstance(i, supported_formats) for i in image)):
268
+ raise ValueError(
269
+ f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support {', '.join(supported_formats)}"
270
+ )
271
+
272
+ if isinstance(image[0], PIL.Image.Image):
273
+ if self.config.do_convert_rgb:
274
+ image = [self.convert_to_rgb(i) for i in image]
275
+ elif self.config.do_convert_grayscale:
276
+ image = [self.convert_to_grayscale(i) for i in image]
277
+ if self.config.do_resize:
278
+ height, width = self.get_default_height_width(image[0], height, width)
279
+ image = [self.resize(i, height, width) for i in image]
280
+ image = self.pil_to_numpy(image) # to np
281
+ image = self.numpy_to_pt(image) # to pt
282
+
283
+ elif isinstance(image[0], np.ndarray):
284
+ image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0)
285
+
286
+ image = self.numpy_to_pt(image)
287
+
288
+ height, width = self.get_default_height_width(image, height, width)
289
+ if self.config.do_resize:
290
+ image = self.resize(image, height, width)
291
+
292
+ elif isinstance(image[0], torch.Tensor):
293
+ image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)
294
+
295
+ if self.config.do_convert_grayscale and image.ndim == 3:
296
+ image = image.unsqueeze(1)
297
+
298
+ channel = image.shape[1]
299
+ # don't need any preprocess if the image is latents
300
+ if channel == 4:
301
+ return image
302
+
303
+ height, width = self.get_default_height_width(image, height, width)
304
+ if self.config.do_resize:
305
+ image = self.resize(image, height, width)
306
+
307
+ # expected range [0,1], normalize to [-1,1]
308
+ do_normalize = self.config.do_normalize
309
+ if image.min() < 0 and do_normalize:
310
+ warnings.warn(
311
+ "Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
312
+ f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]",
313
+ FutureWarning,
314
+ )
315
+ do_normalize = False
316
+
317
+ if do_normalize:
318
+ image = self.normalize(image)
319
+
320
+ if self.config.do_binarize:
321
+ image = self.binarize(image)
322
+
323
+ return image
324
+
325
+ def postprocess(
326
+ self,
327
+ image: torch.FloatTensor,
328
+ output_type: str = "pil",
329
+ do_denormalize: Optional[List[bool]] = None,
330
+ ):
331
+ if not isinstance(image, torch.Tensor):
332
+ raise ValueError(
333
+ f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
334
+ )
335
+ if output_type not in ["latent", "pt", "np", "pil"]:
336
+ deprecation_message = (
337
+ f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
338
+ "`pil`, `np`, `pt`, `latent`"
339
+ )
340
+ deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
341
+ output_type = "np"
342
+
343
+ if output_type == "latent":
344
+ return image
345
+
346
+ if do_denormalize is None:
347
+ do_denormalize = [self.config.do_normalize] * image.shape[0]
348
+
349
+ image = torch.stack(
350
+ [self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]
351
+ )
352
+
353
+ if output_type == "pt":
354
+ return image
355
+
356
+ image = self.pt_to_numpy(image)
357
+
358
+ if output_type == "np":
359
+ return image
360
+
361
+ if output_type == "pil":
362
+ return self.numpy_to_pil(image)
363
+
364
+
365
+ class VaeImageProcessorLDM3D(VaeImageProcessor):
366
+ """
367
+ Image processor for VAE LDM3D.
368
+
369
+ Args:
370
+ do_resize (`bool`, *optional*, defaults to `True`):
371
+ Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`.
372
+ vae_scale_factor (`int`, *optional*, defaults to `8`):
373
+ VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
374
+ resample (`str`, *optional*, defaults to `lanczos`):
375
+ Resampling filter to use when resizing the image.
376
+ do_normalize (`bool`, *optional*, defaults to `True`):
377
+ Whether to normalize the image to [-1,1].
378
+ """
379
+
380
+ config_name = CONFIG_NAME
381
+
382
+ @register_to_config
383
+ def __init__(
384
+ self,
385
+ do_resize: bool = True,
386
+ vae_scale_factor: int = 8,
387
+ resample: str = "lanczos",
388
+ do_normalize: bool = True,
389
+ ):
390
+ super().__init__()
391
+
392
+ @staticmethod
393
+ def numpy_to_pil(images):
394
+ """
395
+ Convert a NumPy image or a batch of images to a PIL image.
396
+ """
397
+ if images.ndim == 3:
398
+ images = images[None, ...]
399
+ images = (images * 255).round().astype("uint8")
400
+ if images.shape[-1] == 1:
401
+ # special case for grayscale (single channel) images
402
+ pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
403
+ else:
404
+ pil_images = [Image.fromarray(image[:, :, :3]) for image in images]
405
+
406
+ return pil_images
407
+
408
+ @staticmethod
409
+ def rgblike_to_depthmap(image):
410
+ """
411
+ Args:
412
+ image: RGB-like depth image
413
+
414
+ Returns: depth map
415
+
416
+ """
417
+ return image[:, :, 1] * 2**8 + image[:, :, 2]
418
+
419
+ def numpy_to_depth(self, images):
420
+ """
421
+ Convert a NumPy depth image or a batch of images to a PIL image.
422
+ """
423
+ if images.ndim == 3:
424
+ images = images[None, ...]
425
+ images_depth = images[:, :, :, 3:]
426
+ if images.shape[-1] == 6:
427
+ images_depth = (images_depth * 255).round().astype("uint8")
428
+ pil_images = [
429
+ Image.fromarray(self.rgblike_to_depthmap(image_depth), mode="I;16") for image_depth in images_depth
430
+ ]
431
+ elif images.shape[-1] == 4:
432
+ images_depth = (images_depth * 65535.0).astype(np.uint16)
433
+ pil_images = [Image.fromarray(image_depth, mode="I;16") for image_depth in images_depth]
434
+ else:
435
+ raise Exception("Not supported")
436
+
437
+ return pil_images
438
+
439
+ def postprocess(
440
+ self,
441
+ image: torch.FloatTensor,
442
+ output_type: str = "pil",
443
+ do_denormalize: Optional[List[bool]] = None,
444
+ ):
445
+ if not isinstance(image, torch.Tensor):
446
+ raise ValueError(
447
+ f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
448
+ )
449
+ if output_type not in ["latent", "pt", "np", "pil"]:
450
+ deprecation_message = (
451
+ f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
452
+ "`pil`, `np`, `pt`, `latent`"
453
+ )
454
+ deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
455
+ output_type = "np"
456
+
457
+ if do_denormalize is None:
458
+ do_denormalize = [self.config.do_normalize] * image.shape[0]
459
+
460
+ image = torch.stack(
461
+ [self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]
462
+ )
463
+
464
+ image = self.pt_to_numpy(image)
465
+
466
+ if output_type == "np":
467
+ if image.shape[-1] == 6:
468
+ image_depth = np.stack([self.rgblike_to_depthmap(im[:, :, 3:]) for im in image], axis=0)
469
+ else:
470
+ image_depth = image[:, :, :, 3:]
471
+ return image[:, :, :, :3], image_depth
472
+
473
+ if output_type == "pil":
474
+ return self.numpy_to_pil(image), self.numpy_to_depth(image)
475
+ else:
476
+ raise Exception(f"This type {output_type} is not supported")
diffusers/loaders.py ADDED
The diff for this file is too large to render. See raw diff
 
diffusers/models/README.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Models
2
+
3
+ For more detail on the models, please refer to the [docs](https://huggingface.co/docs/diffusers/api/models/overview).
diffusers/models/__init__.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING
16
+
17
+ from ..utils import DIFFUSERS_SLOW_IMPORT, _LazyModule, is_flax_available, is_torch_available
18
+
19
+
20
+ _import_structure = {}
21
+
22
+ if is_torch_available():
23
+ _import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
24
+ _import_structure["autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
25
+ _import_structure["autoencoder_kl"] = ["AutoencoderKL"]
26
+ _import_structure["autoencoder_tiny"] = ["AutoencoderTiny"]
27
+ _import_structure["consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
28
+ _import_structure["controlnet"] = ["ControlNetModel"]
29
+ _import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"]
30
+ _import_structure["modeling_utils"] = ["ModelMixin"]
31
+ _import_structure["prior_transformer"] = ["PriorTransformer"]
32
+ _import_structure["t5_film_transformer"] = ["T5FilmDecoder"]
33
+ _import_structure["transformer_2d"] = ["Transformer2DModel"]
34
+ _import_structure["transformer_temporal"] = ["TransformerTemporalModel"]
35
+ _import_structure["unet_1d"] = ["UNet1DModel"]
36
+ _import_structure["unet_2d"] = ["UNet2DModel"]
37
+ _import_structure["unet_2d_condition"] = ["UNet2DConditionModel"]
38
+ _import_structure["unet_3d_condition"] = ["UNet3DConditionModel"]
39
+ _import_structure["unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"]
40
+ _import_structure["vq_model"] = ["VQModel"]
41
+
42
+ if is_flax_available():
43
+ _import_structure["controlnet_flax"] = ["FlaxControlNetModel"]
44
+ _import_structure["unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
45
+ _import_structure["vae_flax"] = ["FlaxAutoencoderKL"]
46
+
47
+
48
+ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
49
+ if is_torch_available():
50
+ from .adapter import MultiAdapter, T2IAdapter
51
+ from .autoencoder_asym_kl import AsymmetricAutoencoderKL
52
+ from .autoencoder_kl import AutoencoderKL
53
+ from .autoencoder_tiny import AutoencoderTiny
54
+ from .consistency_decoder_vae import ConsistencyDecoderVAE
55
+ from .controlnet import ControlNetModel
56
+ from .dual_transformer_2d import DualTransformer2DModel
57
+ from .modeling_utils import ModelMixin
58
+ from .prior_transformer import PriorTransformer
59
+ from .t5_film_transformer import T5FilmDecoder
60
+ from .transformer_2d import Transformer2DModel
61
+ from .transformer_temporal import TransformerTemporalModel
62
+ from .unet_1d import UNet1DModel
63
+ from .unet_2d import UNet2DModel
64
+ from .unet_2d_condition import UNet2DConditionModel
65
+ from .unet_3d_condition import UNet3DConditionModel
66
+ from .unet_motion_model import MotionAdapter, UNetMotionModel
67
+ from .vq_model import VQModel
68
+
69
+ if is_flax_available():
70
+ from .controlnet_flax import FlaxControlNetModel
71
+ from .unet_2d_condition_flax import FlaxUNet2DConditionModel
72
+ from .vae_flax import FlaxAutoencoderKL
73
+
74
+ else:
75
+ import sys
76
+
77
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diffusers/models/activations.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 HuggingFace Inc.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from torch import nn
19
+
20
+ from ..utils import USE_PEFT_BACKEND
21
+ from .lora import LoRACompatibleLinear
22
+
23
+
24
+ ACTIVATION_FUNCTIONS = {
25
+ "swish": nn.SiLU(),
26
+ "silu": nn.SiLU(),
27
+ "mish": nn.Mish(),
28
+ "gelu": nn.GELU(),
29
+ "relu": nn.ReLU(),
30
+ }
31
+
32
+
33
+ def get_activation(act_fn: str) -> nn.Module:
34
+ """Helper function to get activation function from string.
35
+
36
+ Args:
37
+ act_fn (str): Name of activation function.
38
+
39
+ Returns:
40
+ nn.Module: Activation function.
41
+ """
42
+
43
+ act_fn = act_fn.lower()
44
+ if act_fn in ACTIVATION_FUNCTIONS:
45
+ return ACTIVATION_FUNCTIONS[act_fn]
46
+ else:
47
+ raise ValueError(f"Unsupported activation function: {act_fn}")
48
+
49
+
50
+ class GELU(nn.Module):
51
+ r"""
52
+ GELU activation function with tanh approximation support with `approximate="tanh"`.
53
+
54
+ Parameters:
55
+ dim_in (`int`): The number of channels in the input.
56
+ dim_out (`int`): The number of channels in the output.
57
+ approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation.
58
+ """
59
+
60
+ def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"):
61
+ super().__init__()
62
+ self.proj = nn.Linear(dim_in, dim_out)
63
+ self.approximate = approximate
64
+
65
+ def gelu(self, gate: torch.Tensor) -> torch.Tensor:
66
+ if gate.device.type != "mps":
67
+ return F.gelu(gate, approximate=self.approximate)
68
+ # mps: gelu is not implemented for float16
69
+ return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
70
+
71
+ def forward(self, hidden_states):
72
+ hidden_states = self.proj(hidden_states)
73
+ hidden_states = self.gelu(hidden_states)
74
+ return hidden_states
75
+
76
+
77
+ class GEGLU(nn.Module):
78
+ r"""
79
+ A [variant](https://arxiv.org/abs/2002.05202) of the gated linear unit activation function.
80
+
81
+ Parameters:
82
+ dim_in (`int`): The number of channels in the input.
83
+ dim_out (`int`): The number of channels in the output.
84
+ """
85
+
86
+ def __init__(self, dim_in: int, dim_out: int):
87
+ super().__init__()
88
+ linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
89
+
90
+ self.proj = linear_cls(dim_in, dim_out * 2)
91
+
92
+ def gelu(self, gate: torch.Tensor) -> torch.Tensor:
93
+ if gate.device.type != "mps":
94
+ return F.gelu(gate)
95
+ # mps: gelu is not implemented for float16
96
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
97
+
98
+ def forward(self, hidden_states, scale: float = 1.0):
99
+ args = () if USE_PEFT_BACKEND else (scale,)
100
+ hidden_states, gate = self.proj(hidden_states, *args).chunk(2, dim=-1)
101
+ return hidden_states * self.gelu(gate)
102
+
103
+
104
+ class ApproximateGELU(nn.Module):
105
+ r"""
106
+ The approximate form of the Gaussian Error Linear Unit (GELU). For more details, see section 2 of this
107
+ [paper](https://arxiv.org/abs/1606.08415).
108
+
109
+ Parameters:
110
+ dim_in (`int`): The number of channels in the input.
111
+ dim_out (`int`): The number of channels in the output.
112
+ """
113
+
114
+ def __init__(self, dim_in: int, dim_out: int):
115
+ super().__init__()
116
+ self.proj = nn.Linear(dim_in, dim_out)
117
+
118
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
119
+ x = self.proj(x)
120
+ return x * torch.sigmoid(1.702 * x)
diffusers/models/adapter.py ADDED
@@ -0,0 +1,584 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ from typing import Callable, List, Optional, Union
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from ..configuration_utils import ConfigMixin, register_to_config
21
+ from ..utils import logging
22
+ from .modeling_utils import ModelMixin
23
+
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+
28
+ class MultiAdapter(ModelMixin):
29
+ r"""
30
+ MultiAdapter is a wrapper model that contains multiple adapter models and merges their outputs according to
31
+ user-assigned weighting.
32
+
33
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
34
+ implements for all the model (such as downloading or saving, etc.)
35
+
36
+ Parameters:
37
+ adapters (`List[T2IAdapter]`, *optional*, defaults to None):
38
+ A list of `T2IAdapter` model instances.
39
+ """
40
+
41
+ def __init__(self, adapters: List["T2IAdapter"]):
42
+ super(MultiAdapter, self).__init__()
43
+
44
+ self.num_adapter = len(adapters)
45
+ self.adapters = nn.ModuleList(adapters)
46
+
47
+ if len(adapters) == 0:
48
+ raise ValueError("Expecting at least one adapter")
49
+
50
+ if len(adapters) == 1:
51
+ raise ValueError("For a single adapter, please use the `T2IAdapter` class instead of `MultiAdapter`")
52
+
53
+ # The outputs from each adapter are added together with a weight.
54
+ # This means that the change in dimensions from downsampling must
55
+ # be the same for all adapters. Inductively, it also means the
56
+ # downscale_factor and total_downscale_factor must be the same for all
57
+ # adapters.
58
+ first_adapter_total_downscale_factor = adapters[0].total_downscale_factor
59
+ first_adapter_downscale_factor = adapters[0].downscale_factor
60
+ for idx in range(1, len(adapters)):
61
+ if (
62
+ adapters[idx].total_downscale_factor != first_adapter_total_downscale_factor
63
+ or adapters[idx].downscale_factor != first_adapter_downscale_factor
64
+ ):
65
+ raise ValueError(
66
+ f"Expecting all adapters to have the same downscaling behavior, but got:\n"
67
+ f"adapters[0].total_downscale_factor={first_adapter_total_downscale_factor}\n"
68
+ f"adapters[0].downscale_factor={first_adapter_downscale_factor}\n"
69
+ f"adapter[`{idx}`].total_downscale_factor={adapters[idx].total_downscale_factor}\n"
70
+ f"adapter[`{idx}`].downscale_factor={adapters[idx].downscale_factor}"
71
+ )
72
+
73
+ self.total_downscale_factor = first_adapter_total_downscale_factor
74
+ self.downscale_factor = first_adapter_downscale_factor
75
+
76
+ def forward(self, xs: torch.Tensor, adapter_weights: Optional[List[float]] = None) -> List[torch.Tensor]:
77
+ r"""
78
+ Args:
79
+ xs (`torch.Tensor`):
80
+ (batch, channel, height, width) input images for multiple adapter models concated along dimension 1,
81
+ `channel` should equal to `num_adapter` * "number of channel of image".
82
+ adapter_weights (`List[float]`, *optional*, defaults to None):
83
+ List of floats representing the weight which will be multiply to each adapter's output before adding
84
+ them together.
85
+ """
86
+ if adapter_weights is None:
87
+ adapter_weights = torch.tensor([1 / self.num_adapter] * self.num_adapter)
88
+ else:
89
+ adapter_weights = torch.tensor(adapter_weights)
90
+
91
+ accume_state = None
92
+ for x, w, adapter in zip(xs, adapter_weights, self.adapters):
93
+ features = adapter(x)
94
+ if accume_state is None:
95
+ accume_state = features
96
+ for i in range(len(accume_state)):
97
+ accume_state[i] = w * accume_state[i]
98
+ else:
99
+ for i in range(len(features)):
100
+ accume_state[i] += w * features[i]
101
+ return accume_state
102
+
103
+ def save_pretrained(
104
+ self,
105
+ save_directory: Union[str, os.PathLike],
106
+ is_main_process: bool = True,
107
+ save_function: Callable = None,
108
+ safe_serialization: bool = True,
109
+ variant: Optional[str] = None,
110
+ ):
111
+ """
112
+ Save a model and its configuration file to a directory, so that it can be re-loaded using the
113
+ `[`~models.adapter.MultiAdapter.from_pretrained`]` class method.
114
+
115
+ Arguments:
116
+ save_directory (`str` or `os.PathLike`):
117
+ Directory to which to save. Will be created if it doesn't exist.
118
+ is_main_process (`bool`, *optional*, defaults to `True`):
119
+ Whether the process calling this is the main process or not. Useful when in distributed training like
120
+ TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
121
+ the main process to avoid race conditions.
122
+ save_function (`Callable`):
123
+ The function to use to save the state dictionary. Useful on distributed training like TPUs when one
124
+ need to replace `torch.save` by another method. Can be configured with the environment variable
125
+ `DIFFUSERS_SAVE_MODE`.
126
+ safe_serialization (`bool`, *optional*, defaults to `True`):
127
+ Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
128
+ variant (`str`, *optional*):
129
+ If specified, weights are saved in the format pytorch_model.<variant>.bin.
130
+ """
131
+ idx = 0
132
+ model_path_to_save = save_directory
133
+ for adapter in self.adapters:
134
+ adapter.save_pretrained(
135
+ model_path_to_save,
136
+ is_main_process=is_main_process,
137
+ save_function=save_function,
138
+ safe_serialization=safe_serialization,
139
+ variant=variant,
140
+ )
141
+
142
+ idx += 1
143
+ model_path_to_save = model_path_to_save + f"_{idx}"
144
+
145
+ @classmethod
146
+ def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike]], **kwargs):
147
+ r"""
148
+ Instantiate a pretrained MultiAdapter model from multiple pre-trained adapter models.
149
+
150
+ The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
151
+ the model, you should first set it back in training mode with `model.train()`.
152
+
153
+ The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
154
+ pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
155
+ task.
156
+
157
+ The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
158
+ weights are discarded.
159
+
160
+ Parameters:
161
+ pretrained_model_path (`os.PathLike`):
162
+ A path to a *directory* containing model weights saved using
163
+ [`~diffusers.models.adapter.MultiAdapter.save_pretrained`], e.g., `./my_model_directory/adapter`.
164
+ torch_dtype (`str` or `torch.dtype`, *optional*):
165
+ Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
166
+ will be automatically derived from the model's weights.
167
+ output_loading_info(`bool`, *optional*, defaults to `False`):
168
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
169
+ device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
170
+ A map that specifies where each submodule should go. It doesn't need to be refined to each
171
+ parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
172
+ same device.
173
+
174
+ To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
175
+ more information about each option see [designing a device
176
+ map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
177
+ max_memory (`Dict`, *optional*):
178
+ A dictionary device identifier to maximum memory. Will default to the maximum memory available for each
179
+ GPU and the available CPU RAM if unset.
180
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
181
+ Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
182
+ also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
183
+ model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
184
+ setting this argument to `True` will raise an error.
185
+ variant (`str`, *optional*):
186
+ If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. `variant` is
187
+ ignored when using `from_flax`.
188
+ use_safetensors (`bool`, *optional*, defaults to `None`):
189
+ If set to `None`, the `safetensors` weights will be downloaded if they're available **and** if the
190
+ `safetensors` library is installed. If set to `True`, the model will be forcibly loaded from
191
+ `safetensors` weights. If set to `False`, loading will *not* use `safetensors`.
192
+ """
193
+ idx = 0
194
+ adapters = []
195
+
196
+ # load adapter and append to list until no adapter directory exists anymore
197
+ # first adapter has to be saved under `./mydirectory/adapter` to be compliant with `DiffusionPipeline.from_pretrained`
198
+ # second, third, ... adapters have to be saved under `./mydirectory/adapter_1`, `./mydirectory/adapter_2`, ...
199
+ model_path_to_load = pretrained_model_path
200
+ while os.path.isdir(model_path_to_load):
201
+ adapter = T2IAdapter.from_pretrained(model_path_to_load, **kwargs)
202
+ adapters.append(adapter)
203
+
204
+ idx += 1
205
+ model_path_to_load = pretrained_model_path + f"_{idx}"
206
+
207
+ logger.info(f"{len(adapters)} adapters loaded from {pretrained_model_path}.")
208
+
209
+ if len(adapters) == 0:
210
+ raise ValueError(
211
+ f"No T2IAdapters found under {os.path.dirname(pretrained_model_path)}. Expected at least {pretrained_model_path + '_0'}."
212
+ )
213
+
214
+ return cls(adapters)
215
+
216
+
217
+ class T2IAdapter(ModelMixin, ConfigMixin):
218
+ r"""
219
+ A simple ResNet-like model that accepts images containing control signals such as keyposes and depth. The model
220
+ generates multiple feature maps that are used as additional conditioning in [`UNet2DConditionModel`]. The model's
221
+ architecture follows the original implementation of
222
+ [Adapter](https://github.com/TencentARC/T2I-Adapter/blob/686de4681515662c0ac2ffa07bf5dda83af1038a/ldm/modules/encoders/adapter.py#L97)
223
+ and
224
+ [AdapterLight](https://github.com/TencentARC/T2I-Adapter/blob/686de4681515662c0ac2ffa07bf5dda83af1038a/ldm/modules/encoders/adapter.py#L235).
225
+
226
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
227
+ implements for all the model (such as downloading or saving, etc.)
228
+
229
+ Parameters:
230
+ in_channels (`int`, *optional*, defaults to 3):
231
+ Number of channels of Aapter's input(*control image*). Set this parameter to 1 if you're using gray scale
232
+ image as *control image*.
233
+ channels (`List[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
234
+ The number of channel of each downsample block's output hidden state. The `len(block_out_channels)` will
235
+ also determine the number of downsample blocks in the Adapter.
236
+ num_res_blocks (`int`, *optional*, defaults to 2):
237
+ Number of ResNet blocks in each downsample block.
238
+ downscale_factor (`int`, *optional*, defaults to 8):
239
+ A factor that determines the total downscale factor of the Adapter.
240
+ adapter_type (`str`, *optional*, defaults to `full_adapter`):
241
+ The type of Adapter to use. Choose either `full_adapter` or `full_adapter_xl` or `light_adapter`.
242
+ """
243
+
244
+ @register_to_config
245
+ def __init__(
246
+ self,
247
+ in_channels: int = 3,
248
+ channels: List[int] = [320, 640, 1280, 1280],
249
+ num_res_blocks: int = 2,
250
+ downscale_factor: int = 8,
251
+ adapter_type: str = "full_adapter",
252
+ ):
253
+ super().__init__()
254
+
255
+ if adapter_type == "full_adapter":
256
+ self.adapter = FullAdapter(in_channels, channels, num_res_blocks, downscale_factor)
257
+ elif adapter_type == "full_adapter_xl":
258
+ self.adapter = FullAdapterXL(in_channels, channels, num_res_blocks, downscale_factor)
259
+ elif adapter_type == "light_adapter":
260
+ self.adapter = LightAdapter(in_channels, channels, num_res_blocks, downscale_factor)
261
+ else:
262
+ raise ValueError(
263
+ f"Unsupported adapter_type: '{adapter_type}'. Choose either 'full_adapter' or "
264
+ "'full_adapter_xl' or 'light_adapter'."
265
+ )
266
+
267
+ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
268
+ r"""
269
+ This function processes the input tensor `x` through the adapter model and returns a list of feature tensors,
270
+ each representing information extracted at a different scale from the input. The length of the list is
271
+ determined by the number of downsample blocks in the Adapter, as specified by the `channels` and
272
+ `num_res_blocks` parameters during initialization.
273
+ """
274
+ return self.adapter(x)
275
+
276
+ @property
277
+ def total_downscale_factor(self):
278
+ return self.adapter.total_downscale_factor
279
+
280
+ @property
281
+ def downscale_factor(self):
282
+ """The downscale factor applied in the T2I-Adapter's initial pixel unshuffle operation. If an input image's dimensions are
283
+ not evenly divisible by the downscale_factor then an exception will be raised.
284
+ """
285
+ return self.adapter.unshuffle.downscale_factor
286
+
287
+
288
+ # full adapter
289
+
290
+
291
+ class FullAdapter(nn.Module):
292
+ r"""
293
+ See [`T2IAdapter`] for more information.
294
+ """
295
+
296
+ def __init__(
297
+ self,
298
+ in_channels: int = 3,
299
+ channels: List[int] = [320, 640, 1280, 1280],
300
+ num_res_blocks: int = 2,
301
+ downscale_factor: int = 8,
302
+ ):
303
+ super().__init__()
304
+
305
+ in_channels = in_channels * downscale_factor**2
306
+
307
+ self.unshuffle = nn.PixelUnshuffle(downscale_factor)
308
+ self.conv_in = nn.Conv2d(in_channels, channels[0], kernel_size=3, padding=1)
309
+
310
+ self.body = nn.ModuleList(
311
+ [
312
+ AdapterBlock(channels[0], channels[0], num_res_blocks),
313
+ *[
314
+ AdapterBlock(channels[i - 1], channels[i], num_res_blocks, down=True)
315
+ for i in range(1, len(channels))
316
+ ],
317
+ ]
318
+ )
319
+
320
+ self.total_downscale_factor = downscale_factor * 2 ** (len(channels) - 1)
321
+
322
+ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
323
+ r"""
324
+ This method processes the input tensor `x` through the FullAdapter model and performs operations including
325
+ pixel unshuffling, convolution, and a stack of AdapterBlocks. It returns a list of feature tensors, each
326
+ capturing information at a different stage of processing within the FullAdapter model. The number of feature
327
+ tensors in the list is determined by the number of downsample blocks specified during initialization.
328
+ """
329
+ x = self.unshuffle(x)
330
+ x = self.conv_in(x)
331
+
332
+ features = []
333
+
334
+ for block in self.body:
335
+ x = block(x)
336
+ features.append(x)
337
+
338
+ return features
339
+
340
+
341
+ class FullAdapterXL(nn.Module):
342
+ r"""
343
+ See [`T2IAdapter`] for more information.
344
+ """
345
+
346
+ def __init__(
347
+ self,
348
+ in_channels: int = 3,
349
+ channels: List[int] = [320, 640, 1280, 1280],
350
+ num_res_blocks: int = 2,
351
+ downscale_factor: int = 16,
352
+ ):
353
+ super().__init__()
354
+
355
+ in_channels = in_channels * downscale_factor**2
356
+
357
+ self.unshuffle = nn.PixelUnshuffle(downscale_factor)
358
+ self.conv_in = nn.Conv2d(in_channels, channels[0], kernel_size=3, padding=1)
359
+
360
+ self.body = []
361
+ # blocks to extract XL features with dimensions of [320, 64, 64], [640, 64, 64], [1280, 32, 32], [1280, 32, 32]
362
+ for i in range(len(channels)):
363
+ if i == 1:
364
+ self.body.append(AdapterBlock(channels[i - 1], channels[i], num_res_blocks))
365
+ elif i == 2:
366
+ self.body.append(AdapterBlock(channels[i - 1], channels[i], num_res_blocks, down=True))
367
+ else:
368
+ self.body.append(AdapterBlock(channels[i], channels[i], num_res_blocks))
369
+
370
+ self.body = nn.ModuleList(self.body)
371
+ # XL has only one downsampling AdapterBlock.
372
+ self.total_downscale_factor = downscale_factor * 2
373
+
374
+ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
375
+ r"""
376
+ This method takes the tensor x as input and processes it through FullAdapterXL model. It consists of operations
377
+ including unshuffling pixels, applying convolution layer and appending each block into list of feature tensors.
378
+ """
379
+ x = self.unshuffle(x)
380
+ x = self.conv_in(x)
381
+
382
+ features = []
383
+
384
+ for block in self.body:
385
+ x = block(x)
386
+ features.append(x)
387
+
388
+ return features
389
+
390
+
391
+ class AdapterBlock(nn.Module):
392
+ r"""
393
+ An AdapterBlock is a helper model that contains multiple ResNet-like blocks. It is used in the `FullAdapter` and
394
+ `FullAdapterXL` models.
395
+
396
+ Parameters:
397
+ in_channels (`int`):
398
+ Number of channels of AdapterBlock's input.
399
+ out_channels (`int`):
400
+ Number of channels of AdapterBlock's output.
401
+ num_res_blocks (`int`):
402
+ Number of ResNet blocks in the AdapterBlock.
403
+ down (`bool`, *optional*, defaults to `False`):
404
+ Whether to perform downsampling on AdapterBlock's input.
405
+ """
406
+
407
+ def __init__(self, in_channels: int, out_channels: int, num_res_blocks: int, down: bool = False):
408
+ super().__init__()
409
+
410
+ self.downsample = None
411
+ if down:
412
+ self.downsample = nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=True)
413
+
414
+ self.in_conv = None
415
+ if in_channels != out_channels:
416
+ self.in_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
417
+
418
+ self.resnets = nn.Sequential(
419
+ *[AdapterResnetBlock(out_channels) for _ in range(num_res_blocks)],
420
+ )
421
+
422
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
423
+ r"""
424
+ This method takes tensor x as input and performs operations downsampling and convolutional layers if the
425
+ self.downsample and self.in_conv properties of AdapterBlock model are specified. Then it applies a series of
426
+ residual blocks to the input tensor.
427
+ """
428
+ if self.downsample is not None:
429
+ x = self.downsample(x)
430
+
431
+ if self.in_conv is not None:
432
+ x = self.in_conv(x)
433
+
434
+ x = self.resnets(x)
435
+
436
+ return x
437
+
438
+
439
+ class AdapterResnetBlock(nn.Module):
440
+ r"""
441
+ An `AdapterResnetBlock` is a helper model that implements a ResNet-like block.
442
+
443
+ Parameters:
444
+ channels (`int`):
445
+ Number of channels of AdapterResnetBlock's input and output.
446
+ """
447
+
448
+ def __init__(self, channels: int):
449
+ super().__init__()
450
+ self.block1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
451
+ self.act = nn.ReLU()
452
+ self.block2 = nn.Conv2d(channels, channels, kernel_size=1)
453
+
454
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
455
+ r"""
456
+ This method takes input tensor x and applies a convolutional layer, ReLU activation, and another convolutional
457
+ layer on the input tensor. It returns addition with the input tensor.
458
+ """
459
+
460
+ h = self.act(self.block1(x))
461
+ h = self.block2(h)
462
+
463
+ return h + x
464
+
465
+
466
+ # light adapter
467
+
468
+
469
+ class LightAdapter(nn.Module):
470
+ r"""
471
+ See [`T2IAdapter`] for more information.
472
+ """
473
+
474
+ def __init__(
475
+ self,
476
+ in_channels: int = 3,
477
+ channels: List[int] = [320, 640, 1280],
478
+ num_res_blocks: int = 4,
479
+ downscale_factor: int = 8,
480
+ ):
481
+ super().__init__()
482
+
483
+ in_channels = in_channels * downscale_factor**2
484
+
485
+ self.unshuffle = nn.PixelUnshuffle(downscale_factor)
486
+
487
+ self.body = nn.ModuleList(
488
+ [
489
+ LightAdapterBlock(in_channels, channels[0], num_res_blocks),
490
+ *[
491
+ LightAdapterBlock(channels[i], channels[i + 1], num_res_blocks, down=True)
492
+ for i in range(len(channels) - 1)
493
+ ],
494
+ LightAdapterBlock(channels[-1], channels[-1], num_res_blocks, down=True),
495
+ ]
496
+ )
497
+
498
+ self.total_downscale_factor = downscale_factor * (2 ** len(channels))
499
+
500
+ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
501
+ r"""
502
+ This method takes the input tensor x and performs downscaling and appends it in list of feature tensors. Each
503
+ feature tensor corresponds to a different level of processing within the LightAdapter.
504
+ """
505
+ x = self.unshuffle(x)
506
+
507
+ features = []
508
+
509
+ for block in self.body:
510
+ x = block(x)
511
+ features.append(x)
512
+
513
+ return features
514
+
515
+
516
+ class LightAdapterBlock(nn.Module):
517
+ r"""
518
+ A `LightAdapterBlock` is a helper model that contains multiple `LightAdapterResnetBlocks`. It is used in the
519
+ `LightAdapter` model.
520
+
521
+ Parameters:
522
+ in_channels (`int`):
523
+ Number of channels of LightAdapterBlock's input.
524
+ out_channels (`int`):
525
+ Number of channels of LightAdapterBlock's output.
526
+ num_res_blocks (`int`):
527
+ Number of LightAdapterResnetBlocks in the LightAdapterBlock.
528
+ down (`bool`, *optional*, defaults to `False`):
529
+ Whether to perform downsampling on LightAdapterBlock's input.
530
+ """
531
+
532
+ def __init__(self, in_channels: int, out_channels: int, num_res_blocks: int, down: bool = False):
533
+ super().__init__()
534
+ mid_channels = out_channels // 4
535
+
536
+ self.downsample = None
537
+ if down:
538
+ self.downsample = nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=True)
539
+
540
+ self.in_conv = nn.Conv2d(in_channels, mid_channels, kernel_size=1)
541
+ self.resnets = nn.Sequential(*[LightAdapterResnetBlock(mid_channels) for _ in range(num_res_blocks)])
542
+ self.out_conv = nn.Conv2d(mid_channels, out_channels, kernel_size=1)
543
+
544
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
545
+ r"""
546
+ This method takes tensor x as input and performs downsampling if required. Then it applies in convolution
547
+ layer, a sequence of residual blocks, and out convolutional layer.
548
+ """
549
+ if self.downsample is not None:
550
+ x = self.downsample(x)
551
+
552
+ x = self.in_conv(x)
553
+ x = self.resnets(x)
554
+ x = self.out_conv(x)
555
+
556
+ return x
557
+
558
+
559
+ class LightAdapterResnetBlock(nn.Module):
560
+ """
561
+ A `LightAdapterResnetBlock` is a helper model that implements a ResNet-like block with a slightly different
562
+ architecture than `AdapterResnetBlock`.
563
+
564
+ Parameters:
565
+ channels (`int`):
566
+ Number of channels of LightAdapterResnetBlock's input and output.
567
+ """
568
+
569
+ def __init__(self, channels: int):
570
+ super().__init__()
571
+ self.block1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
572
+ self.act = nn.ReLU()
573
+ self.block2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
574
+
575
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
576
+ r"""
577
+ This function takes input tensor x and processes it through one convolutional layer, ReLU activation, and
578
+ another convolutional layer and adds it to input tensor.
579
+ """
580
+
581
+ h = self.act(self.block1(x))
582
+ h = self.block2(h)
583
+
584
+ return h + x
diffusers/models/attention.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Any, Dict, Optional
15
+
16
+ import torch
17
+ from torch import nn
18
+
19
+ from ..utils import USE_PEFT_BACKEND
20
+ from ..utils.torch_utils import maybe_allow_in_graph
21
+ from .activations import GEGLU, GELU, ApproximateGELU
22
+ from .attention_processor import Attention
23
+ from .embeddings import SinusoidalPositionalEmbedding
24
+ from .lora import LoRACompatibleLinear
25
+ from .normalization import AdaLayerNorm, AdaLayerNormZero
26
+
27
+
28
+ @maybe_allow_in_graph
29
+ class GatedSelfAttentionDense(nn.Module):
30
+ r"""
31
+ A gated self-attention dense layer that combines visual features and object features.
32
+
33
+ Parameters:
34
+ query_dim (`int`): The number of channels in the query.
35
+ context_dim (`int`): The number of channels in the context.
36
+ n_heads (`int`): The number of heads to use for attention.
37
+ d_head (`int`): The number of channels in each head.
38
+ """
39
+
40
+ def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int):
41
+ super().__init__()
42
+
43
+ # we need a linear projection since we need cat visual feature and obj feature
44
+ self.linear = nn.Linear(context_dim, query_dim)
45
+
46
+ self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
47
+ self.ff = FeedForward(query_dim, activation_fn="geglu")
48
+
49
+ self.norm1 = nn.LayerNorm(query_dim)
50
+ self.norm2 = nn.LayerNorm(query_dim)
51
+
52
+ self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
53
+ self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
54
+
55
+ self.enabled = True
56
+
57
+ def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:
58
+ if not self.enabled:
59
+ return x
60
+
61
+ n_visual = x.shape[1]
62
+ objs = self.linear(objs)
63
+
64
+ x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :]
65
+ x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))
66
+
67
+ return x
68
+
69
+
70
+ @maybe_allow_in_graph
71
+ class BasicTransformerBlock(nn.Module):
72
+ r"""
73
+ A basic Transformer block.
74
+
75
+ Parameters:
76
+ dim (`int`): The number of channels in the input and output.
77
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
78
+ attention_head_dim (`int`): The number of channels in each head.
79
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
80
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
81
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
82
+ num_embeds_ada_norm (:
83
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
84
+ attention_bias (:
85
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
86
+ only_cross_attention (`bool`, *optional*):
87
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
88
+ double_self_attention (`bool`, *optional*):
89
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
90
+ upcast_attention (`bool`, *optional*):
91
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
92
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
93
+ Whether to use learnable elementwise affine parameters for normalization.
94
+ norm_type (`str`, *optional*, defaults to `"layer_norm"`):
95
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
96
+ final_dropout (`bool` *optional*, defaults to False):
97
+ Whether to apply a final dropout after the last feed-forward layer.
98
+ attention_type (`str`, *optional*, defaults to `"default"`):
99
+ The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
100
+ positional_embeddings (`str`, *optional*, defaults to `None`):
101
+ The type of positional embeddings to apply to.
102
+ num_positional_embeddings (`int`, *optional*, defaults to `None`):
103
+ The maximum number of positional embeddings to apply.
104
+ """
105
+
106
+ def __init__(
107
+ self,
108
+ dim: int,
109
+ num_attention_heads: int,
110
+ attention_head_dim: int,
111
+ dropout=0.0,
112
+ cross_attention_dim: Optional[int] = None,
113
+ activation_fn: str = "geglu",
114
+ num_embeds_ada_norm: Optional[int] = None,
115
+ attention_bias: bool = False,
116
+ only_cross_attention: bool = False,
117
+ double_self_attention: bool = False,
118
+ upcast_attention: bool = False,
119
+ norm_elementwise_affine: bool = True,
120
+ norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single'
121
+ norm_eps: float = 1e-5,
122
+ final_dropout: bool = False,
123
+ attention_type: str = "default",
124
+ positional_embeddings: Optional[str] = None,
125
+ num_positional_embeddings: Optional[int] = None,
126
+ ):
127
+ super().__init__()
128
+ self.only_cross_attention = only_cross_attention
129
+
130
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
131
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
132
+ self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
133
+ self.use_layer_norm = norm_type == "layer_norm"
134
+
135
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
136
+ raise ValueError(
137
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
138
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
139
+ )
140
+
141
+ if positional_embeddings and (num_positional_embeddings is None):
142
+ raise ValueError(
143
+ "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
144
+ )
145
+
146
+ if positional_embeddings == "sinusoidal":
147
+ self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
148
+ else:
149
+ self.pos_embed = None
150
+
151
+ # Define 3 blocks. Each block has its own normalization layer.
152
+ # 1. Self-Attn
153
+ if self.use_ada_layer_norm:
154
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
155
+ elif self.use_ada_layer_norm_zero:
156
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
157
+ else:
158
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
159
+
160
+ self.attn1 = Attention(
161
+ query_dim=dim,
162
+ heads=num_attention_heads,
163
+ dim_head=attention_head_dim,
164
+ dropout=dropout,
165
+ bias=attention_bias,
166
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
167
+ upcast_attention=upcast_attention,
168
+ )
169
+
170
+ # 2. Cross-Attn
171
+ if cross_attention_dim is not None or double_self_attention:
172
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
173
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
174
+ # the second cross attention block.
175
+ self.norm2 = (
176
+ AdaLayerNorm(dim, num_embeds_ada_norm)
177
+ if self.use_ada_layer_norm
178
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
179
+ )
180
+ self.attn2 = Attention(
181
+ query_dim=dim,
182
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
183
+ heads=num_attention_heads,
184
+ dim_head=attention_head_dim,
185
+ dropout=dropout,
186
+ bias=attention_bias,
187
+ upcast_attention=upcast_attention,
188
+ ) # is self-attn if encoder_hidden_states is none
189
+ else:
190
+ self.norm2 = None
191
+ self.attn2 = None
192
+
193
+ # 3. Feed-forward
194
+ if not self.use_ada_layer_norm_single:
195
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
196
+
197
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
198
+
199
+ # 4. Fuser
200
+ if attention_type == "gated" or attention_type == "gated-text-image":
201
+ self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
202
+
203
+ # 5. Scale-shift for PixArt-Alpha.
204
+ if self.use_ada_layer_norm_single:
205
+ self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
206
+
207
+ # let chunk size default to None
208
+ self._chunk_size = None
209
+ self._chunk_dim = 0
210
+
211
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
212
+ # Sets chunk feed-forward
213
+ self._chunk_size = chunk_size
214
+ self._chunk_dim = dim
215
+
216
+ def forward(
217
+ self,
218
+ hidden_states: torch.FloatTensor,
219
+ attention_mask: Optional[torch.FloatTensor] = None,
220
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
221
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
222
+ timestep: Optional[torch.LongTensor] = None,
223
+ cross_attention_kwargs: Dict[str, Any] = None,
224
+ class_labels: Optional[torch.LongTensor] = None,
225
+ ) -> torch.FloatTensor:
226
+ # Notice that normalization is always applied before the real computation in the following blocks.
227
+ # 0. Self-Attention
228
+ batch_size = hidden_states.shape[0]
229
+
230
+ if self.use_ada_layer_norm:
231
+ norm_hidden_states = self.norm1(hidden_states, timestep)
232
+ elif self.use_ada_layer_norm_zero:
233
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
234
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
235
+ )
236
+ elif self.use_layer_norm:
237
+ norm_hidden_states = self.norm1(hidden_states)
238
+ elif self.use_ada_layer_norm_single:
239
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
240
+ self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
241
+ ).chunk(6, dim=1)
242
+ norm_hidden_states = self.norm1(hidden_states)
243
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
244
+ norm_hidden_states = norm_hidden_states.squeeze(1)
245
+ else:
246
+ raise ValueError("Incorrect norm used")
247
+
248
+ if self.pos_embed is not None:
249
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
250
+
251
+ # 1. Retrieve lora scale.
252
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
253
+
254
+ # 2. Prepare GLIGEN inputs
255
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
256
+ gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
257
+
258
+ attn_output = self.attn1(
259
+ norm_hidden_states, # 32 4096 320
260
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, # 32 77 768
261
+ attention_mask=attention_mask,
262
+ **cross_attention_kwargs,
263
+ )
264
+ if self.use_ada_layer_norm_zero:
265
+ attn_output = gate_msa.unsqueeze(1) * attn_output
266
+ elif self.use_ada_layer_norm_single:
267
+ attn_output = gate_msa * attn_output
268
+
269
+ hidden_states = attn_output + hidden_states
270
+ if hidden_states.ndim == 4:
271
+ hidden_states = hidden_states.squeeze(1)
272
+
273
+ # 2.5 GLIGEN Control
274
+ if gligen_kwargs is not None:
275
+ hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
276
+
277
+ # 3. Cross-Attention
278
+ if self.attn2 is not None:
279
+ if self.use_ada_layer_norm:
280
+ norm_hidden_states = self.norm2(hidden_states, timestep)
281
+ elif self.use_ada_layer_norm_zero or self.use_layer_norm:
282
+ norm_hidden_states = self.norm2(hidden_states)
283
+ elif self.use_ada_layer_norm_single:
284
+ # For PixArt norm2 isn't applied here:
285
+ # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
286
+ norm_hidden_states = hidden_states
287
+ else:
288
+ raise ValueError("Incorrect norm")
289
+
290
+ if self.pos_embed is not None and self.use_ada_layer_norm_single is False:
291
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
292
+
293
+ attn_output = self.attn2(
294
+ norm_hidden_states,
295
+ encoder_hidden_states=encoder_hidden_states,
296
+ attention_mask=encoder_attention_mask,
297
+ **cross_attention_kwargs,
298
+ )
299
+ hidden_states = attn_output + hidden_states
300
+
301
+ # 4. Feed-forward
302
+ if not self.use_ada_layer_norm_single:
303
+ norm_hidden_states = self.norm3(hidden_states)
304
+
305
+ if self.use_ada_layer_norm_zero:
306
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
307
+
308
+ if self.use_ada_layer_norm_single:
309
+ norm_hidden_states = self.norm2(hidden_states)
310
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
311
+
312
+ if self._chunk_size is not None:
313
+ # "feed_forward_chunk_size" can be used to save memory
314
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
315
+ raise ValueError(
316
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
317
+ )
318
+
319
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
320
+ ff_output = torch.cat(
321
+ [
322
+ self.ff(hid_slice, scale=lora_scale)
323
+ for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)
324
+ ],
325
+ dim=self._chunk_dim,
326
+ )
327
+ else:
328
+ ff_output = self.ff(norm_hidden_states, scale=lora_scale)
329
+
330
+ if self.use_ada_layer_norm_zero:
331
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
332
+ elif self.use_ada_layer_norm_single:
333
+ ff_output = gate_mlp * ff_output
334
+
335
+ hidden_states = ff_output + hidden_states
336
+ if hidden_states.ndim == 4:
337
+ hidden_states = hidden_states.squeeze(1)
338
+
339
+ return hidden_states
340
+
341
+
342
+ class FeedForward(nn.Module):
343
+ r"""
344
+ A feed-forward layer.
345
+
346
+ Parameters:
347
+ dim (`int`): The number of channels in the input.
348
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
349
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
350
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
351
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
352
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
353
+ """
354
+
355
+ def __init__(
356
+ self,
357
+ dim: int,
358
+ dim_out: Optional[int] = None,
359
+ mult: int = 4,
360
+ dropout: float = 0.0,
361
+ activation_fn: str = "geglu",
362
+ final_dropout: bool = False,
363
+ ):
364
+ super().__init__()
365
+ inner_dim = int(dim * mult)
366
+ dim_out = dim_out if dim_out is not None else dim
367
+ linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
368
+
369
+ if activation_fn == "gelu":
370
+ act_fn = GELU(dim, inner_dim)
371
+ if activation_fn == "gelu-approximate":
372
+ act_fn = GELU(dim, inner_dim, approximate="tanh")
373
+ elif activation_fn == "geglu":
374
+ act_fn = GEGLU(dim, inner_dim)
375
+ elif activation_fn == "geglu-approximate":
376
+ act_fn = ApproximateGELU(dim, inner_dim)
377
+
378
+ self.net = nn.ModuleList([])
379
+ # project in
380
+ self.net.append(act_fn)
381
+ # project dropout
382
+ self.net.append(nn.Dropout(dropout))
383
+ # project out
384
+ self.net.append(linear_cls(inner_dim, dim_out))
385
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
386
+ if final_dropout:
387
+ self.net.append(nn.Dropout(dropout))
388
+
389
+ def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
390
+ compatible_cls = (GEGLU,) if USE_PEFT_BACKEND else (GEGLU, LoRACompatibleLinear)
391
+ for module in self.net:
392
+ if isinstance(module, compatible_cls):
393
+ hidden_states = module(hidden_states, scale)
394
+ else:
395
+ hidden_states = module(hidden_states)
396
+ return hidden_states
diffusers/models/attention_flax.py ADDED
@@ -0,0 +1,486 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import functools
16
+ import math
17
+
18
+ import flax.linen as nn
19
+ import jax
20
+ import jax.numpy as jnp
21
+
22
+
23
+ def _query_chunk_attention(query, key, value, precision, key_chunk_size: int = 4096):
24
+ """Multi-head dot product attention with a limited number of queries."""
25
+ num_kv, num_heads, k_features = key.shape[-3:]
26
+ v_features = value.shape[-1]
27
+ key_chunk_size = min(key_chunk_size, num_kv)
28
+ query = query / jnp.sqrt(k_features)
29
+
30
+ @functools.partial(jax.checkpoint, prevent_cse=False)
31
+ def summarize_chunk(query, key, value):
32
+ attn_weights = jnp.einsum("...qhd,...khd->...qhk", query, key, precision=precision)
33
+
34
+ max_score = jnp.max(attn_weights, axis=-1, keepdims=True)
35
+ max_score = jax.lax.stop_gradient(max_score)
36
+ exp_weights = jnp.exp(attn_weights - max_score)
37
+
38
+ exp_values = jnp.einsum("...vhf,...qhv->...qhf", value, exp_weights, precision=precision)
39
+ max_score = jnp.einsum("...qhk->...qh", max_score)
40
+
41
+ return (exp_values, exp_weights.sum(axis=-1), max_score)
42
+
43
+ def chunk_scanner(chunk_idx):
44
+ # julienne key array
45
+ key_chunk = jax.lax.dynamic_slice(
46
+ operand=key,
47
+ start_indices=[0] * (key.ndim - 3) + [chunk_idx, 0, 0], # [...,k,h,d]
48
+ slice_sizes=list(key.shape[:-3]) + [key_chunk_size, num_heads, k_features], # [...,k,h,d]
49
+ )
50
+
51
+ # julienne value array
52
+ value_chunk = jax.lax.dynamic_slice(
53
+ operand=value,
54
+ start_indices=[0] * (value.ndim - 3) + [chunk_idx, 0, 0], # [...,v,h,d]
55
+ slice_sizes=list(value.shape[:-3]) + [key_chunk_size, num_heads, v_features], # [...,v,h,d]
56
+ )
57
+
58
+ return summarize_chunk(query, key_chunk, value_chunk)
59
+
60
+ chunk_values, chunk_weights, chunk_max = jax.lax.map(f=chunk_scanner, xs=jnp.arange(0, num_kv, key_chunk_size))
61
+
62
+ global_max = jnp.max(chunk_max, axis=0, keepdims=True)
63
+ max_diffs = jnp.exp(chunk_max - global_max)
64
+
65
+ chunk_values *= jnp.expand_dims(max_diffs, axis=-1)
66
+ chunk_weights *= max_diffs
67
+
68
+ all_values = chunk_values.sum(axis=0)
69
+ all_weights = jnp.expand_dims(chunk_weights, -1).sum(axis=0)
70
+
71
+ return all_values / all_weights
72
+
73
+
74
+ def jax_memory_efficient_attention(
75
+ query, key, value, precision=jax.lax.Precision.HIGHEST, query_chunk_size: int = 1024, key_chunk_size: int = 4096
76
+ ):
77
+ r"""
78
+ Flax Memory-efficient multi-head dot product attention. https://arxiv.org/abs/2112.05682v2
79
+ https://github.com/AminRezaei0x443/memory-efficient-attention
80
+
81
+ Args:
82
+ query (`jnp.ndarray`): (batch..., query_length, head, query_key_depth_per_head)
83
+ key (`jnp.ndarray`): (batch..., key_value_length, head, query_key_depth_per_head)
84
+ value (`jnp.ndarray`): (batch..., key_value_length, head, value_depth_per_head)
85
+ precision (`jax.lax.Precision`, *optional*, defaults to `jax.lax.Precision.HIGHEST`):
86
+ numerical precision for computation
87
+ query_chunk_size (`int`, *optional*, defaults to 1024):
88
+ chunk size to divide query array value must divide query_length equally without remainder
89
+ key_chunk_size (`int`, *optional*, defaults to 4096):
90
+ chunk size to divide key and value array value must divide key_value_length equally without remainder
91
+
92
+ Returns:
93
+ (`jnp.ndarray`) with shape of (batch..., query_length, head, value_depth_per_head)
94
+ """
95
+ num_q, num_heads, q_features = query.shape[-3:]
96
+
97
+ def chunk_scanner(chunk_idx, _):
98
+ # julienne query array
99
+ query_chunk = jax.lax.dynamic_slice(
100
+ operand=query,
101
+ start_indices=([0] * (query.ndim - 3)) + [chunk_idx, 0, 0], # [...,q,h,d]
102
+ slice_sizes=list(query.shape[:-3]) + [min(query_chunk_size, num_q), num_heads, q_features], # [...,q,h,d]
103
+ )
104
+
105
+ return (
106
+ chunk_idx + query_chunk_size, # unused ignore it
107
+ _query_chunk_attention(
108
+ query=query_chunk, key=key, value=value, precision=precision, key_chunk_size=key_chunk_size
109
+ ),
110
+ )
111
+
112
+ _, res = jax.lax.scan(
113
+ f=chunk_scanner, init=0, xs=None, length=math.ceil(num_q / query_chunk_size) # start counter # stop counter
114
+ )
115
+
116
+ return jnp.concatenate(res, axis=-3) # fuse the chunked result back
117
+
118
+
119
+ class FlaxAttention(nn.Module):
120
+ r"""
121
+ A Flax multi-head attention module as described in: https://arxiv.org/abs/1706.03762
122
+
123
+ Parameters:
124
+ query_dim (:obj:`int`):
125
+ Input hidden states dimension
126
+ heads (:obj:`int`, *optional*, defaults to 8):
127
+ Number of heads
128
+ dim_head (:obj:`int`, *optional*, defaults to 64):
129
+ Hidden states dimension inside each head
130
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
131
+ Dropout rate
132
+ use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
133
+ enable memory efficient attention https://arxiv.org/abs/2112.05682
134
+ split_head_dim (`bool`, *optional*, defaults to `False`):
135
+ Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
136
+ enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
137
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
138
+ Parameters `dtype`
139
+
140
+ """
141
+ query_dim: int
142
+ heads: int = 8
143
+ dim_head: int = 64
144
+ dropout: float = 0.0
145
+ use_memory_efficient_attention: bool = False
146
+ split_head_dim: bool = False
147
+ dtype: jnp.dtype = jnp.float32
148
+
149
+ def setup(self):
150
+ inner_dim = self.dim_head * self.heads
151
+ self.scale = self.dim_head**-0.5
152
+
153
+ # Weights were exported with old names {to_q, to_k, to_v, to_out}
154
+ self.query = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_q")
155
+ self.key = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_k")
156
+ self.value = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_v")
157
+
158
+ self.proj_attn = nn.Dense(self.query_dim, dtype=self.dtype, name="to_out_0")
159
+ self.dropout_layer = nn.Dropout(rate=self.dropout)
160
+
161
+ def reshape_heads_to_batch_dim(self, tensor):
162
+ batch_size, seq_len, dim = tensor.shape
163
+ head_size = self.heads
164
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
165
+ tensor = jnp.transpose(tensor, (0, 2, 1, 3))
166
+ tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
167
+ return tensor
168
+
169
+ def reshape_batch_dim_to_heads(self, tensor):
170
+ batch_size, seq_len, dim = tensor.shape
171
+ head_size = self.heads
172
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
173
+ tensor = jnp.transpose(tensor, (0, 2, 1, 3))
174
+ tensor = tensor.reshape(batch_size // head_size, seq_len, dim * head_size)
175
+ return tensor
176
+
177
+ def __call__(self, hidden_states, context=None, deterministic=True):
178
+ context = hidden_states if context is None else context
179
+
180
+ query_proj = self.query(hidden_states)
181
+ key_proj = self.key(context)
182
+ value_proj = self.value(context)
183
+
184
+ if self.split_head_dim:
185
+ b = hidden_states.shape[0]
186
+ query_states = jnp.reshape(query_proj, (b, -1, self.heads, self.dim_head))
187
+ key_states = jnp.reshape(key_proj, (b, -1, self.heads, self.dim_head))
188
+ value_states = jnp.reshape(value_proj, (b, -1, self.heads, self.dim_head))
189
+ else:
190
+ query_states = self.reshape_heads_to_batch_dim(query_proj)
191
+ key_states = self.reshape_heads_to_batch_dim(key_proj)
192
+ value_states = self.reshape_heads_to_batch_dim(value_proj)
193
+
194
+ if self.use_memory_efficient_attention:
195
+ query_states = query_states.transpose(1, 0, 2)
196
+ key_states = key_states.transpose(1, 0, 2)
197
+ value_states = value_states.transpose(1, 0, 2)
198
+
199
+ # this if statement create a chunk size for each layer of the unet
200
+ # the chunk size is equal to the query_length dimension of the deepest layer of the unet
201
+
202
+ flatten_latent_dim = query_states.shape[-3]
203
+ if flatten_latent_dim % 64 == 0:
204
+ query_chunk_size = int(flatten_latent_dim / 64)
205
+ elif flatten_latent_dim % 16 == 0:
206
+ query_chunk_size = int(flatten_latent_dim / 16)
207
+ elif flatten_latent_dim % 4 == 0:
208
+ query_chunk_size = int(flatten_latent_dim / 4)
209
+ else:
210
+ query_chunk_size = int(flatten_latent_dim)
211
+
212
+ hidden_states = jax_memory_efficient_attention(
213
+ query_states, key_states, value_states, query_chunk_size=query_chunk_size, key_chunk_size=4096 * 4
214
+ )
215
+
216
+ hidden_states = hidden_states.transpose(1, 0, 2)
217
+ else:
218
+ # compute attentions
219
+ if self.split_head_dim:
220
+ attention_scores = jnp.einsum("b t n h, b f n h -> b n f t", key_states, query_states)
221
+ else:
222
+ attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states)
223
+
224
+ attention_scores = attention_scores * self.scale
225
+ attention_probs = nn.softmax(attention_scores, axis=-1 if self.split_head_dim else 2)
226
+
227
+ # attend to values
228
+ if self.split_head_dim:
229
+ hidden_states = jnp.einsum("b n f t, b t n h -> b f n h", attention_probs, value_states)
230
+ b = hidden_states.shape[0]
231
+ hidden_states = jnp.reshape(hidden_states, (b, -1, self.heads * self.dim_head))
232
+ else:
233
+ hidden_states = jnp.einsum("b i j, b j d -> b i d", attention_probs, value_states)
234
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
235
+
236
+ hidden_states = self.proj_attn(hidden_states)
237
+ return self.dropout_layer(hidden_states, deterministic=deterministic)
238
+
239
+
240
+ class FlaxBasicTransformerBlock(nn.Module):
241
+ r"""
242
+ A Flax transformer block layer with `GLU` (Gated Linear Unit) activation function as described in:
243
+ https://arxiv.org/abs/1706.03762
244
+
245
+
246
+ Parameters:
247
+ dim (:obj:`int`):
248
+ Inner hidden states dimension
249
+ n_heads (:obj:`int`):
250
+ Number of heads
251
+ d_head (:obj:`int`):
252
+ Hidden states dimension inside each head
253
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
254
+ Dropout rate
255
+ only_cross_attention (`bool`, defaults to `False`):
256
+ Whether to only apply cross attention.
257
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
258
+ Parameters `dtype`
259
+ use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
260
+ enable memory efficient attention https://arxiv.org/abs/2112.05682
261
+ split_head_dim (`bool`, *optional*, defaults to `False`):
262
+ Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
263
+ enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
264
+ """
265
+ dim: int
266
+ n_heads: int
267
+ d_head: int
268
+ dropout: float = 0.0
269
+ only_cross_attention: bool = False
270
+ dtype: jnp.dtype = jnp.float32
271
+ use_memory_efficient_attention: bool = False
272
+ split_head_dim: bool = False
273
+
274
+ def setup(self):
275
+ # self attention (or cross_attention if only_cross_attention is True)
276
+ self.attn1 = FlaxAttention(
277
+ self.dim,
278
+ self.n_heads,
279
+ self.d_head,
280
+ self.dropout,
281
+ self.use_memory_efficient_attention,
282
+ self.split_head_dim,
283
+ dtype=self.dtype,
284
+ )
285
+ # cross attention
286
+ self.attn2 = FlaxAttention(
287
+ self.dim,
288
+ self.n_heads,
289
+ self.d_head,
290
+ self.dropout,
291
+ self.use_memory_efficient_attention,
292
+ self.split_head_dim,
293
+ dtype=self.dtype,
294
+ )
295
+ self.ff = FlaxFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype)
296
+ self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
297
+ self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
298
+ self.norm3 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
299
+ self.dropout_layer = nn.Dropout(rate=self.dropout)
300
+
301
+ def __call__(self, hidden_states, context, deterministic=True):
302
+ # self attention
303
+ residual = hidden_states
304
+ if self.only_cross_attention:
305
+ hidden_states = self.attn1(self.norm1(hidden_states), context, deterministic=deterministic)
306
+ else:
307
+ hidden_states = self.attn1(self.norm1(hidden_states), deterministic=deterministic)
308
+ hidden_states = hidden_states + residual
309
+
310
+ # cross attention
311
+ residual = hidden_states
312
+ hidden_states = self.attn2(self.norm2(hidden_states), context, deterministic=deterministic)
313
+ hidden_states = hidden_states + residual
314
+
315
+ # feed forward
316
+ residual = hidden_states
317
+ hidden_states = self.ff(self.norm3(hidden_states), deterministic=deterministic)
318
+ hidden_states = hidden_states + residual
319
+
320
+ return self.dropout_layer(hidden_states, deterministic=deterministic)
321
+
322
+
323
+ class FlaxTransformer2DModel(nn.Module):
324
+ r"""
325
+ A Spatial Transformer layer with Gated Linear Unit (GLU) activation function as described in:
326
+ https://arxiv.org/pdf/1506.02025.pdf
327
+
328
+
329
+ Parameters:
330
+ in_channels (:obj:`int`):
331
+ Input number of channels
332
+ n_heads (:obj:`int`):
333
+ Number of heads
334
+ d_head (:obj:`int`):
335
+ Hidden states dimension inside each head
336
+ depth (:obj:`int`, *optional*, defaults to 1):
337
+ Number of transformers block
338
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
339
+ Dropout rate
340
+ use_linear_projection (`bool`, defaults to `False`): tbd
341
+ only_cross_attention (`bool`, defaults to `False`): tbd
342
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
343
+ Parameters `dtype`
344
+ use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
345
+ enable memory efficient attention https://arxiv.org/abs/2112.05682
346
+ split_head_dim (`bool`, *optional*, defaults to `False`):
347
+ Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
348
+ enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
349
+ """
350
+ in_channels: int
351
+ n_heads: int
352
+ d_head: int
353
+ depth: int = 1
354
+ dropout: float = 0.0
355
+ use_linear_projection: bool = False
356
+ only_cross_attention: bool = False
357
+ dtype: jnp.dtype = jnp.float32
358
+ use_memory_efficient_attention: bool = False
359
+ split_head_dim: bool = False
360
+
361
+ def setup(self):
362
+ self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5)
363
+
364
+ inner_dim = self.n_heads * self.d_head
365
+ if self.use_linear_projection:
366
+ self.proj_in = nn.Dense(inner_dim, dtype=self.dtype)
367
+ else:
368
+ self.proj_in = nn.Conv(
369
+ inner_dim,
370
+ kernel_size=(1, 1),
371
+ strides=(1, 1),
372
+ padding="VALID",
373
+ dtype=self.dtype,
374
+ )
375
+
376
+ self.transformer_blocks = [
377
+ FlaxBasicTransformerBlock(
378
+ inner_dim,
379
+ self.n_heads,
380
+ self.d_head,
381
+ dropout=self.dropout,
382
+ only_cross_attention=self.only_cross_attention,
383
+ dtype=self.dtype,
384
+ use_memory_efficient_attention=self.use_memory_efficient_attention,
385
+ split_head_dim=self.split_head_dim,
386
+ )
387
+ for _ in range(self.depth)
388
+ ]
389
+
390
+ if self.use_linear_projection:
391
+ self.proj_out = nn.Dense(inner_dim, dtype=self.dtype)
392
+ else:
393
+ self.proj_out = nn.Conv(
394
+ inner_dim,
395
+ kernel_size=(1, 1),
396
+ strides=(1, 1),
397
+ padding="VALID",
398
+ dtype=self.dtype,
399
+ )
400
+
401
+ self.dropout_layer = nn.Dropout(rate=self.dropout)
402
+
403
+ def __call__(self, hidden_states, context, deterministic=True):
404
+ batch, height, width, channels = hidden_states.shape
405
+ residual = hidden_states
406
+ hidden_states = self.norm(hidden_states)
407
+ if self.use_linear_projection:
408
+ hidden_states = hidden_states.reshape(batch, height * width, channels)
409
+ hidden_states = self.proj_in(hidden_states)
410
+ else:
411
+ hidden_states = self.proj_in(hidden_states)
412
+ hidden_states = hidden_states.reshape(batch, height * width, channels)
413
+
414
+ for transformer_block in self.transformer_blocks:
415
+ hidden_states = transformer_block(hidden_states, context, deterministic=deterministic)
416
+
417
+ if self.use_linear_projection:
418
+ hidden_states = self.proj_out(hidden_states)
419
+ hidden_states = hidden_states.reshape(batch, height, width, channels)
420
+ else:
421
+ hidden_states = hidden_states.reshape(batch, height, width, channels)
422
+ hidden_states = self.proj_out(hidden_states)
423
+
424
+ hidden_states = hidden_states + residual
425
+ return self.dropout_layer(hidden_states, deterministic=deterministic)
426
+
427
+
428
+ class FlaxFeedForward(nn.Module):
429
+ r"""
430
+ Flax module that encapsulates two Linear layers separated by a non-linearity. It is the counterpart of PyTorch's
431
+ [`FeedForward`] class, with the following simplifications:
432
+ - The activation function is currently hardcoded to a gated linear unit from:
433
+ https://arxiv.org/abs/2002.05202
434
+ - `dim_out` is equal to `dim`.
435
+ - The number of hidden dimensions is hardcoded to `dim * 4` in [`FlaxGELU`].
436
+
437
+ Parameters:
438
+ dim (:obj:`int`):
439
+ Inner hidden states dimension
440
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
441
+ Dropout rate
442
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
443
+ Parameters `dtype`
444
+ """
445
+ dim: int
446
+ dropout: float = 0.0
447
+ dtype: jnp.dtype = jnp.float32
448
+
449
+ def setup(self):
450
+ # The second linear layer needs to be called
451
+ # net_2 for now to match the index of the Sequential layer
452
+ self.net_0 = FlaxGEGLU(self.dim, self.dropout, self.dtype)
453
+ self.net_2 = nn.Dense(self.dim, dtype=self.dtype)
454
+
455
+ def __call__(self, hidden_states, deterministic=True):
456
+ hidden_states = self.net_0(hidden_states, deterministic=deterministic)
457
+ hidden_states = self.net_2(hidden_states)
458
+ return hidden_states
459
+
460
+
461
+ class FlaxGEGLU(nn.Module):
462
+ r"""
463
+ Flax implementation of a Linear layer followed by the variant of the gated linear unit activation function from
464
+ https://arxiv.org/abs/2002.05202.
465
+
466
+ Parameters:
467
+ dim (:obj:`int`):
468
+ Input hidden states dimension
469
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
470
+ Dropout rate
471
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
472
+ Parameters `dtype`
473
+ """
474
+ dim: int
475
+ dropout: float = 0.0
476
+ dtype: jnp.dtype = jnp.float32
477
+
478
+ def setup(self):
479
+ inner_dim = self.dim * 4
480
+ self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype)
481
+ self.dropout_layer = nn.Dropout(rate=self.dropout)
482
+
483
+ def __call__(self, hidden_states, deterministic=True):
484
+ hidden_states = self.proj(hidden_states)
485
+ hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2)
486
+ return self.dropout_layer(hidden_linear * nn.gelu(hidden_gelu), deterministic=deterministic)
diffusers/models/attention_processor.py ADDED
@@ -0,0 +1,2020 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from importlib import import_module
15
+ from typing import Callable, Optional, Union
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torch import nn
20
+
21
+ from ..utils import USE_PEFT_BACKEND, deprecate, logging
22
+ from ..utils.import_utils import is_xformers_available
23
+ from ..utils.torch_utils import maybe_allow_in_graph
24
+ from .lora import LoRACompatibleLinear, LoRALinearLayer
25
+
26
+
27
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
28
+
29
+
30
+ if is_xformers_available():
31
+ import xformers
32
+ import xformers.ops
33
+ else:
34
+ xformers = None
35
+
36
+
37
+ @maybe_allow_in_graph
38
+ class Attention(nn.Module):
39
+ r"""
40
+ A cross attention layer.
41
+
42
+ Parameters:
43
+ query_dim (`int`):
44
+ The number of channels in the query.
45
+ cross_attention_dim (`int`, *optional*):
46
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
47
+ heads (`int`, *optional*, defaults to 8):
48
+ The number of heads to use for multi-head attention.
49
+ dim_head (`int`, *optional*, defaults to 64):
50
+ The number of channels in each head.
51
+ dropout (`float`, *optional*, defaults to 0.0):
52
+ The dropout probability to use.
53
+ bias (`bool`, *optional*, defaults to False):
54
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
55
+ upcast_attention (`bool`, *optional*, defaults to False):
56
+ Set to `True` to upcast the attention computation to `float32`.
57
+ upcast_softmax (`bool`, *optional*, defaults to False):
58
+ Set to `True` to upcast the softmax computation to `float32`.
59
+ cross_attention_norm (`str`, *optional*, defaults to `None`):
60
+ The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`.
61
+ cross_attention_norm_num_groups (`int`, *optional*, defaults to 32):
62
+ The number of groups to use for the group norm in the cross attention.
63
+ added_kv_proj_dim (`int`, *optional*, defaults to `None`):
64
+ The number of channels to use for the added key and value projections. If `None`, no projection is used.
65
+ norm_num_groups (`int`, *optional*, defaults to `None`):
66
+ The number of groups to use for the group norm in the attention.
67
+ spatial_norm_dim (`int`, *optional*, defaults to `None`):
68
+ The number of channels to use for the spatial normalization.
69
+ out_bias (`bool`, *optional*, defaults to `True`):
70
+ Set to `True` to use a bias in the output linear layer.
71
+ scale_qk (`bool`, *optional*, defaults to `True`):
72
+ Set to `True` to scale the query and key by `1 / sqrt(dim_head)`.
73
+ only_cross_attention (`bool`, *optional*, defaults to `False`):
74
+ Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if
75
+ `added_kv_proj_dim` is not `None`.
76
+ eps (`float`, *optional*, defaults to 1e-5):
77
+ An additional value added to the denominator in group normalization that is used for numerical stability.
78
+ rescale_output_factor (`float`, *optional*, defaults to 1.0):
79
+ A factor to rescale the output by dividing it with this value.
80
+ residual_connection (`bool`, *optional*, defaults to `False`):
81
+ Set to `True` to add the residual connection to the output.
82
+ _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`):
83
+ Set to `True` if the attention block is loaded from a deprecated state dict.
84
+ processor (`AttnProcessor`, *optional*, defaults to `None`):
85
+ The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and
86
+ `AttnProcessor` otherwise.
87
+ """
88
+
89
+ def __init__(
90
+ self,
91
+ query_dim: int,
92
+ cross_attention_dim: Optional[int] = None,
93
+ heads: int = 8,
94
+ dim_head: int = 64,
95
+ dropout: float = 0.0,
96
+ bias: bool = False,
97
+ upcast_attention: bool = False,
98
+ upcast_softmax: bool = False,
99
+ cross_attention_norm: Optional[str] = None,
100
+ cross_attention_norm_num_groups: int = 32,
101
+ added_kv_proj_dim: Optional[int] = None,
102
+ norm_num_groups: Optional[int] = None,
103
+ spatial_norm_dim: Optional[int] = None,
104
+ out_bias: bool = True,
105
+ scale_qk: bool = True,
106
+ only_cross_attention: bool = False,
107
+ eps: float = 1e-5,
108
+ rescale_output_factor: float = 1.0,
109
+ residual_connection: bool = False,
110
+ _from_deprecated_attn_block: bool = False,
111
+ processor: Optional["AttnProcessor"] = None,
112
+ ):
113
+ super().__init__()
114
+ self.inner_dim = dim_head * heads
115
+ self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
116
+ self.upcast_attention = upcast_attention
117
+ self.upcast_softmax = upcast_softmax
118
+ self.rescale_output_factor = rescale_output_factor
119
+ self.residual_connection = residual_connection
120
+ self.dropout = dropout
121
+
122
+ # we make use of this private variable to know whether this class is loaded
123
+ # with an deprecated state dict so that we can convert it on the fly
124
+ self._from_deprecated_attn_block = _from_deprecated_attn_block
125
+
126
+ self.scale_qk = scale_qk
127
+ self.scale = dim_head**-0.5 if self.scale_qk else 1.0
128
+
129
+ self.heads = heads
130
+ # for slice_size > 0 the attention score computation
131
+ # is split across the batch axis to save memory
132
+ # You can set slice_size with `set_attention_slice`
133
+ self.sliceable_head_dim = heads
134
+
135
+ self.added_kv_proj_dim = added_kv_proj_dim
136
+ self.only_cross_attention = only_cross_attention
137
+
138
+ if self.added_kv_proj_dim is None and self.only_cross_attention:
139
+ raise ValueError(
140
+ "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
141
+ )
142
+
143
+ if norm_num_groups is not None:
144
+ self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)
145
+ else:
146
+ self.group_norm = None
147
+
148
+ if spatial_norm_dim is not None:
149
+ self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim)
150
+ else:
151
+ self.spatial_norm = None
152
+
153
+ if cross_attention_norm is None:
154
+ self.norm_cross = None
155
+ elif cross_attention_norm == "layer_norm":
156
+ self.norm_cross = nn.LayerNorm(self.cross_attention_dim)
157
+ elif cross_attention_norm == "group_norm":
158
+ if self.added_kv_proj_dim is not None:
159
+ # The given `encoder_hidden_states` are initially of shape
160
+ # (batch_size, seq_len, added_kv_proj_dim) before being projected
161
+ # to (batch_size, seq_len, cross_attention_dim). The norm is applied
162
+ # before the projection, so we need to use `added_kv_proj_dim` as
163
+ # the number of channels for the group norm.
164
+ norm_cross_num_channels = added_kv_proj_dim
165
+ else:
166
+ norm_cross_num_channels = self.cross_attention_dim
167
+
168
+ self.norm_cross = nn.GroupNorm(
169
+ num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True
170
+ )
171
+ else:
172
+ raise ValueError(
173
+ f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
174
+ )
175
+
176
+ if USE_PEFT_BACKEND:
177
+ linear_cls = nn.Linear
178
+ else:
179
+ linear_cls = LoRACompatibleLinear
180
+
181
+ self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)
182
+
183
+ if not self.only_cross_attention:
184
+ # only relevant for the `AddedKVProcessor` classes
185
+ self.to_k = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
186
+ self.to_v = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
187
+ else:
188
+ self.to_k = None
189
+ self.to_v = None
190
+
191
+ if self.added_kv_proj_dim is not None:
192
+ self.add_k_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
193
+ self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
194
+
195
+ self.to_out = nn.ModuleList([])
196
+ self.to_out.append(linear_cls(self.inner_dim, query_dim, bias=out_bias))
197
+ self.to_out.append(nn.Dropout(dropout))
198
+
199
+ # set attention processor
200
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
201
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
202
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
203
+ if processor is None:
204
+ processor = (
205
+ AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
206
+ )
207
+ self.set_processor(processor)
208
+
209
+ def set_use_memory_efficient_attention_xformers(
210
+ self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
211
+ ) -> None:
212
+ r"""
213
+ Set whether to use memory efficient attention from `xformers` or not.
214
+
215
+ Args:
216
+ use_memory_efficient_attention_xformers (`bool`):
217
+ Whether to use memory efficient attention from `xformers` or not.
218
+ attention_op (`Callable`, *optional*):
219
+ The attention operation to use. Defaults to `None` which uses the default attention operation from
220
+ `xformers`.
221
+ """
222
+ is_lora = hasattr(self, "processor") and isinstance(
223
+ self.processor,
224
+ LORA_ATTENTION_PROCESSORS,
225
+ )
226
+ is_custom_diffusion = hasattr(self, "processor") and isinstance(
227
+ self.processor,
228
+ (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessor2_0),
229
+ )
230
+ is_added_kv_processor = hasattr(self, "processor") and isinstance(
231
+ self.processor,
232
+ (
233
+ AttnAddedKVProcessor,
234
+ AttnAddedKVProcessor2_0,
235
+ SlicedAttnAddedKVProcessor,
236
+ XFormersAttnAddedKVProcessor,
237
+ LoRAAttnAddedKVProcessor,
238
+ ),
239
+ )
240
+
241
+ if use_memory_efficient_attention_xformers:
242
+ if is_added_kv_processor and (is_lora or is_custom_diffusion):
243
+ raise NotImplementedError(
244
+ f"Memory efficient attention is currently not supported for LoRA or custom diffusion for attention processor type {self.processor}"
245
+ )
246
+ if not is_xformers_available():
247
+ raise ModuleNotFoundError(
248
+ (
249
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
250
+ " xformers"
251
+ ),
252
+ name="xformers",
253
+ )
254
+ elif not torch.cuda.is_available():
255
+ raise ValueError(
256
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
257
+ " only available for GPU "
258
+ )
259
+ else:
260
+ try:
261
+ # Make sure we can run the memory efficient attention
262
+ _ = xformers.ops.memory_efficient_attention(
263
+ torch.randn((1, 2, 40), device="cuda"),
264
+ torch.randn((1, 2, 40), device="cuda"),
265
+ torch.randn((1, 2, 40), device="cuda"),
266
+ )
267
+ except Exception as e:
268
+ raise e
269
+
270
+ if is_lora:
271
+ # TODO (sayakpaul): should we throw a warning if someone wants to use the xformers
272
+ # variant when using PT 2.0 now that we have LoRAAttnProcessor2_0?
273
+ processor = LoRAXFormersAttnProcessor(
274
+ hidden_size=self.processor.hidden_size,
275
+ cross_attention_dim=self.processor.cross_attention_dim,
276
+ rank=self.processor.rank,
277
+ attention_op=attention_op,
278
+ )
279
+ processor.load_state_dict(self.processor.state_dict())
280
+ processor.to(self.processor.to_q_lora.up.weight.device)
281
+ elif is_custom_diffusion:
282
+ processor = CustomDiffusionXFormersAttnProcessor(
283
+ train_kv=self.processor.train_kv,
284
+ train_q_out=self.processor.train_q_out,
285
+ hidden_size=self.processor.hidden_size,
286
+ cross_attention_dim=self.processor.cross_attention_dim,
287
+ attention_op=attention_op,
288
+ )
289
+ processor.load_state_dict(self.processor.state_dict())
290
+ if hasattr(self.processor, "to_k_custom_diffusion"):
291
+ processor.to(self.processor.to_k_custom_diffusion.weight.device)
292
+ elif is_added_kv_processor:
293
+ # TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
294
+ # which uses this type of cross attention ONLY because the attention mask of format
295
+ # [0, ..., -10.000, ..., 0, ...,] is not supported
296
+ # throw warning
297
+ logger.info(
298
+ "Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation."
299
+ )
300
+ processor = XFormersAttnAddedKVProcessor(attention_op=attention_op)
301
+ else:
302
+ processor = XFormersAttnProcessor(attention_op=attention_op)
303
+ else:
304
+ if is_lora:
305
+ attn_processor_class = (
306
+ LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
307
+ )
308
+ processor = attn_processor_class(
309
+ hidden_size=self.processor.hidden_size,
310
+ cross_attention_dim=self.processor.cross_attention_dim,
311
+ rank=self.processor.rank,
312
+ )
313
+ processor.load_state_dict(self.processor.state_dict())
314
+ processor.to(self.processor.to_q_lora.up.weight.device)
315
+ elif is_custom_diffusion:
316
+ attn_processor_class = (
317
+ CustomDiffusionAttnProcessor2_0
318
+ if hasattr(F, "scaled_dot_product_attention")
319
+ else CustomDiffusionAttnProcessor
320
+ )
321
+ processor = attn_processor_class(
322
+ train_kv=self.processor.train_kv,
323
+ train_q_out=self.processor.train_q_out,
324
+ hidden_size=self.processor.hidden_size,
325
+ cross_attention_dim=self.processor.cross_attention_dim,
326
+ )
327
+ processor.load_state_dict(self.processor.state_dict())
328
+ if hasattr(self.processor, "to_k_custom_diffusion"):
329
+ processor.to(self.processor.to_k_custom_diffusion.weight.device)
330
+ else:
331
+ # set attention processor
332
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
333
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
334
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
335
+ processor = (
336
+ AttnProcessor2_0()
337
+ if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
338
+ else AttnProcessor()
339
+ )
340
+
341
+ self.set_processor(processor)
342
+
343
+ def set_attention_slice(self, slice_size: int) -> None:
344
+ r"""
345
+ Set the slice size for attention computation.
346
+
347
+ Args:
348
+ slice_size (`int`):
349
+ The slice size for attention computation.
350
+ """
351
+ if slice_size is not None and slice_size > self.sliceable_head_dim:
352
+ raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
353
+
354
+ if slice_size is not None and self.added_kv_proj_dim is not None:
355
+ processor = SlicedAttnAddedKVProcessor(slice_size)
356
+ elif slice_size is not None:
357
+ processor = SlicedAttnProcessor(slice_size)
358
+ elif self.added_kv_proj_dim is not None:
359
+ processor = AttnAddedKVProcessor()
360
+ else:
361
+ # set attention processor
362
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
363
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
364
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
365
+ processor = (
366
+ AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
367
+ )
368
+
369
+ self.set_processor(processor)
370
+
371
+ def set_processor(self, processor: "AttnProcessor", _remove_lora: bool = False) -> None:
372
+ r"""
373
+ Set the attention processor to use.
374
+
375
+ Args:
376
+ processor (`AttnProcessor`):
377
+ The attention processor to use.
378
+ _remove_lora (`bool`, *optional*, defaults to `False`):
379
+ Set to `True` to remove LoRA layers from the model.
380
+ """
381
+ if not USE_PEFT_BACKEND and hasattr(self, "processor") and _remove_lora and self.to_q.lora_layer is not None:
382
+ deprecate(
383
+ "set_processor to offload LoRA",
384
+ "0.26.0",
385
+ "In detail, removing LoRA layers via calling `set_default_attn_processor` is deprecated. Please make sure to call `pipe.unload_lora_weights()` instead.",
386
+ )
387
+ # TODO(Patrick, Sayak) - this can be deprecated once PEFT LoRA integration is complete
388
+ # We need to remove all LoRA layers
389
+ # Don't forget to remove ALL `_remove_lora` from the codebase
390
+ for module in self.modules():
391
+ if hasattr(module, "set_lora_layer"):
392
+ module.set_lora_layer(None)
393
+
394
+ # if current processor is in `self._modules` and if passed `processor` is not, we need to
395
+ # pop `processor` from `self._modules`
396
+ if (
397
+ hasattr(self, "processor")
398
+ and isinstance(self.processor, torch.nn.Module)
399
+ and not isinstance(processor, torch.nn.Module)
400
+ ):
401
+ logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
402
+ self._modules.pop("processor")
403
+
404
+ self.processor = processor
405
+
406
+ def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor":
407
+ r"""
408
+ Get the attention processor in use.
409
+
410
+ Args:
411
+ return_deprecated_lora (`bool`, *optional*, defaults to `False`):
412
+ Set to `True` to return the deprecated LoRA attention processor.
413
+
414
+ Returns:
415
+ "AttentionProcessor": The attention processor in use.
416
+ """
417
+ if not return_deprecated_lora:
418
+ return self.processor
419
+
420
+ # TODO(Sayak, Patrick). The rest of the function is needed to ensure backwards compatible
421
+ # serialization format for LoRA Attention Processors. It should be deleted once the integration
422
+ # with PEFT is completed.
423
+ is_lora_activated = {
424
+ name: module.lora_layer is not None
425
+ for name, module in self.named_modules()
426
+ if hasattr(module, "lora_layer")
427
+ }
428
+
429
+ # 1. if no layer has a LoRA activated we can return the processor as usual
430
+ if not any(is_lora_activated.values()):
431
+ return self.processor
432
+
433
+ # If doesn't apply LoRA do `add_k_proj` or `add_v_proj`
434
+ is_lora_activated.pop("add_k_proj", None)
435
+ is_lora_activated.pop("add_v_proj", None)
436
+ # 2. else it is not posssible that only some layers have LoRA activated
437
+ if not all(is_lora_activated.values()):
438
+ raise ValueError(
439
+ f"Make sure that either all layers or no layers have LoRA activated, but have {is_lora_activated}"
440
+ )
441
+
442
+ # 3. And we need to merge the current LoRA layers into the corresponding LoRA attention processor
443
+ non_lora_processor_cls_name = self.processor.__class__.__name__
444
+ lora_processor_cls = getattr(import_module(__name__), "LoRA" + non_lora_processor_cls_name)
445
+
446
+ hidden_size = self.inner_dim
447
+
448
+ # now create a LoRA attention processor from the LoRA layers
449
+ if lora_processor_cls in [LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor]:
450
+ kwargs = {
451
+ "cross_attention_dim": self.cross_attention_dim,
452
+ "rank": self.to_q.lora_layer.rank,
453
+ "network_alpha": self.to_q.lora_layer.network_alpha,
454
+ "q_rank": self.to_q.lora_layer.rank,
455
+ "q_hidden_size": self.to_q.lora_layer.out_features,
456
+ "k_rank": self.to_k.lora_layer.rank,
457
+ "k_hidden_size": self.to_k.lora_layer.out_features,
458
+ "v_rank": self.to_v.lora_layer.rank,
459
+ "v_hidden_size": self.to_v.lora_layer.out_features,
460
+ "out_rank": self.to_out[0].lora_layer.rank,
461
+ "out_hidden_size": self.to_out[0].lora_layer.out_features,
462
+ }
463
+
464
+ if hasattr(self.processor, "attention_op"):
465
+ kwargs["attention_op"] = self.processor.attention_op
466
+
467
+ lora_processor = lora_processor_cls(hidden_size, **kwargs)
468
+ lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
469
+ lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
470
+ lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
471
+ lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict())
472
+ elif lora_processor_cls == LoRAAttnAddedKVProcessor:
473
+ lora_processor = lora_processor_cls(
474
+ hidden_size,
475
+ cross_attention_dim=self.add_k_proj.weight.shape[0],
476
+ rank=self.to_q.lora_layer.rank,
477
+ network_alpha=self.to_q.lora_layer.network_alpha,
478
+ )
479
+ lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
480
+ lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
481
+ lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
482
+ lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict())
483
+
484
+ # only save if used
485
+ if self.add_k_proj.lora_layer is not None:
486
+ lora_processor.add_k_proj_lora.load_state_dict(self.add_k_proj.lora_layer.state_dict())
487
+ lora_processor.add_v_proj_lora.load_state_dict(self.add_v_proj.lora_layer.state_dict())
488
+ else:
489
+ lora_processor.add_k_proj_lora = None
490
+ lora_processor.add_v_proj_lora = None
491
+ else:
492
+ raise ValueError(f"{lora_processor_cls} does not exist.")
493
+
494
+ return lora_processor
495
+
496
+ def forward(
497
+ self,
498
+ hidden_states: torch.FloatTensor,
499
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
500
+ attention_mask: Optional[torch.FloatTensor] = None,
501
+ **cross_attention_kwargs,
502
+ ) -> torch.Tensor:
503
+ r"""
504
+ The forward method of the `Attention` class.
505
+
506
+ Args:
507
+ hidden_states (`torch.Tensor`):
508
+ The hidden states of the query.
509
+ encoder_hidden_states (`torch.Tensor`, *optional*):
510
+ The hidden states of the encoder.
511
+ attention_mask (`torch.Tensor`, *optional*):
512
+ The attention mask to use. If `None`, no mask is applied.
513
+ **cross_attention_kwargs:
514
+ Additional keyword arguments to pass along to the cross attention.
515
+
516
+ Returns:
517
+ `torch.Tensor`: The output of the attention layer.
518
+ """
519
+ # The `Attention` class can call different attention processors / attention functions
520
+ # here we simply pass along all tensors to the selected processor class
521
+ # For standard processors that are defined here, `**cross_attention_kwargs` is empty
522
+ return self.processor(
523
+ self,
524
+ hidden_states,
525
+ encoder_hidden_states=encoder_hidden_states,
526
+ attention_mask=attention_mask,
527
+ **cross_attention_kwargs,
528
+ )
529
+
530
+ def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
531
+ r"""
532
+ Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads`
533
+ is the number of heads initialized while constructing the `Attention` class.
534
+
535
+ Args:
536
+ tensor (`torch.Tensor`): The tensor to reshape.
537
+
538
+ Returns:
539
+ `torch.Tensor`: The reshaped tensor.
540
+ """
541
+ head_size = self.heads
542
+ batch_size, seq_len, dim = tensor.shape
543
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
544
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
545
+ return tensor
546
+
547
+ def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
548
+ r"""
549
+ Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is
550
+ the number of heads initialized while constructing the `Attention` class.
551
+
552
+ Args:
553
+ tensor (`torch.Tensor`): The tensor to reshape.
554
+ out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is
555
+ reshaped to `[batch_size * heads, seq_len, dim // heads]`.
556
+
557
+ Returns:
558
+ `torch.Tensor`: The reshaped tensor.
559
+ """
560
+ head_size = self.heads
561
+ batch_size, seq_len, dim = tensor.shape
562
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
563
+ tensor = tensor.permute(0, 2, 1, 3)
564
+
565
+ if out_dim == 3:
566
+ tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
567
+
568
+ return tensor
569
+
570
+ def get_attention_scores(
571
+ self, query: torch.Tensor, key: torch.Tensor, attention_mask: torch.Tensor = None
572
+ ) -> torch.Tensor:
573
+ r"""
574
+ Compute the attention scores.
575
+
576
+ Args:
577
+ query (`torch.Tensor`): The query tensor.
578
+ key (`torch.Tensor`): The key tensor.
579
+ attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied.
580
+
581
+ Returns:
582
+ `torch.Tensor`: The attention probabilities/scores.
583
+ """
584
+ dtype = query.dtype
585
+ if self.upcast_attention:
586
+ query = query.float()
587
+ key = key.float()
588
+
589
+ if attention_mask is None:
590
+ baddbmm_input = torch.empty(
591
+ query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
592
+ )
593
+ beta = 0
594
+ else:
595
+ baddbmm_input = attention_mask
596
+ beta = 1
597
+
598
+ attention_scores = torch.baddbmm(
599
+ baddbmm_input,
600
+ query,
601
+ key.transpose(-1, -2),
602
+ beta=beta,
603
+ alpha=self.scale,
604
+ )
605
+ del baddbmm_input
606
+
607
+ if self.upcast_softmax:
608
+ attention_scores = attention_scores.float()
609
+
610
+ attention_probs = attention_scores.softmax(dim=-1)
611
+ del attention_scores
612
+
613
+ attention_probs = attention_probs.to(dtype)
614
+
615
+ return attention_probs
616
+
617
+ def prepare_attention_mask(
618
+ self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3
619
+ ) -> torch.Tensor:
620
+ r"""
621
+ Prepare the attention mask for the attention computation.
622
+
623
+ Args:
624
+ attention_mask (`torch.Tensor`):
625
+ The attention mask to prepare.
626
+ target_length (`int`):
627
+ The target length of the attention mask. This is the length of the attention mask after padding.
628
+ batch_size (`int`):
629
+ The batch size, which is used to repeat the attention mask.
630
+ out_dim (`int`, *optional*, defaults to `3`):
631
+ The output dimension of the attention mask. Can be either `3` or `4`.
632
+
633
+ Returns:
634
+ `torch.Tensor`: The prepared attention mask.
635
+ """
636
+ head_size = self.heads
637
+ if attention_mask is None:
638
+ return attention_mask
639
+
640
+ current_length: int = attention_mask.shape[-1]
641
+ if current_length != target_length:
642
+ if attention_mask.device.type == "mps":
643
+ # HACK: MPS: Does not support padding by greater than dimension of input tensor.
644
+ # Instead, we can manually construct the padding tensor.
645
+ padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
646
+ padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
647
+ attention_mask = torch.cat([attention_mask, padding], dim=2)
648
+ else:
649
+ # TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
650
+ # we want to instead pad by (0, remaining_length), where remaining_length is:
651
+ # remaining_length: int = target_length - current_length
652
+ # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
653
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
654
+
655
+ if out_dim == 3:
656
+ if attention_mask.shape[0] < batch_size * head_size:
657
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
658
+ elif out_dim == 4:
659
+ attention_mask = attention_mask.unsqueeze(1)
660
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
661
+
662
+ return attention_mask
663
+
664
+ def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
665
+ r"""
666
+ Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the
667
+ `Attention` class.
668
+
669
+ Args:
670
+ encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder.
671
+
672
+ Returns:
673
+ `torch.Tensor`: The normalized encoder hidden states.
674
+ """
675
+ assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
676
+
677
+ if isinstance(self.norm_cross, nn.LayerNorm):
678
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
679
+ elif isinstance(self.norm_cross, nn.GroupNorm):
680
+ # Group norm norms along the channels dimension and expects
681
+ # input to be in the shape of (N, C, *). In this case, we want
682
+ # to norm along the hidden dimension, so we need to move
683
+ # (batch_size, sequence_length, hidden_size) ->
684
+ # (batch_size, hidden_size, sequence_length)
685
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
686
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
687
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
688
+ else:
689
+ assert False
690
+
691
+ return encoder_hidden_states
692
+
693
+
694
+ class AttnProcessor:
695
+ r"""
696
+ Default processor for performing attention-related computations.
697
+ """
698
+
699
+ def __call__(
700
+ self,
701
+ attn: Attention,
702
+ hidden_states: torch.FloatTensor,
703
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
704
+ attention_mask: Optional[torch.FloatTensor] = None,
705
+ temb: Optional[torch.FloatTensor] = None,
706
+ scale: float = 1.0,
707
+ ) -> torch.Tensor:
708
+ residual = hidden_states
709
+
710
+ args = () if USE_PEFT_BACKEND else (scale,)
711
+
712
+ if attn.spatial_norm is not None:
713
+ hidden_states = attn.spatial_norm(hidden_states, temb)
714
+
715
+ input_ndim = hidden_states.ndim
716
+
717
+ if input_ndim == 4:
718
+ batch_size, channel, height, width = hidden_states.shape
719
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
720
+
721
+ batch_size, sequence_length, _ = (
722
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
723
+ )
724
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
725
+
726
+ if attn.group_norm is not None:
727
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
728
+
729
+ query = attn.to_q(hidden_states, *args)
730
+
731
+ if encoder_hidden_states is None:
732
+ encoder_hidden_states = hidden_states
733
+ elif attn.norm_cross:
734
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
735
+
736
+ key = attn.to_k(encoder_hidden_states, *args)
737
+ value = attn.to_v(encoder_hidden_states, *args)
738
+
739
+ query = attn.head_to_batch_dim(query)
740
+ key = attn.head_to_batch_dim(key)
741
+ value = attn.head_to_batch_dim(value)
742
+
743
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
744
+ hidden_states = torch.bmm(attention_probs, value)
745
+ hidden_states = attn.batch_to_head_dim(hidden_states)
746
+
747
+ # linear proj
748
+ hidden_states = attn.to_out[0](hidden_states, *args)
749
+ # dropout
750
+ hidden_states = attn.to_out[1](hidden_states)
751
+
752
+ if input_ndim == 4:
753
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
754
+
755
+ if attn.residual_connection:
756
+ hidden_states = hidden_states + residual
757
+
758
+ hidden_states = hidden_states / attn.rescale_output_factor
759
+
760
+ return hidden_states
761
+
762
+
763
+ class CustomDiffusionAttnProcessor(nn.Module):
764
+ r"""
765
+ Processor for implementing attention for the Custom Diffusion method.
766
+
767
+ Args:
768
+ train_kv (`bool`, defaults to `True`):
769
+ Whether to newly train the key and value matrices corresponding to the text features.
770
+ train_q_out (`bool`, defaults to `True`):
771
+ Whether to newly train query matrices corresponding to the latent image features.
772
+ hidden_size (`int`, *optional*, defaults to `None`):
773
+ The hidden size of the attention layer.
774
+ cross_attention_dim (`int`, *optional*, defaults to `None`):
775
+ The number of channels in the `encoder_hidden_states`.
776
+ out_bias (`bool`, defaults to `True`):
777
+ Whether to include the bias parameter in `train_q_out`.
778
+ dropout (`float`, *optional*, defaults to 0.0):
779
+ The dropout probability to use.
780
+ """
781
+
782
+ def __init__(
783
+ self,
784
+ train_kv: bool = True,
785
+ train_q_out: bool = True,
786
+ hidden_size: Optional[int] = None,
787
+ cross_attention_dim: Optional[int] = None,
788
+ out_bias: bool = True,
789
+ dropout: float = 0.0,
790
+ ):
791
+ super().__init__()
792
+ self.train_kv = train_kv
793
+ self.train_q_out = train_q_out
794
+
795
+ self.hidden_size = hidden_size
796
+ self.cross_attention_dim = cross_attention_dim
797
+
798
+ # `_custom_diffusion` id for easy serialization and loading.
799
+ if self.train_kv:
800
+ self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
801
+ self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
802
+ if self.train_q_out:
803
+ self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
804
+ self.to_out_custom_diffusion = nn.ModuleList([])
805
+ self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
806
+ self.to_out_custom_diffusion.append(nn.Dropout(dropout))
807
+
808
+ def __call__(
809
+ self,
810
+ attn: Attention,
811
+ hidden_states: torch.FloatTensor,
812
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
813
+ attention_mask: Optional[torch.FloatTensor] = None,
814
+ ) -> torch.Tensor:
815
+ batch_size, sequence_length, _ = hidden_states.shape
816
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
817
+ if self.train_q_out:
818
+ query = self.to_q_custom_diffusion(hidden_states).to(attn.to_q.weight.dtype)
819
+ else:
820
+ query = attn.to_q(hidden_states.to(attn.to_q.weight.dtype))
821
+
822
+ if encoder_hidden_states is None:
823
+ crossattn = False
824
+ encoder_hidden_states = hidden_states
825
+ else:
826
+ crossattn = True
827
+ if attn.norm_cross:
828
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
829
+
830
+ if self.train_kv:
831
+ key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype))
832
+ value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype))
833
+ key = key.to(attn.to_q.weight.dtype)
834
+ value = value.to(attn.to_q.weight.dtype)
835
+ else:
836
+ key = attn.to_k(encoder_hidden_states)
837
+ value = attn.to_v(encoder_hidden_states)
838
+
839
+ if crossattn:
840
+ detach = torch.ones_like(key)
841
+ detach[:, :1, :] = detach[:, :1, :] * 0.0
842
+ key = detach * key + (1 - detach) * key.detach()
843
+ value = detach * value + (1 - detach) * value.detach()
844
+
845
+ query = attn.head_to_batch_dim(query)
846
+ key = attn.head_to_batch_dim(key)
847
+ value = attn.head_to_batch_dim(value)
848
+
849
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
850
+ hidden_states = torch.bmm(attention_probs, value)
851
+ hidden_states = attn.batch_to_head_dim(hidden_states)
852
+
853
+ if self.train_q_out:
854
+ # linear proj
855
+ hidden_states = self.to_out_custom_diffusion[0](hidden_states)
856
+ # dropout
857
+ hidden_states = self.to_out_custom_diffusion[1](hidden_states)
858
+ else:
859
+ # linear proj
860
+ hidden_states = attn.to_out[0](hidden_states)
861
+ # dropout
862
+ hidden_states = attn.to_out[1](hidden_states)
863
+
864
+ return hidden_states
865
+
866
+
867
+ class AttnAddedKVProcessor:
868
+ r"""
869
+ Processor for performing attention-related computations with extra learnable key and value matrices for the text
870
+ encoder.
871
+ """
872
+
873
+ def __call__(
874
+ self,
875
+ attn: Attention,
876
+ hidden_states: torch.FloatTensor,
877
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
878
+ attention_mask: Optional[torch.FloatTensor] = None,
879
+ scale: float = 1.0,
880
+ ) -> torch.Tensor:
881
+ residual = hidden_states
882
+
883
+ args = () if USE_PEFT_BACKEND else (scale,)
884
+
885
+ hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
886
+ batch_size, sequence_length, _ = hidden_states.shape
887
+
888
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
889
+
890
+ if encoder_hidden_states is None:
891
+ encoder_hidden_states = hidden_states
892
+ elif attn.norm_cross:
893
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
894
+
895
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
896
+
897
+ query = attn.to_q(hidden_states, *args)
898
+ query = attn.head_to_batch_dim(query)
899
+
900
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states, *args)
901
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states, *args)
902
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
903
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
904
+
905
+ if not attn.only_cross_attention:
906
+ key = attn.to_k(hidden_states, *args)
907
+ value = attn.to_v(hidden_states, *args)
908
+ key = attn.head_to_batch_dim(key)
909
+ value = attn.head_to_batch_dim(value)
910
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
911
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
912
+ else:
913
+ key = encoder_hidden_states_key_proj
914
+ value = encoder_hidden_states_value_proj
915
+
916
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
917
+ hidden_states = torch.bmm(attention_probs, value)
918
+ hidden_states = attn.batch_to_head_dim(hidden_states)
919
+
920
+ # linear proj
921
+ hidden_states = attn.to_out[0](hidden_states, *args)
922
+ # dropout
923
+ hidden_states = attn.to_out[1](hidden_states)
924
+
925
+ hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
926
+ hidden_states = hidden_states + residual
927
+
928
+ return hidden_states
929
+
930
+
931
+ class AttnAddedKVProcessor2_0:
932
+ r"""
933
+ Processor for performing scaled dot-product attention (enabled by default if you're using PyTorch 2.0), with extra
934
+ learnable key and value matrices for the text encoder.
935
+ """
936
+
937
+ def __init__(self):
938
+ if not hasattr(F, "scaled_dot_product_attention"):
939
+ raise ImportError(
940
+ "AttnAddedKVProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
941
+ )
942
+
943
+ def __call__(
944
+ self,
945
+ attn: Attention,
946
+ hidden_states: torch.FloatTensor,
947
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
948
+ attention_mask: Optional[torch.FloatTensor] = None,
949
+ scale: float = 1.0,
950
+ ) -> torch.Tensor:
951
+ residual = hidden_states
952
+
953
+ args = () if USE_PEFT_BACKEND else (scale,)
954
+
955
+ hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
956
+ batch_size, sequence_length, _ = hidden_states.shape
957
+
958
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size, out_dim=4)
959
+
960
+ if encoder_hidden_states is None:
961
+ encoder_hidden_states = hidden_states
962
+ elif attn.norm_cross:
963
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
964
+
965
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
966
+
967
+ query = attn.to_q(hidden_states, *args)
968
+ query = attn.head_to_batch_dim(query, out_dim=4)
969
+
970
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
971
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
972
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj, out_dim=4)
973
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4)
974
+
975
+ if not attn.only_cross_attention:
976
+ key = attn.to_k(hidden_states, *args)
977
+ value = attn.to_v(hidden_states, *args)
978
+ key = attn.head_to_batch_dim(key, out_dim=4)
979
+ value = attn.head_to_batch_dim(value, out_dim=4)
980
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
981
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
982
+ else:
983
+ key = encoder_hidden_states_key_proj
984
+ value = encoder_hidden_states_value_proj
985
+
986
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
987
+ # TODO: add support for attn.scale when we move to Torch 2.1
988
+ hidden_states = F.scaled_dot_product_attention(
989
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
990
+ )
991
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1])
992
+
993
+ # linear proj
994
+ hidden_states = attn.to_out[0](hidden_states, *args)
995
+ # dropout
996
+ hidden_states = attn.to_out[1](hidden_states)
997
+
998
+ hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
999
+ hidden_states = hidden_states + residual
1000
+
1001
+ return hidden_states
1002
+
1003
+
1004
+ class XFormersAttnAddedKVProcessor:
1005
+ r"""
1006
+ Processor for implementing memory efficient attention using xFormers.
1007
+
1008
+ Args:
1009
+ attention_op (`Callable`, *optional*, defaults to `None`):
1010
+ The base
1011
+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
1012
+ use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
1013
+ operator.
1014
+ """
1015
+
1016
+ def __init__(self, attention_op: Optional[Callable] = None):
1017
+ self.attention_op = attention_op
1018
+
1019
+ def __call__(
1020
+ self,
1021
+ attn: Attention,
1022
+ hidden_states: torch.FloatTensor,
1023
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1024
+ attention_mask: Optional[torch.FloatTensor] = None,
1025
+ ) -> torch.Tensor:
1026
+ residual = hidden_states
1027
+ hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
1028
+ batch_size, sequence_length, _ = hidden_states.shape
1029
+
1030
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1031
+
1032
+ if encoder_hidden_states is None:
1033
+ encoder_hidden_states = hidden_states
1034
+ elif attn.norm_cross:
1035
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1036
+
1037
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1038
+
1039
+ query = attn.to_q(hidden_states)
1040
+ query = attn.head_to_batch_dim(query)
1041
+
1042
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
1043
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
1044
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
1045
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
1046
+
1047
+ if not attn.only_cross_attention:
1048
+ key = attn.to_k(hidden_states)
1049
+ value = attn.to_v(hidden_states)
1050
+ key = attn.head_to_batch_dim(key)
1051
+ value = attn.head_to_batch_dim(value)
1052
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
1053
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
1054
+ else:
1055
+ key = encoder_hidden_states_key_proj
1056
+ value = encoder_hidden_states_value_proj
1057
+
1058
+ hidden_states = xformers.ops.memory_efficient_attention(
1059
+ query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
1060
+ )
1061
+ hidden_states = hidden_states.to(query.dtype)
1062
+ hidden_states = attn.batch_to_head_dim(hidden_states)
1063
+
1064
+ # linear proj
1065
+ hidden_states = attn.to_out[0](hidden_states)
1066
+ # dropout
1067
+ hidden_states = attn.to_out[1](hidden_states)
1068
+
1069
+ hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
1070
+ hidden_states = hidden_states + residual
1071
+
1072
+ return hidden_states
1073
+
1074
+
1075
+ class XFormersAttnProcessor:
1076
+ r"""
1077
+ Processor for implementing memory efficient attention using xFormers.
1078
+
1079
+ Args:
1080
+ attention_op (`Callable`, *optional*, defaults to `None`):
1081
+ The base
1082
+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
1083
+ use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
1084
+ operator.
1085
+ """
1086
+
1087
+ def __init__(self, attention_op: Optional[Callable] = None):
1088
+ self.attention_op = attention_op
1089
+
1090
+ def __call__(
1091
+ self,
1092
+ attn: Attention,
1093
+ hidden_states: torch.FloatTensor,
1094
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1095
+ attention_mask: Optional[torch.FloatTensor] = None,
1096
+ temb: Optional[torch.FloatTensor] = None,
1097
+ scale: float = 1.0,
1098
+ ) -> torch.FloatTensor:
1099
+ residual = hidden_states
1100
+
1101
+ args = () if USE_PEFT_BACKEND else (scale,)
1102
+
1103
+ if attn.spatial_norm is not None:
1104
+ hidden_states = attn.spatial_norm(hidden_states, temb)
1105
+
1106
+ input_ndim = hidden_states.ndim
1107
+
1108
+ if input_ndim == 4:
1109
+ batch_size, channel, height, width = hidden_states.shape
1110
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1111
+
1112
+ batch_size, key_tokens, _ = (
1113
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1114
+ )
1115
+
1116
+ attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size)
1117
+ if attention_mask is not None:
1118
+ # expand our mask's singleton query_tokens dimension:
1119
+ # [batch*heads, 1, key_tokens] ->
1120
+ # [batch*heads, query_tokens, key_tokens]
1121
+ # so that it can be added as a bias onto the attention scores that xformers computes:
1122
+ # [batch*heads, query_tokens, key_tokens]
1123
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
1124
+ _, query_tokens, _ = hidden_states.shape
1125
+ attention_mask = attention_mask.expand(-1, query_tokens, -1)
1126
+
1127
+ if attn.group_norm is not None:
1128
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1129
+
1130
+ query = attn.to_q(hidden_states, *args)
1131
+
1132
+ if encoder_hidden_states is None:
1133
+ encoder_hidden_states = hidden_states
1134
+ elif attn.norm_cross:
1135
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1136
+
1137
+ key = attn.to_k(encoder_hidden_states, *args)
1138
+ value = attn.to_v(encoder_hidden_states, *args)
1139
+
1140
+ query = attn.head_to_batch_dim(query).contiguous()
1141
+ key = attn.head_to_batch_dim(key).contiguous()
1142
+ value = attn.head_to_batch_dim(value).contiguous()
1143
+
1144
+ hidden_states = xformers.ops.memory_efficient_attention(
1145
+ query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
1146
+ )
1147
+ hidden_states = hidden_states.to(query.dtype)
1148
+ hidden_states = attn.batch_to_head_dim(hidden_states)
1149
+
1150
+ # linear proj
1151
+ hidden_states = attn.to_out[0](hidden_states, *args)
1152
+ # dropout
1153
+ hidden_states = attn.to_out[1](hidden_states)
1154
+
1155
+ if input_ndim == 4:
1156
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1157
+
1158
+ if attn.residual_connection:
1159
+ hidden_states = hidden_states + residual
1160
+
1161
+ hidden_states = hidden_states / attn.rescale_output_factor
1162
+
1163
+ return hidden_states
1164
+
1165
+
1166
+ class AttnProcessor2_0:
1167
+ r"""
1168
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
1169
+ """
1170
+
1171
+ def __init__(self):
1172
+ if not hasattr(F, "scaled_dot_product_attention"):
1173
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
1174
+
1175
+ def __call__(
1176
+ self,
1177
+ attn: Attention,
1178
+ hidden_states: torch.FloatTensor,
1179
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1180
+ attention_mask: Optional[torch.FloatTensor] = None,
1181
+ temb: Optional[torch.FloatTensor] = None,
1182
+ scale: float = 1.0,
1183
+ ) -> torch.FloatTensor:
1184
+ residual = hidden_states
1185
+
1186
+ args = () if USE_PEFT_BACKEND else (scale,)
1187
+
1188
+ if attn.spatial_norm is not None:
1189
+ hidden_states = attn.spatial_norm(hidden_states, temb)
1190
+
1191
+ input_ndim = hidden_states.ndim
1192
+
1193
+ if input_ndim == 4:
1194
+ batch_size, channel, height, width = hidden_states.shape
1195
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1196
+
1197
+ batch_size, sequence_length, _ = (
1198
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1199
+ )
1200
+
1201
+ if attention_mask is not None:
1202
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1203
+ # scaled_dot_product_attention expects attention_mask shape to be
1204
+ # (batch, heads, source_length, target_length)
1205
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
1206
+
1207
+ if attn.group_norm is not None:
1208
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1209
+
1210
+ args = () if USE_PEFT_BACKEND else (scale,)
1211
+ query = attn.to_q(hidden_states, *args)
1212
+
1213
+ if encoder_hidden_states is None:
1214
+ encoder_hidden_states = hidden_states
1215
+ elif attn.norm_cross:
1216
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1217
+
1218
+ key = attn.to_k(encoder_hidden_states, *args)
1219
+ value = attn.to_v(encoder_hidden_states, *args)
1220
+
1221
+ inner_dim = key.shape[-1]
1222
+ head_dim = inner_dim // attn.heads
1223
+
1224
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1225
+
1226
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1227
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1228
+
1229
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
1230
+ # TODO: add support for attn.scale when we move to Torch 2.1
1231
+ hidden_states = F.scaled_dot_product_attention(
1232
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1233
+ )
1234
+
1235
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1236
+ hidden_states = hidden_states.to(query.dtype)
1237
+
1238
+ # linear proj
1239
+ hidden_states = attn.to_out[0](hidden_states, *args)
1240
+ # dropout
1241
+ hidden_states = attn.to_out[1](hidden_states)
1242
+
1243
+ if input_ndim == 4:
1244
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1245
+
1246
+ if attn.residual_connection:
1247
+ hidden_states = hidden_states + residual
1248
+
1249
+ hidden_states = hidden_states / attn.rescale_output_factor
1250
+
1251
+ return hidden_states
1252
+
1253
+
1254
+ class CustomDiffusionXFormersAttnProcessor(nn.Module):
1255
+ r"""
1256
+ Processor for implementing memory efficient attention using xFormers for the Custom Diffusion method.
1257
+
1258
+ Args:
1259
+ train_kv (`bool`, defaults to `True`):
1260
+ Whether to newly train the key and value matrices corresponding to the text features.
1261
+ train_q_out (`bool`, defaults to `True`):
1262
+ Whether to newly train query matrices corresponding to the latent image features.
1263
+ hidden_size (`int`, *optional*, defaults to `None`):
1264
+ The hidden size of the attention layer.
1265
+ cross_attention_dim (`int`, *optional*, defaults to `None`):
1266
+ The number of channels in the `encoder_hidden_states`.
1267
+ out_bias (`bool`, defaults to `True`):
1268
+ Whether to include the bias parameter in `train_q_out`.
1269
+ dropout (`float`, *optional*, defaults to 0.0):
1270
+ The dropout probability to use.
1271
+ attention_op (`Callable`, *optional*, defaults to `None`):
1272
+ The base
1273
+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to use
1274
+ as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best operator.
1275
+ """
1276
+
1277
+ def __init__(
1278
+ self,
1279
+ train_kv: bool = True,
1280
+ train_q_out: bool = False,
1281
+ hidden_size: Optional[int] = None,
1282
+ cross_attention_dim: Optional[int] = None,
1283
+ out_bias: bool = True,
1284
+ dropout: float = 0.0,
1285
+ attention_op: Optional[Callable] = None,
1286
+ ):
1287
+ super().__init__()
1288
+ self.train_kv = train_kv
1289
+ self.train_q_out = train_q_out
1290
+
1291
+ self.hidden_size = hidden_size
1292
+ self.cross_attention_dim = cross_attention_dim
1293
+ self.attention_op = attention_op
1294
+
1295
+ # `_custom_diffusion` id for easy serialization and loading.
1296
+ if self.train_kv:
1297
+ self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
1298
+ self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
1299
+ if self.train_q_out:
1300
+ self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
1301
+ self.to_out_custom_diffusion = nn.ModuleList([])
1302
+ self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
1303
+ self.to_out_custom_diffusion.append(nn.Dropout(dropout))
1304
+
1305
+ def __call__(
1306
+ self,
1307
+ attn: Attention,
1308
+ hidden_states: torch.FloatTensor,
1309
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1310
+ attention_mask: Optional[torch.FloatTensor] = None,
1311
+ ) -> torch.FloatTensor:
1312
+ batch_size, sequence_length, _ = (
1313
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1314
+ )
1315
+
1316
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1317
+
1318
+ if self.train_q_out:
1319
+ query = self.to_q_custom_diffusion(hidden_states).to(attn.to_q.weight.dtype)
1320
+ else:
1321
+ query = attn.to_q(hidden_states.to(attn.to_q.weight.dtype))
1322
+
1323
+ if encoder_hidden_states is None:
1324
+ crossattn = False
1325
+ encoder_hidden_states = hidden_states
1326
+ else:
1327
+ crossattn = True
1328
+ if attn.norm_cross:
1329
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1330
+
1331
+ if self.train_kv:
1332
+ key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype))
1333
+ value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype))
1334
+ key = key.to(attn.to_q.weight.dtype)
1335
+ value = value.to(attn.to_q.weight.dtype)
1336
+ else:
1337
+ key = attn.to_k(encoder_hidden_states)
1338
+ value = attn.to_v(encoder_hidden_states)
1339
+
1340
+ if crossattn:
1341
+ detach = torch.ones_like(key)
1342
+ detach[:, :1, :] = detach[:, :1, :] * 0.0
1343
+ key = detach * key + (1 - detach) * key.detach()
1344
+ value = detach * value + (1 - detach) * value.detach()
1345
+
1346
+ query = attn.head_to_batch_dim(query).contiguous()
1347
+ key = attn.head_to_batch_dim(key).contiguous()
1348
+ value = attn.head_to_batch_dim(value).contiguous()
1349
+
1350
+ hidden_states = xformers.ops.memory_efficient_attention(
1351
+ query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
1352
+ )
1353
+ hidden_states = hidden_states.to(query.dtype)
1354
+ hidden_states = attn.batch_to_head_dim(hidden_states)
1355
+
1356
+ if self.train_q_out:
1357
+ # linear proj
1358
+ hidden_states = self.to_out_custom_diffusion[0](hidden_states)
1359
+ # dropout
1360
+ hidden_states = self.to_out_custom_diffusion[1](hidden_states)
1361
+ else:
1362
+ # linear proj
1363
+ hidden_states = attn.to_out[0](hidden_states)
1364
+ # dropout
1365
+ hidden_states = attn.to_out[1](hidden_states)
1366
+
1367
+ return hidden_states
1368
+
1369
+
1370
+ class CustomDiffusionAttnProcessor2_0(nn.Module):
1371
+ r"""
1372
+ Processor for implementing attention for the Custom Diffusion method using PyTorch 2.0’s memory-efficient scaled
1373
+ dot-product attention.
1374
+
1375
+ Args:
1376
+ train_kv (`bool`, defaults to `True`):
1377
+ Whether to newly train the key and value matrices corresponding to the text features.
1378
+ train_q_out (`bool`, defaults to `True`):
1379
+ Whether to newly train query matrices corresponding to the latent image features.
1380
+ hidden_size (`int`, *optional*, defaults to `None`):
1381
+ The hidden size of the attention layer.
1382
+ cross_attention_dim (`int`, *optional*, defaults to `None`):
1383
+ The number of channels in the `encoder_hidden_states`.
1384
+ out_bias (`bool`, defaults to `True`):
1385
+ Whether to include the bias parameter in `train_q_out`.
1386
+ dropout (`float`, *optional*, defaults to 0.0):
1387
+ The dropout probability to use.
1388
+ """
1389
+
1390
+ def __init__(
1391
+ self,
1392
+ train_kv: bool = True,
1393
+ train_q_out: bool = True,
1394
+ hidden_size: Optional[int] = None,
1395
+ cross_attention_dim: Optional[int] = None,
1396
+ out_bias: bool = True,
1397
+ dropout: float = 0.0,
1398
+ ):
1399
+ super().__init__()
1400
+ self.train_kv = train_kv
1401
+ self.train_q_out = train_q_out
1402
+
1403
+ self.hidden_size = hidden_size
1404
+ self.cross_attention_dim = cross_attention_dim
1405
+
1406
+ # `_custom_diffusion` id for easy serialization and loading.
1407
+ if self.train_kv:
1408
+ self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
1409
+ self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
1410
+ if self.train_q_out:
1411
+ self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
1412
+ self.to_out_custom_diffusion = nn.ModuleList([])
1413
+ self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
1414
+ self.to_out_custom_diffusion.append(nn.Dropout(dropout))
1415
+
1416
+ def __call__(
1417
+ self,
1418
+ attn: Attention,
1419
+ hidden_states: torch.FloatTensor,
1420
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1421
+ attention_mask: Optional[torch.FloatTensor] = None,
1422
+ ) -> torch.FloatTensor:
1423
+ batch_size, sequence_length, _ = hidden_states.shape
1424
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1425
+ if self.train_q_out:
1426
+ query = self.to_q_custom_diffusion(hidden_states)
1427
+ else:
1428
+ query = attn.to_q(hidden_states)
1429
+
1430
+ if encoder_hidden_states is None:
1431
+ crossattn = False
1432
+ encoder_hidden_states = hidden_states
1433
+ else:
1434
+ crossattn = True
1435
+ if attn.norm_cross:
1436
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1437
+
1438
+ if self.train_kv:
1439
+ key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype))
1440
+ value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype))
1441
+ key = key.to(attn.to_q.weight.dtype)
1442
+ value = value.to(attn.to_q.weight.dtype)
1443
+
1444
+ else:
1445
+ key = attn.to_k(encoder_hidden_states)
1446
+ value = attn.to_v(encoder_hidden_states)
1447
+
1448
+ if crossattn:
1449
+ detach = torch.ones_like(key)
1450
+ detach[:, :1, :] = detach[:, :1, :] * 0.0
1451
+ key = detach * key + (1 - detach) * key.detach()
1452
+ value = detach * value + (1 - detach) * value.detach()
1453
+
1454
+ inner_dim = hidden_states.shape[-1]
1455
+
1456
+ head_dim = inner_dim // attn.heads
1457
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1458
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1459
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1460
+
1461
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
1462
+ # TODO: add support for attn.scale when we move to Torch 2.1
1463
+ hidden_states = F.scaled_dot_product_attention(
1464
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1465
+ )
1466
+
1467
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1468
+ hidden_states = hidden_states.to(query.dtype)
1469
+
1470
+ if self.train_q_out:
1471
+ # linear proj
1472
+ hidden_states = self.to_out_custom_diffusion[0](hidden_states)
1473
+ # dropout
1474
+ hidden_states = self.to_out_custom_diffusion[1](hidden_states)
1475
+ else:
1476
+ # linear proj
1477
+ hidden_states = attn.to_out[0](hidden_states)
1478
+ # dropout
1479
+ hidden_states = attn.to_out[1](hidden_states)
1480
+
1481
+ return hidden_states
1482
+
1483
+
1484
+ class SlicedAttnProcessor:
1485
+ r"""
1486
+ Processor for implementing sliced attention.
1487
+
1488
+ Args:
1489
+ slice_size (`int`, *optional*):
1490
+ The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
1491
+ `attention_head_dim` must be a multiple of the `slice_size`.
1492
+ """
1493
+
1494
+ def __init__(self, slice_size: int):
1495
+ self.slice_size = slice_size
1496
+
1497
+ def __call__(
1498
+ self,
1499
+ attn: Attention,
1500
+ hidden_states: torch.FloatTensor,
1501
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1502
+ attention_mask: Optional[torch.FloatTensor] = None,
1503
+ ) -> torch.FloatTensor:
1504
+ residual = hidden_states
1505
+
1506
+ input_ndim = hidden_states.ndim
1507
+
1508
+ if input_ndim == 4:
1509
+ batch_size, channel, height, width = hidden_states.shape
1510
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1511
+
1512
+ batch_size, sequence_length, _ = (
1513
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1514
+ )
1515
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1516
+
1517
+ if attn.group_norm is not None:
1518
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1519
+
1520
+ query = attn.to_q(hidden_states)
1521
+ dim = query.shape[-1]
1522
+ query = attn.head_to_batch_dim(query)
1523
+
1524
+ if encoder_hidden_states is None:
1525
+ encoder_hidden_states = hidden_states
1526
+ elif attn.norm_cross:
1527
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1528
+
1529
+ key = attn.to_k(encoder_hidden_states)
1530
+ value = attn.to_v(encoder_hidden_states)
1531
+ key = attn.head_to_batch_dim(key)
1532
+ value = attn.head_to_batch_dim(value)
1533
+
1534
+ batch_size_attention, query_tokens, _ = query.shape
1535
+ hidden_states = torch.zeros(
1536
+ (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
1537
+ )
1538
+
1539
+ for i in range(batch_size_attention // self.slice_size):
1540
+ start_idx = i * self.slice_size
1541
+ end_idx = (i + 1) * self.slice_size
1542
+
1543
+ query_slice = query[start_idx:end_idx]
1544
+ key_slice = key[start_idx:end_idx]
1545
+ attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
1546
+
1547
+ attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
1548
+
1549
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
1550
+
1551
+ hidden_states[start_idx:end_idx] = attn_slice
1552
+
1553
+ hidden_states = attn.batch_to_head_dim(hidden_states)
1554
+
1555
+ # linear proj
1556
+ hidden_states = attn.to_out[0](hidden_states)
1557
+ # dropout
1558
+ hidden_states = attn.to_out[1](hidden_states)
1559
+
1560
+ if input_ndim == 4:
1561
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1562
+
1563
+ if attn.residual_connection:
1564
+ hidden_states = hidden_states + residual
1565
+
1566
+ hidden_states = hidden_states / attn.rescale_output_factor
1567
+
1568
+ return hidden_states
1569
+
1570
+
1571
+ class SlicedAttnAddedKVProcessor:
1572
+ r"""
1573
+ Processor for implementing sliced attention with extra learnable key and value matrices for the text encoder.
1574
+
1575
+ Args:
1576
+ slice_size (`int`, *optional*):
1577
+ The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
1578
+ `attention_head_dim` must be a multiple of the `slice_size`.
1579
+ """
1580
+
1581
+ def __init__(self, slice_size):
1582
+ self.slice_size = slice_size
1583
+
1584
+ def __call__(
1585
+ self,
1586
+ attn: "Attention",
1587
+ hidden_states: torch.FloatTensor,
1588
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1589
+ attention_mask: Optional[torch.FloatTensor] = None,
1590
+ temb: Optional[torch.FloatTensor] = None,
1591
+ ) -> torch.FloatTensor:
1592
+ residual = hidden_states
1593
+
1594
+ if attn.spatial_norm is not None:
1595
+ hidden_states = attn.spatial_norm(hidden_states, temb)
1596
+
1597
+ hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
1598
+
1599
+ batch_size, sequence_length, _ = hidden_states.shape
1600
+
1601
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1602
+
1603
+ if encoder_hidden_states is None:
1604
+ encoder_hidden_states = hidden_states
1605
+ elif attn.norm_cross:
1606
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1607
+
1608
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1609
+
1610
+ query = attn.to_q(hidden_states)
1611
+ dim = query.shape[-1]
1612
+ query = attn.head_to_batch_dim(query)
1613
+
1614
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
1615
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
1616
+
1617
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
1618
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
1619
+
1620
+ if not attn.only_cross_attention:
1621
+ key = attn.to_k(hidden_states)
1622
+ value = attn.to_v(hidden_states)
1623
+ key = attn.head_to_batch_dim(key)
1624
+ value = attn.head_to_batch_dim(value)
1625
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
1626
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
1627
+ else:
1628
+ key = encoder_hidden_states_key_proj
1629
+ value = encoder_hidden_states_value_proj
1630
+
1631
+ batch_size_attention, query_tokens, _ = query.shape
1632
+ hidden_states = torch.zeros(
1633
+ (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
1634
+ )
1635
+
1636
+ for i in range(batch_size_attention // self.slice_size):
1637
+ start_idx = i * self.slice_size
1638
+ end_idx = (i + 1) * self.slice_size
1639
+
1640
+ query_slice = query[start_idx:end_idx]
1641
+ key_slice = key[start_idx:end_idx]
1642
+ attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
1643
+
1644
+ attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
1645
+
1646
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
1647
+
1648
+ hidden_states[start_idx:end_idx] = attn_slice
1649
+
1650
+ hidden_states = attn.batch_to_head_dim(hidden_states)
1651
+
1652
+ # linear proj
1653
+ hidden_states = attn.to_out[0](hidden_states)
1654
+ # dropout
1655
+ hidden_states = attn.to_out[1](hidden_states)
1656
+
1657
+ hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
1658
+ hidden_states = hidden_states + residual
1659
+
1660
+ return hidden_states
1661
+
1662
+
1663
+ class SpatialNorm(nn.Module):
1664
+ """
1665
+ Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002.
1666
+
1667
+ Args:
1668
+ f_channels (`int`):
1669
+ The number of channels for input to group normalization layer, and output of the spatial norm layer.
1670
+ zq_channels (`int`):
1671
+ The number of channels for the quantized vector as described in the paper.
1672
+ """
1673
+
1674
+ def __init__(
1675
+ self,
1676
+ f_channels: int,
1677
+ zq_channels: int,
1678
+ ):
1679
+ super().__init__()
1680
+ self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True)
1681
+ self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
1682
+ self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
1683
+
1684
+ def forward(self, f: torch.FloatTensor, zq: torch.FloatTensor) -> torch.FloatTensor:
1685
+ f_size = f.shape[-2:]
1686
+ zq = F.interpolate(zq, size=f_size, mode="nearest")
1687
+ norm_f = self.norm_layer(f)
1688
+ new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
1689
+ return new_f
1690
+
1691
+
1692
+ ## Deprecated
1693
+ class LoRAAttnProcessor(nn.Module):
1694
+ r"""
1695
+ Processor for implementing the LoRA attention mechanism.
1696
+
1697
+ Args:
1698
+ hidden_size (`int`, *optional*):
1699
+ The hidden size of the attention layer.
1700
+ cross_attention_dim (`int`, *optional*):
1701
+ The number of channels in the `encoder_hidden_states`.
1702
+ rank (`int`, defaults to 4):
1703
+ The dimension of the LoRA update matrices.
1704
+ network_alpha (`int`, *optional*):
1705
+ Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
1706
+ kwargs (`dict`):
1707
+ Additional keyword arguments to pass to the `LoRALinearLayer` layers.
1708
+ """
1709
+
1710
+ def __init__(
1711
+ self,
1712
+ hidden_size: int,
1713
+ cross_attention_dim: Optional[int] = None,
1714
+ rank: int = 4,
1715
+ network_alpha: Optional[int] = None,
1716
+ **kwargs,
1717
+ ):
1718
+ super().__init__()
1719
+
1720
+ self.hidden_size = hidden_size
1721
+ self.cross_attention_dim = cross_attention_dim
1722
+ self.rank = rank
1723
+
1724
+ q_rank = kwargs.pop("q_rank", None)
1725
+ q_hidden_size = kwargs.pop("q_hidden_size", None)
1726
+ q_rank = q_rank if q_rank is not None else rank
1727
+ q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size
1728
+
1729
+ v_rank = kwargs.pop("v_rank", None)
1730
+ v_hidden_size = kwargs.pop("v_hidden_size", None)
1731
+ v_rank = v_rank if v_rank is not None else rank
1732
+ v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size
1733
+
1734
+ out_rank = kwargs.pop("out_rank", None)
1735
+ out_hidden_size = kwargs.pop("out_hidden_size", None)
1736
+ out_rank = out_rank if out_rank is not None else rank
1737
+ out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size
1738
+
1739
+ self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha)
1740
+ self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
1741
+ self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
1742
+ self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
1743
+
1744
+ def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
1745
+ self_cls_name = self.__class__.__name__
1746
+ deprecate(
1747
+ self_cls_name,
1748
+ "0.26.0",
1749
+ (
1750
+ f"Make sure use {self_cls_name[4:]} instead by setting"
1751
+ "LoRA layers to `self.{to_q,to_k,to_v,to_out[0]}.lora_layer` respectively. This will be done automatically when using"
1752
+ " `LoraLoaderMixin.load_lora_weights`"
1753
+ ),
1754
+ )
1755
+ attn.to_q.lora_layer = self.to_q_lora.to(hidden_states.device)
1756
+ attn.to_k.lora_layer = self.to_k_lora.to(hidden_states.device)
1757
+ attn.to_v.lora_layer = self.to_v_lora.to(hidden_states.device)
1758
+ attn.to_out[0].lora_layer = self.to_out_lora.to(hidden_states.device)
1759
+
1760
+ attn._modules.pop("processor")
1761
+ attn.processor = AttnProcessor()
1762
+ return attn.processor(attn, hidden_states, *args, **kwargs)
1763
+
1764
+
1765
+ class LoRAAttnProcessor2_0(nn.Module):
1766
+ r"""
1767
+ Processor for implementing the LoRA attention mechanism using PyTorch 2.0's memory-efficient scaled dot-product
1768
+ attention.
1769
+
1770
+ Args:
1771
+ hidden_size (`int`):
1772
+ The hidden size of the attention layer.
1773
+ cross_attention_dim (`int`, *optional*):
1774
+ The number of channels in the `encoder_hidden_states`.
1775
+ rank (`int`, defaults to 4):
1776
+ The dimension of the LoRA update matrices.
1777
+ network_alpha (`int`, *optional*):
1778
+ Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
1779
+ kwargs (`dict`):
1780
+ Additional keyword arguments to pass to the `LoRALinearLayer` layers.
1781
+ """
1782
+
1783
+ def __init__(
1784
+ self,
1785
+ hidden_size: int,
1786
+ cross_attention_dim: Optional[int] = None,
1787
+ rank: int = 4,
1788
+ network_alpha: Optional[int] = None,
1789
+ **kwargs,
1790
+ ):
1791
+ super().__init__()
1792
+ if not hasattr(F, "scaled_dot_product_attention"):
1793
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
1794
+
1795
+ self.hidden_size = hidden_size
1796
+ self.cross_attention_dim = cross_attention_dim
1797
+ self.rank = rank
1798
+
1799
+ q_rank = kwargs.pop("q_rank", None)
1800
+ q_hidden_size = kwargs.pop("q_hidden_size", None)
1801
+ q_rank = q_rank if q_rank is not None else rank
1802
+ q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size
1803
+
1804
+ v_rank = kwargs.pop("v_rank", None)
1805
+ v_hidden_size = kwargs.pop("v_hidden_size", None)
1806
+ v_rank = v_rank if v_rank is not None else rank
1807
+ v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size
1808
+
1809
+ out_rank = kwargs.pop("out_rank", None)
1810
+ out_hidden_size = kwargs.pop("out_hidden_size", None)
1811
+ out_rank = out_rank if out_rank is not None else rank
1812
+ out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size
1813
+
1814
+ self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha)
1815
+ self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
1816
+ self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
1817
+ self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
1818
+
1819
+ def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
1820
+ self_cls_name = self.__class__.__name__
1821
+ deprecate(
1822
+ self_cls_name,
1823
+ "0.26.0",
1824
+ (
1825
+ f"Make sure use {self_cls_name[4:]} instead by setting"
1826
+ "LoRA layers to `self.{to_q,to_k,to_v,to_out[0]}.lora_layer` respectively. This will be done automatically when using"
1827
+ " `LoraLoaderMixin.load_lora_weights`"
1828
+ ),
1829
+ )
1830
+ attn.to_q.lora_layer = self.to_q_lora.to(hidden_states.device)
1831
+ attn.to_k.lora_layer = self.to_k_lora.to(hidden_states.device)
1832
+ attn.to_v.lora_layer = self.to_v_lora.to(hidden_states.device)
1833
+ attn.to_out[0].lora_layer = self.to_out_lora.to(hidden_states.device)
1834
+
1835
+ attn._modules.pop("processor")
1836
+ attn.processor = AttnProcessor2_0()
1837
+ return attn.processor(attn, hidden_states, *args, **kwargs)
1838
+
1839
+
1840
+ class LoRAXFormersAttnProcessor(nn.Module):
1841
+ r"""
1842
+ Processor for implementing the LoRA attention mechanism with memory efficient attention using xFormers.
1843
+
1844
+ Args:
1845
+ hidden_size (`int`, *optional*):
1846
+ The hidden size of the attention layer.
1847
+ cross_attention_dim (`int`, *optional*):
1848
+ The number of channels in the `encoder_hidden_states`.
1849
+ rank (`int`, defaults to 4):
1850
+ The dimension of the LoRA update matrices.
1851
+ attention_op (`Callable`, *optional*, defaults to `None`):
1852
+ The base
1853
+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
1854
+ use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
1855
+ operator.
1856
+ network_alpha (`int`, *optional*):
1857
+ Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
1858
+ kwargs (`dict`):
1859
+ Additional keyword arguments to pass to the `LoRALinearLayer` layers.
1860
+ """
1861
+
1862
+ def __init__(
1863
+ self,
1864
+ hidden_size: int,
1865
+ cross_attention_dim: int,
1866
+ rank: int = 4,
1867
+ attention_op: Optional[Callable] = None,
1868
+ network_alpha: Optional[int] = None,
1869
+ **kwargs,
1870
+ ):
1871
+ super().__init__()
1872
+
1873
+ self.hidden_size = hidden_size
1874
+ self.cross_attention_dim = cross_attention_dim
1875
+ self.rank = rank
1876
+ self.attention_op = attention_op
1877
+
1878
+ q_rank = kwargs.pop("q_rank", None)
1879
+ q_hidden_size = kwargs.pop("q_hidden_size", None)
1880
+ q_rank = q_rank if q_rank is not None else rank
1881
+ q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size
1882
+
1883
+ v_rank = kwargs.pop("v_rank", None)
1884
+ v_hidden_size = kwargs.pop("v_hidden_size", None)
1885
+ v_rank = v_rank if v_rank is not None else rank
1886
+ v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size
1887
+
1888
+ out_rank = kwargs.pop("out_rank", None)
1889
+ out_hidden_size = kwargs.pop("out_hidden_size", None)
1890
+ out_rank = out_rank if out_rank is not None else rank
1891
+ out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size
1892
+
1893
+ self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha)
1894
+ self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
1895
+ self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
1896
+ self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
1897
+
1898
+ def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
1899
+ self_cls_name = self.__class__.__name__
1900
+ deprecate(
1901
+ self_cls_name,
1902
+ "0.26.0",
1903
+ (
1904
+ f"Make sure use {self_cls_name[4:]} instead by setting"
1905
+ "LoRA layers to `self.{to_q,to_k,to_v,add_k_proj,add_v_proj,to_out[0]}.lora_layer` respectively. This will be done automatically when using"
1906
+ " `LoraLoaderMixin.load_lora_weights`"
1907
+ ),
1908
+ )
1909
+ attn.to_q.lora_layer = self.to_q_lora.to(hidden_states.device)
1910
+ attn.to_k.lora_layer = self.to_k_lora.to(hidden_states.device)
1911
+ attn.to_v.lora_layer = self.to_v_lora.to(hidden_states.device)
1912
+ attn.to_out[0].lora_layer = self.to_out_lora.to(hidden_states.device)
1913
+
1914
+ attn._modules.pop("processor")
1915
+ attn.processor = XFormersAttnProcessor()
1916
+ return attn.processor(attn, hidden_states, *args, **kwargs)
1917
+
1918
+
1919
+ class LoRAAttnAddedKVProcessor(nn.Module):
1920
+ r"""
1921
+ Processor for implementing the LoRA attention mechanism with extra learnable key and value matrices for the text
1922
+ encoder.
1923
+
1924
+ Args:
1925
+ hidden_size (`int`, *optional*):
1926
+ The hidden size of the attention layer.
1927
+ cross_attention_dim (`int`, *optional*, defaults to `None`):
1928
+ The number of channels in the `encoder_hidden_states`.
1929
+ rank (`int`, defaults to 4):
1930
+ The dimension of the LoRA update matrices.
1931
+ network_alpha (`int`, *optional*):
1932
+ Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
1933
+ kwargs (`dict`):
1934
+ Additional keyword arguments to pass to the `LoRALinearLayer` layers.
1935
+ """
1936
+
1937
+ def __init__(
1938
+ self,
1939
+ hidden_size: int,
1940
+ cross_attention_dim: Optional[int] = None,
1941
+ rank: int = 4,
1942
+ network_alpha: Optional[int] = None,
1943
+ ):
1944
+ super().__init__()
1945
+
1946
+ self.hidden_size = hidden_size
1947
+ self.cross_attention_dim = cross_attention_dim
1948
+ self.rank = rank
1949
+
1950
+ self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
1951
+ self.add_k_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
1952
+ self.add_v_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
1953
+ self.to_k_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
1954
+ self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
1955
+ self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
1956
+
1957
+ def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
1958
+ self_cls_name = self.__class__.__name__
1959
+ deprecate(
1960
+ self_cls_name,
1961
+ "0.26.0",
1962
+ (
1963
+ f"Make sure use {self_cls_name[4:]} instead by setting"
1964
+ "LoRA layers to `self.{to_q,to_k,to_v,add_k_proj,add_v_proj,to_out[0]}.lora_layer` respectively. This will be done automatically when using"
1965
+ " `LoraLoaderMixin.load_lora_weights`"
1966
+ ),
1967
+ )
1968
+ attn.to_q.lora_layer = self.to_q_lora.to(hidden_states.device)
1969
+ attn.to_k.lora_layer = self.to_k_lora.to(hidden_states.device)
1970
+ attn.to_v.lora_layer = self.to_v_lora.to(hidden_states.device)
1971
+ attn.to_out[0].lora_layer = self.to_out_lora.to(hidden_states.device)
1972
+
1973
+ attn._modules.pop("processor")
1974
+ attn.processor = AttnAddedKVProcessor()
1975
+ return attn.processor(attn, hidden_states, *args, **kwargs)
1976
+
1977
+
1978
+ LORA_ATTENTION_PROCESSORS = (
1979
+ LoRAAttnProcessor,
1980
+ LoRAAttnProcessor2_0,
1981
+ LoRAXFormersAttnProcessor,
1982
+ LoRAAttnAddedKVProcessor,
1983
+ )
1984
+
1985
+ ADDED_KV_ATTENTION_PROCESSORS = (
1986
+ AttnAddedKVProcessor,
1987
+ SlicedAttnAddedKVProcessor,
1988
+ AttnAddedKVProcessor2_0,
1989
+ XFormersAttnAddedKVProcessor,
1990
+ LoRAAttnAddedKVProcessor,
1991
+ )
1992
+
1993
+ CROSS_ATTENTION_PROCESSORS = (
1994
+ AttnProcessor,
1995
+ AttnProcessor2_0,
1996
+ XFormersAttnProcessor,
1997
+ SlicedAttnProcessor,
1998
+ LoRAAttnProcessor,
1999
+ LoRAAttnProcessor2_0,
2000
+ LoRAXFormersAttnProcessor,
2001
+ )
2002
+
2003
+ AttentionProcessor = Union[
2004
+ AttnProcessor,
2005
+ AttnProcessor2_0,
2006
+ XFormersAttnProcessor,
2007
+ SlicedAttnProcessor,
2008
+ AttnAddedKVProcessor,
2009
+ SlicedAttnAddedKVProcessor,
2010
+ AttnAddedKVProcessor2_0,
2011
+ XFormersAttnAddedKVProcessor,
2012
+ CustomDiffusionAttnProcessor,
2013
+ CustomDiffusionXFormersAttnProcessor,
2014
+ CustomDiffusionAttnProcessor2_0,
2015
+ # deprecated
2016
+ LoRAAttnProcessor,
2017
+ LoRAAttnProcessor2_0,
2018
+ LoRAXFormersAttnProcessor,
2019
+ LoRAAttnAddedKVProcessor,
2020
+ ]
diffusers/models/autoencoder_asym_kl.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Optional, Tuple, Union
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+
19
+ from ..configuration_utils import ConfigMixin, register_to_config
20
+ from ..utils.accelerate_utils import apply_forward_hook
21
+ from .autoencoder_kl import AutoencoderKLOutput
22
+ from .modeling_utils import ModelMixin
23
+ from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder, MaskConditionDecoder
24
+
25
+
26
+ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
27
+ r"""
28
+ Designing a Better Asymmetric VQGAN for StableDiffusion https://arxiv.org/abs/2306.04632 . A VAE model with KL loss
29
+ for encoding images into latents and decoding latent representations into images.
30
+
31
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
32
+ for all models (such as downloading or saving).
33
+
34
+ Parameters:
35
+ in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
36
+ out_channels (int, *optional*, defaults to 3): Number of channels in the output.
37
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
38
+ Tuple of downsample block types.
39
+ down_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
40
+ Tuple of down block output channels.
41
+ layers_per_down_block (`int`, *optional*, defaults to `1`):
42
+ Number layers for down block.
43
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
44
+ Tuple of upsample block types.
45
+ up_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
46
+ Tuple of up block output channels.
47
+ layers_per_up_block (`int`, *optional*, defaults to `1`):
48
+ Number layers for up block.
49
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
50
+ latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
51
+ sample_size (`int`, *optional*, defaults to `32`): Sample input size.
52
+ norm_num_groups (`int`, *optional*, defaults to `32`):
53
+ Number of groups to use for the first normalization layer in ResNet blocks.
54
+ scaling_factor (`float`, *optional*, defaults to 0.18215):
55
+ The component-wise standard deviation of the trained latent space computed using the first batch of the
56
+ training set. This is used to scale the latent space to have unit variance when training the diffusion
57
+ model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
58
+ diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
59
+ / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
60
+ Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
61
+ """
62
+
63
+ @register_to_config
64
+ def __init__(
65
+ self,
66
+ in_channels: int = 3,
67
+ out_channels: int = 3,
68
+ down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
69
+ down_block_out_channels: Tuple[int] = (64,),
70
+ layers_per_down_block: int = 1,
71
+ up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
72
+ up_block_out_channels: Tuple[int] = (64,),
73
+ layers_per_up_block: int = 1,
74
+ act_fn: str = "silu",
75
+ latent_channels: int = 4,
76
+ norm_num_groups: int = 32,
77
+ sample_size: int = 32,
78
+ scaling_factor: float = 0.18215,
79
+ ) -> None:
80
+ super().__init__()
81
+
82
+ # pass init params to Encoder
83
+ self.encoder = Encoder(
84
+ in_channels=in_channels,
85
+ out_channels=latent_channels,
86
+ down_block_types=down_block_types,
87
+ block_out_channels=down_block_out_channels,
88
+ layers_per_block=layers_per_down_block,
89
+ act_fn=act_fn,
90
+ norm_num_groups=norm_num_groups,
91
+ double_z=True,
92
+ )
93
+
94
+ # pass init params to Decoder
95
+ self.decoder = MaskConditionDecoder(
96
+ in_channels=latent_channels,
97
+ out_channels=out_channels,
98
+ up_block_types=up_block_types,
99
+ block_out_channels=up_block_out_channels,
100
+ layers_per_block=layers_per_up_block,
101
+ act_fn=act_fn,
102
+ norm_num_groups=norm_num_groups,
103
+ )
104
+
105
+ self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
106
+ self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1)
107
+
108
+ self.use_slicing = False
109
+ self.use_tiling = False
110
+
111
+ @apply_forward_hook
112
+ def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
113
+ h = self.encoder(x)
114
+ moments = self.quant_conv(h)
115
+ posterior = DiagonalGaussianDistribution(moments)
116
+
117
+ if not return_dict:
118
+ return (posterior,)
119
+
120
+ return AutoencoderKLOutput(latent_dist=posterior)
121
+
122
+ def _decode(
123
+ self,
124
+ z: torch.FloatTensor,
125
+ image: Optional[torch.FloatTensor] = None,
126
+ mask: Optional[torch.FloatTensor] = None,
127
+ return_dict: bool = True,
128
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
129
+ z = self.post_quant_conv(z)
130
+ dec = self.decoder(z, image, mask)
131
+
132
+ if not return_dict:
133
+ return (dec,)
134
+
135
+ return DecoderOutput(sample=dec)
136
+
137
+ @apply_forward_hook
138
+ def decode(
139
+ self,
140
+ z: torch.FloatTensor,
141
+ generator: Optional[torch.Generator] = None,
142
+ image: Optional[torch.FloatTensor] = None,
143
+ mask: Optional[torch.FloatTensor] = None,
144
+ return_dict: bool = True,
145
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
146
+ decoded = self._decode(z, image, mask).sample
147
+
148
+ if not return_dict:
149
+ return (decoded,)
150
+
151
+ return DecoderOutput(sample=decoded)
152
+
153
+ def forward(
154
+ self,
155
+ sample: torch.FloatTensor,
156
+ mask: Optional[torch.FloatTensor] = None,
157
+ sample_posterior: bool = False,
158
+ return_dict: bool = True,
159
+ generator: Optional[torch.Generator] = None,
160
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
161
+ r"""
162
+ Args:
163
+ sample (`torch.FloatTensor`): Input sample.
164
+ mask (`torch.FloatTensor`, *optional*, defaults to `None`): Optional inpainting mask.
165
+ sample_posterior (`bool`, *optional*, defaults to `False`):
166
+ Whether to sample from the posterior.
167
+ return_dict (`bool`, *optional*, defaults to `True`):
168
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
169
+ """
170
+ x = sample
171
+ posterior = self.encode(x).latent_dist
172
+ if sample_posterior:
173
+ z = posterior.sample(generator=generator)
174
+ else:
175
+ z = posterior.mode()
176
+ dec = self.decode(z, sample, mask).sample
177
+
178
+ if not return_dict:
179
+ return (dec,)
180
+
181
+ return DecoderOutput(sample=dec)
diffusers/models/autoencoder_kl.py ADDED
@@ -0,0 +1,465 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Dict, Optional, Tuple, Union
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from ..configuration_utils import ConfigMixin, register_to_config
21
+ from ..loaders import FromOriginalVAEMixin
22
+ from ..utils import BaseOutput
23
+ from ..utils.accelerate_utils import apply_forward_hook
24
+ from .attention_processor import (
25
+ ADDED_KV_ATTENTION_PROCESSORS,
26
+ CROSS_ATTENTION_PROCESSORS,
27
+ AttentionProcessor,
28
+ AttnAddedKVProcessor,
29
+ AttnProcessor,
30
+ )
31
+ from .modeling_utils import ModelMixin
32
+ from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
33
+
34
+
35
+ @dataclass
36
+ class AutoencoderKLOutput(BaseOutput):
37
+ """
38
+ Output of AutoencoderKL encoding method.
39
+
40
+ Args:
41
+ latent_dist (`DiagonalGaussianDistribution`):
42
+ Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
43
+ `DiagonalGaussianDistribution` allows for sampling latents from the distribution.
44
+ """
45
+
46
+ latent_dist: "DiagonalGaussianDistribution"
47
+
48
+
49
+ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
50
+ r"""
51
+ A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
52
+
53
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
54
+ for all models (such as downloading or saving).
55
+
56
+ Parameters:
57
+ in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
58
+ out_channels (int, *optional*, defaults to 3): Number of channels in the output.
59
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
60
+ Tuple of downsample block types.
61
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
62
+ Tuple of upsample block types.
63
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
64
+ Tuple of block output channels.
65
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
66
+ latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
67
+ sample_size (`int`, *optional*, defaults to `32`): Sample input size.
68
+ scaling_factor (`float`, *optional*, defaults to 0.18215):
69
+ The component-wise standard deviation of the trained latent space computed using the first batch of the
70
+ training set. This is used to scale the latent space to have unit variance when training the diffusion
71
+ model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
72
+ diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
73
+ / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
74
+ Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
75
+ force_upcast (`bool`, *optional*, default to `True`):
76
+ If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
77
+ can be fine-tuned / trained to a lower range without loosing too much precision in which case
78
+ `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
79
+ """
80
+
81
+ _supports_gradient_checkpointing = True
82
+
83
+ @register_to_config
84
+ def __init__(
85
+ self,
86
+ in_channels: int = 3,
87
+ out_channels: int = 3,
88
+ down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
89
+ up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
90
+ block_out_channels: Tuple[int] = (64,),
91
+ layers_per_block: int = 1,
92
+ act_fn: str = "silu",
93
+ latent_channels: int = 4,
94
+ norm_num_groups: int = 32,
95
+ sample_size: int = 32,
96
+ scaling_factor: float = 0.18215,
97
+ force_upcast: float = True,
98
+ ):
99
+ super().__init__()
100
+
101
+ # pass init params to Encoder
102
+ self.encoder = Encoder(
103
+ in_channels=in_channels,
104
+ out_channels=latent_channels,
105
+ down_block_types=down_block_types,
106
+ block_out_channels=block_out_channels,
107
+ layers_per_block=layers_per_block,
108
+ act_fn=act_fn,
109
+ norm_num_groups=norm_num_groups,
110
+ double_z=True,
111
+ )
112
+
113
+ # pass init params to Decoder
114
+ self.decoder = Decoder(
115
+ in_channels=latent_channels,
116
+ out_channels=out_channels,
117
+ up_block_types=up_block_types,
118
+ block_out_channels=block_out_channels,
119
+ layers_per_block=layers_per_block,
120
+ norm_num_groups=norm_num_groups,
121
+ act_fn=act_fn,
122
+ )
123
+
124
+ self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
125
+ self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1)
126
+
127
+ self.use_slicing = False
128
+ self.use_tiling = False
129
+
130
+ # only relevant if vae tiling is enabled
131
+ self.tile_sample_min_size = self.config.sample_size
132
+ sample_size = (
133
+ self.config.sample_size[0]
134
+ if isinstance(self.config.sample_size, (list, tuple))
135
+ else self.config.sample_size
136
+ )
137
+ self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
138
+ self.tile_overlap_factor = 0.25
139
+
140
+ def _set_gradient_checkpointing(self, module, value=False):
141
+ if isinstance(module, (Encoder, Decoder)):
142
+ module.gradient_checkpointing = value
143
+
144
+ def enable_tiling(self, use_tiling: bool = True):
145
+ r"""
146
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
147
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
148
+ processing larger images.
149
+ """
150
+ self.use_tiling = use_tiling
151
+
152
+ def disable_tiling(self):
153
+ r"""
154
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
155
+ decoding in one step.
156
+ """
157
+ self.enable_tiling(False)
158
+
159
+ def enable_slicing(self):
160
+ r"""
161
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
162
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
163
+ """
164
+ self.use_slicing = True
165
+
166
+ def disable_slicing(self):
167
+ r"""
168
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
169
+ decoding in one step.
170
+ """
171
+ self.use_slicing = False
172
+
173
+ @property
174
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
175
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
176
+ r"""
177
+ Returns:
178
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
179
+ indexed by its weight name.
180
+ """
181
+ # set recursively
182
+ processors = {}
183
+
184
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
185
+ if hasattr(module, "get_processor"):
186
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
187
+
188
+ for sub_name, child in module.named_children():
189
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
190
+
191
+ return processors
192
+
193
+ for name, module in self.named_children():
194
+ fn_recursive_add_processors(name, module, processors)
195
+
196
+ return processors
197
+
198
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
199
+ def set_attn_processor(
200
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
201
+ ):
202
+ r"""
203
+ Sets the attention processor to use to compute attention.
204
+
205
+ Parameters:
206
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
207
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
208
+ for **all** `Attention` layers.
209
+
210
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
211
+ processor. This is strongly recommended when setting trainable attention processors.
212
+
213
+ """
214
+ count = len(self.attn_processors.keys())
215
+
216
+ if isinstance(processor, dict) and len(processor) != count:
217
+ raise ValueError(
218
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
219
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
220
+ )
221
+
222
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
223
+ if hasattr(module, "set_processor"):
224
+ if not isinstance(processor, dict):
225
+ module.set_processor(processor, _remove_lora=_remove_lora)
226
+ else:
227
+ module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
228
+
229
+ for sub_name, child in module.named_children():
230
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
231
+
232
+ for name, module in self.named_children():
233
+ fn_recursive_attn_processor(name, module, processor)
234
+
235
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
236
+ def set_default_attn_processor(self):
237
+ """
238
+ Disables custom attention processors and sets the default attention implementation.
239
+ """
240
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
241
+ processor = AttnAddedKVProcessor()
242
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
243
+ processor = AttnProcessor()
244
+ else:
245
+ raise ValueError(
246
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
247
+ )
248
+
249
+ self.set_attn_processor(processor, _remove_lora=True)
250
+
251
+ @apply_forward_hook
252
+ def encode(
253
+ self, x: torch.FloatTensor, return_dict: bool = True
254
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
255
+ """
256
+ Encode a batch of images into latents.
257
+
258
+ Args:
259
+ x (`torch.FloatTensor`): Input batch of images.
260
+ return_dict (`bool`, *optional*, defaults to `True`):
261
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
262
+
263
+ Returns:
264
+ The latent representations of the encoded images. If `return_dict` is True, a
265
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
266
+ """
267
+ if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
268
+ return self.tiled_encode(x, return_dict=return_dict)
269
+
270
+ if self.use_slicing and x.shape[0] > 1:
271
+ encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
272
+ h = torch.cat(encoded_slices)
273
+ else:
274
+ h = self.encoder(x)
275
+
276
+ moments = self.quant_conv(h)
277
+ posterior = DiagonalGaussianDistribution(moments)
278
+
279
+ if not return_dict:
280
+ return (posterior,)
281
+
282
+ return AutoencoderKLOutput(latent_dist=posterior)
283
+
284
+ def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
285
+ if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
286
+ return self.tiled_decode(z, return_dict=return_dict)
287
+
288
+ z = self.post_quant_conv(z)
289
+ dec = self.decoder(z)
290
+
291
+ if not return_dict:
292
+ return (dec,)
293
+
294
+ return DecoderOutput(sample=dec)
295
+
296
+ @apply_forward_hook
297
+ def decode(
298
+ self, z: torch.FloatTensor, return_dict: bool = True, generator=None
299
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
300
+ """
301
+ Decode a batch of images.
302
+
303
+ Args:
304
+ z (`torch.FloatTensor`): Input batch of latent vectors.
305
+ return_dict (`bool`, *optional*, defaults to `True`):
306
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
307
+
308
+ Returns:
309
+ [`~models.vae.DecoderOutput`] or `tuple`:
310
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
311
+ returned.
312
+
313
+ """
314
+ if self.use_slicing and z.shape[0] > 1:
315
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
316
+ decoded = torch.cat(decoded_slices)
317
+ else:
318
+ decoded = self._decode(z).sample
319
+
320
+ if not return_dict:
321
+ return (decoded,)
322
+
323
+ return DecoderOutput(sample=decoded)
324
+
325
+ def blend_v(self, a, b, blend_extent):
326
+ blend_extent = min(a.shape[2], b.shape[2], blend_extent)
327
+ for y in range(blend_extent):
328
+ b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
329
+ return b
330
+
331
+ def blend_h(self, a, b, blend_extent):
332
+ blend_extent = min(a.shape[3], b.shape[3], blend_extent)
333
+ for x in range(blend_extent):
334
+ b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
335
+ return b
336
+
337
+ def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
338
+ r"""Encode a batch of images using a tiled encoder.
339
+
340
+ When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
341
+ steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
342
+ different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
343
+ tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
344
+ output, but they should be much less noticeable.
345
+
346
+ Args:
347
+ x (`torch.FloatTensor`): Input batch of images.
348
+ return_dict (`bool`, *optional*, defaults to `True`):
349
+ Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
350
+
351
+ Returns:
352
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
353
+ If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
354
+ `tuple` is returned.
355
+ """
356
+ overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
357
+ blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
358
+ row_limit = self.tile_latent_min_size - blend_extent
359
+
360
+ # Split the image into 512x512 tiles and encode them separately.
361
+ rows = []
362
+ for i in range(0, x.shape[2], overlap_size):
363
+ row = []
364
+ for j in range(0, x.shape[3], overlap_size):
365
+ tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
366
+ tile = self.encoder(tile)
367
+ tile = self.quant_conv(tile)
368
+ row.append(tile)
369
+ rows.append(row)
370
+ result_rows = []
371
+ for i, row in enumerate(rows):
372
+ result_row = []
373
+ for j, tile in enumerate(row):
374
+ # blend the above tile and the left tile
375
+ # to the current tile and add the current tile to the result row
376
+ if i > 0:
377
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
378
+ if j > 0:
379
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
380
+ result_row.append(tile[:, :, :row_limit, :row_limit])
381
+ result_rows.append(torch.cat(result_row, dim=3))
382
+
383
+ moments = torch.cat(result_rows, dim=2)
384
+ posterior = DiagonalGaussianDistribution(moments)
385
+
386
+ if not return_dict:
387
+ return (posterior,)
388
+
389
+ return AutoencoderKLOutput(latent_dist=posterior)
390
+
391
+ def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
392
+ r"""
393
+ Decode a batch of images using a tiled decoder.
394
+
395
+ Args:
396
+ z (`torch.FloatTensor`): Input batch of latent vectors.
397
+ return_dict (`bool`, *optional*, defaults to `True`):
398
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
399
+
400
+ Returns:
401
+ [`~models.vae.DecoderOutput`] or `tuple`:
402
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
403
+ returned.
404
+ """
405
+ overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
406
+ blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
407
+ row_limit = self.tile_sample_min_size - blend_extent
408
+
409
+ # Split z into overlapping 64x64 tiles and decode them separately.
410
+ # The tiles have an overlap to avoid seams between tiles.
411
+ rows = []
412
+ for i in range(0, z.shape[2], overlap_size):
413
+ row = []
414
+ for j in range(0, z.shape[3], overlap_size):
415
+ tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
416
+ tile = self.post_quant_conv(tile)
417
+ decoded = self.decoder(tile)
418
+ row.append(decoded)
419
+ rows.append(row)
420
+ result_rows = []
421
+ for i, row in enumerate(rows):
422
+ result_row = []
423
+ for j, tile in enumerate(row):
424
+ # blend the above tile and the left tile
425
+ # to the current tile and add the current tile to the result row
426
+ if i > 0:
427
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
428
+ if j > 0:
429
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
430
+ result_row.append(tile[:, :, :row_limit, :row_limit])
431
+ result_rows.append(torch.cat(result_row, dim=3))
432
+
433
+ dec = torch.cat(result_rows, dim=2)
434
+ if not return_dict:
435
+ return (dec,)
436
+
437
+ return DecoderOutput(sample=dec)
438
+
439
+ def forward(
440
+ self,
441
+ sample: torch.FloatTensor,
442
+ sample_posterior: bool = False,
443
+ return_dict: bool = True,
444
+ generator: Optional[torch.Generator] = None,
445
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
446
+ r"""
447
+ Args:
448
+ sample (`torch.FloatTensor`): Input sample.
449
+ sample_posterior (`bool`, *optional*, defaults to `False`):
450
+ Whether to sample from the posterior.
451
+ return_dict (`bool`, *optional*, defaults to `True`):
452
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
453
+ """
454
+ x = sample
455
+ posterior = self.encode(x).latent_dist
456
+ if sample_posterior:
457
+ z = posterior.sample(generator=generator)
458
+ else:
459
+ z = posterior.mode()
460
+ dec = self.decode(z).sample
461
+
462
+ if not return_dict:
463
+ return (dec,)
464
+
465
+ return DecoderOutput(sample=dec)
diffusers/models/autoencoder_tiny.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Ollin Boer Bohan and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from dataclasses import dataclass
17
+ from typing import Optional, Tuple, Union
18
+
19
+ import torch
20
+
21
+ from ..configuration_utils import ConfigMixin, register_to_config
22
+ from ..utils import BaseOutput
23
+ from ..utils.accelerate_utils import apply_forward_hook
24
+ from .modeling_utils import ModelMixin
25
+ from .vae import DecoderOutput, DecoderTiny, EncoderTiny
26
+
27
+
28
+ @dataclass
29
+ class AutoencoderTinyOutput(BaseOutput):
30
+ """
31
+ Output of AutoencoderTiny encoding method.
32
+
33
+ Args:
34
+ latents (`torch.Tensor`): Encoded outputs of the `Encoder`.
35
+
36
+ """
37
+
38
+ latents: torch.Tensor
39
+
40
+
41
+ class AutoencoderTiny(ModelMixin, ConfigMixin):
42
+ r"""
43
+ A tiny distilled VAE model for encoding images into latents and decoding latent representations into images.
44
+
45
+ [`AutoencoderTiny`] is a wrapper around the original implementation of `TAESD`.
46
+
47
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for its generic methods implemented for
48
+ all models (such as downloading or saving).
49
+
50
+ Parameters:
51
+ in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image.
52
+ out_channels (`int`, *optional*, defaults to 3): Number of channels in the output.
53
+ encoder_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64, 64, 64, 64)`):
54
+ Tuple of integers representing the number of output channels for each encoder block. The length of the
55
+ tuple should be equal to the number of encoder blocks.
56
+ decoder_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64, 64, 64, 64)`):
57
+ Tuple of integers representing the number of output channels for each decoder block. The length of the
58
+ tuple should be equal to the number of decoder blocks.
59
+ act_fn (`str`, *optional*, defaults to `"relu"`):
60
+ Activation function to be used throughout the model.
61
+ latent_channels (`int`, *optional*, defaults to 4):
62
+ Number of channels in the latent representation. The latent space acts as a compressed representation of
63
+ the input image.
64
+ upsampling_scaling_factor (`int`, *optional*, defaults to 2):
65
+ Scaling factor for upsampling in the decoder. It determines the size of the output image during the
66
+ upsampling process.
67
+ num_encoder_blocks (`Tuple[int]`, *optional*, defaults to `(1, 3, 3, 3)`):
68
+ Tuple of integers representing the number of encoder blocks at each stage of the encoding process. The
69
+ length of the tuple should be equal to the number of stages in the encoder. Each stage has a different
70
+ number of encoder blocks.
71
+ num_decoder_blocks (`Tuple[int]`, *optional*, defaults to `(3, 3, 3, 1)`):
72
+ Tuple of integers representing the number of decoder blocks at each stage of the decoding process. The
73
+ length of the tuple should be equal to the number of stages in the decoder. Each stage has a different
74
+ number of decoder blocks.
75
+ latent_magnitude (`float`, *optional*, defaults to 3.0):
76
+ Magnitude of the latent representation. This parameter scales the latent representation values to control
77
+ the extent of information preservation.
78
+ latent_shift (float, *optional*, defaults to 0.5):
79
+ Shift applied to the latent representation. This parameter controls the center of the latent space.
80
+ scaling_factor (`float`, *optional*, defaults to 1.0):
81
+ The component-wise standard deviation of the trained latent space computed using the first batch of the
82
+ training set. This is used to scale the latent space to have unit variance when training the diffusion
83
+ model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
84
+ diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
85
+ / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
86
+ Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. For this Autoencoder,
87
+ however, no such scaling factor was used, hence the value of 1.0 as the default.
88
+ force_upcast (`bool`, *optional*, default to `False`):
89
+ If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
90
+ can be fine-tuned / trained to a lower range without losing too much precision, in which case
91
+ `force_upcast` can be set to `False` (see this fp16-friendly
92
+ [AutoEncoder](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).
93
+ """
94
+ _supports_gradient_checkpointing = True
95
+
96
+ @register_to_config
97
+ def __init__(
98
+ self,
99
+ in_channels=3,
100
+ out_channels=3,
101
+ encoder_block_out_channels: Tuple[int] = (64, 64, 64, 64),
102
+ decoder_block_out_channels: Tuple[int] = (64, 64, 64, 64),
103
+ act_fn: str = "relu",
104
+ latent_channels: int = 4,
105
+ upsampling_scaling_factor: int = 2,
106
+ num_encoder_blocks: Tuple[int] = (1, 3, 3, 3),
107
+ num_decoder_blocks: Tuple[int] = (3, 3, 3, 1),
108
+ latent_magnitude: int = 3,
109
+ latent_shift: float = 0.5,
110
+ force_upcast: float = False,
111
+ scaling_factor: float = 1.0,
112
+ ):
113
+ super().__init__()
114
+
115
+ if len(encoder_block_out_channels) != len(num_encoder_blocks):
116
+ raise ValueError("`encoder_block_out_channels` should have the same length as `num_encoder_blocks`.")
117
+ if len(decoder_block_out_channels) != len(num_decoder_blocks):
118
+ raise ValueError("`decoder_block_out_channels` should have the same length as `num_decoder_blocks`.")
119
+
120
+ self.encoder = EncoderTiny(
121
+ in_channels=in_channels,
122
+ out_channels=latent_channels,
123
+ num_blocks=num_encoder_blocks,
124
+ block_out_channels=encoder_block_out_channels,
125
+ act_fn=act_fn,
126
+ )
127
+
128
+ self.decoder = DecoderTiny(
129
+ in_channels=latent_channels,
130
+ out_channels=out_channels,
131
+ num_blocks=num_decoder_blocks,
132
+ block_out_channels=decoder_block_out_channels,
133
+ upsampling_scaling_factor=upsampling_scaling_factor,
134
+ act_fn=act_fn,
135
+ )
136
+
137
+ self.latent_magnitude = latent_magnitude
138
+ self.latent_shift = latent_shift
139
+ self.scaling_factor = scaling_factor
140
+
141
+ self.use_slicing = False
142
+ self.use_tiling = False
143
+
144
+ # only relevant if vae tiling is enabled
145
+ self.spatial_scale_factor = 2**out_channels
146
+ self.tile_overlap_factor = 0.125
147
+ self.tile_sample_min_size = 512
148
+ self.tile_latent_min_size = self.tile_sample_min_size // self.spatial_scale_factor
149
+
150
+ def _set_gradient_checkpointing(self, module, value=False):
151
+ if isinstance(module, (EncoderTiny, DecoderTiny)):
152
+ module.gradient_checkpointing = value
153
+
154
+ def scale_latents(self, x):
155
+ """raw latents -> [0, 1]"""
156
+ return x.div(2 * self.latent_magnitude).add(self.latent_shift).clamp(0, 1)
157
+
158
+ def unscale_latents(self, x):
159
+ """[0, 1] -> raw latents"""
160
+ return x.sub(self.latent_shift).mul(2 * self.latent_magnitude)
161
+
162
+ def enable_slicing(self):
163
+ r"""
164
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
165
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
166
+ """
167
+ self.use_slicing = True
168
+
169
+ def disable_slicing(self):
170
+ r"""
171
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
172
+ decoding in one step.
173
+ """
174
+ self.use_slicing = False
175
+
176
+ def enable_tiling(self, use_tiling: bool = True):
177
+ r"""
178
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
179
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
180
+ processing larger images.
181
+ """
182
+ self.use_tiling = use_tiling
183
+
184
+ def disable_tiling(self):
185
+ r"""
186
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
187
+ decoding in one step.
188
+ """
189
+ self.enable_tiling(False)
190
+
191
+ def _tiled_encode(self, x: torch.FloatTensor) -> torch.FloatTensor:
192
+ r"""Encode a batch of images using a tiled encoder.
193
+
194
+ When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
195
+ steps. This is useful to keep memory use constant regardless of image size. To avoid tiling artifacts, the
196
+ tiles overlap and are blended together to form a smooth output.
197
+
198
+ Args:
199
+ x (`torch.FloatTensor`): Input batch of images.
200
+ return_dict (`bool`, *optional*, defaults to `True`):
201
+ Whether or not to return a [`~models.autoencoder_tiny.AutoencoderTinyOutput`] instead of a plain tuple.
202
+
203
+ Returns:
204
+ [`~models.autoencoder_tiny.AutoencoderTinyOutput`] or `tuple`:
205
+ If return_dict is True, a [`~models.autoencoder_tiny.AutoencoderTinyOutput`] is returned, otherwise a
206
+ plain `tuple` is returned.
207
+ """
208
+ # scale of encoder output relative to input
209
+ sf = self.spatial_scale_factor
210
+ tile_size = self.tile_sample_min_size
211
+
212
+ # number of pixels to blend and to traverse between tile
213
+ blend_size = int(tile_size * self.tile_overlap_factor)
214
+ traverse_size = tile_size - blend_size
215
+
216
+ # tiles index (up/left)
217
+ ti = range(0, x.shape[-2], traverse_size)
218
+ tj = range(0, x.shape[-1], traverse_size)
219
+
220
+ # mask for blending
221
+ blend_masks = torch.stack(
222
+ torch.meshgrid([torch.arange(tile_size / sf) / (blend_size / sf - 1)] * 2, indexing="ij")
223
+ )
224
+ blend_masks = blend_masks.clamp(0, 1).to(x.device)
225
+
226
+ # output array
227
+ out = torch.zeros(x.shape[0], 4, x.shape[-2] // sf, x.shape[-1] // sf, device=x.device)
228
+ for i in ti:
229
+ for j in tj:
230
+ tile_in = x[..., i : i + tile_size, j : j + tile_size]
231
+ # tile result
232
+ tile_out = out[..., i // sf : (i + tile_size) // sf, j // sf : (j + tile_size) // sf]
233
+ tile = self.encoder(tile_in)
234
+ h, w = tile.shape[-2], tile.shape[-1]
235
+ # blend tile result into output
236
+ blend_mask_i = torch.ones_like(blend_masks[0]) if i == 0 else blend_masks[0]
237
+ blend_mask_j = torch.ones_like(blend_masks[1]) if j == 0 else blend_masks[1]
238
+ blend_mask = blend_mask_i * blend_mask_j
239
+ tile, blend_mask = tile[..., :h, :w], blend_mask[..., :h, :w]
240
+ tile_out.copy_(blend_mask * tile + (1 - blend_mask) * tile_out)
241
+ return out
242
+
243
+ def _tiled_decode(self, x: torch.FloatTensor) -> torch.FloatTensor:
244
+ r"""Encode a batch of images using a tiled encoder.
245
+
246
+ When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
247
+ steps. This is useful to keep memory use constant regardless of image size. To avoid tiling artifacts, the
248
+ tiles overlap and are blended together to form a smooth output.
249
+
250
+ Args:
251
+ x (`torch.FloatTensor`): Input batch of images.
252
+ return_dict (`bool`, *optional*, defaults to `True`):
253
+ Whether or not to return a [`~models.autoencoder_tiny.AutoencoderTinyOutput`] instead of a plain tuple.
254
+
255
+ Returns:
256
+ [`~models.vae.DecoderOutput`] or `tuple`:
257
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
258
+ returned.
259
+ """
260
+ # scale of decoder output relative to input
261
+ sf = self.spatial_scale_factor
262
+ tile_size = self.tile_latent_min_size
263
+
264
+ # number of pixels to blend and to traverse between tiles
265
+ blend_size = int(tile_size * self.tile_overlap_factor)
266
+ traverse_size = tile_size - blend_size
267
+
268
+ # tiles index (up/left)
269
+ ti = range(0, x.shape[-2], traverse_size)
270
+ tj = range(0, x.shape[-1], traverse_size)
271
+
272
+ # mask for blending
273
+ blend_masks = torch.stack(
274
+ torch.meshgrid([torch.arange(tile_size * sf) / (blend_size * sf - 1)] * 2, indexing="ij")
275
+ )
276
+ blend_masks = blend_masks.clamp(0, 1).to(x.device)
277
+
278
+ # output array
279
+ out = torch.zeros(x.shape[0], 3, x.shape[-2] * sf, x.shape[-1] * sf, device=x.device)
280
+ for i in ti:
281
+ for j in tj:
282
+ tile_in = x[..., i : i + tile_size, j : j + tile_size]
283
+ # tile result
284
+ tile_out = out[..., i * sf : (i + tile_size) * sf, j * sf : (j + tile_size) * sf]
285
+ tile = self.decoder(tile_in)
286
+ h, w = tile.shape[-2], tile.shape[-1]
287
+ # blend tile result into output
288
+ blend_mask_i = torch.ones_like(blend_masks[0]) if i == 0 else blend_masks[0]
289
+ blend_mask_j = torch.ones_like(blend_masks[1]) if j == 0 else blend_masks[1]
290
+ blend_mask = (blend_mask_i * blend_mask_j)[..., :h, :w]
291
+ tile_out.copy_(blend_mask * tile + (1 - blend_mask) * tile_out)
292
+ return out
293
+
294
+ @apply_forward_hook
295
+ def encode(
296
+ self, x: torch.FloatTensor, return_dict: bool = True
297
+ ) -> Union[AutoencoderTinyOutput, Tuple[torch.FloatTensor]]:
298
+ if self.use_slicing and x.shape[0] > 1:
299
+ output = [self._tiled_encode(x_slice) if self.use_tiling else self.encoder(x) for x_slice in x.split(1)]
300
+ output = torch.cat(output)
301
+ else:
302
+ output = self._tiled_encode(x) if self.use_tiling else self.encoder(x)
303
+
304
+ if not return_dict:
305
+ return (output,)
306
+
307
+ return AutoencoderTinyOutput(latents=output)
308
+
309
+ @apply_forward_hook
310
+ def decode(
311
+ self, x: torch.FloatTensor, generator: Optional[torch.Generator] = None, return_dict: bool = True
312
+ ) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
313
+ if self.use_slicing and x.shape[0] > 1:
314
+ output = [self._tiled_decode(x_slice) if self.use_tiling else self.decoder(x) for x_slice in x.split(1)]
315
+ output = torch.cat(output)
316
+ else:
317
+ output = self._tiled_decode(x) if self.use_tiling else self.decoder(x)
318
+
319
+ if not return_dict:
320
+ return (output,)
321
+
322
+ return DecoderOutput(sample=output)
323
+
324
+ def forward(
325
+ self,
326
+ sample: torch.FloatTensor,
327
+ return_dict: bool = True,
328
+ ) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
329
+ r"""
330
+ Args:
331
+ sample (`torch.FloatTensor`): Input sample.
332
+ return_dict (`bool`, *optional*, defaults to `True`):
333
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
334
+ """
335
+ enc = self.encode(sample).latents
336
+
337
+ # scale latents to be in [0, 1], then quantize latents to a byte tensor,
338
+ # as if we were storing the latents in an RGBA uint8 image.
339
+ scaled_enc = self.scale_latents(enc).mul_(255).round_().byte()
340
+
341
+ # unquantize latents back into [0, 1], then unscale latents back to their original range,
342
+ # as if we were loading the latents from an RGBA uint8 image.
343
+ unscaled_enc = self.unscale_latents(scaled_enc / 255.0)
344
+
345
+ dec = self.decode(unscaled_enc)
346
+
347
+ if not return_dict:
348
+ return (dec,)
349
+ return DecoderOutput(sample=dec)
diffusers/models/consistency_decoder_vae.py ADDED
@@ -0,0 +1,430 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Dict, Optional, Tuple, Union
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torch import nn
20
+
21
+ from ..configuration_utils import ConfigMixin, register_to_config
22
+ from ..schedulers import ConsistencyDecoderScheduler
23
+ from ..utils import BaseOutput
24
+ from ..utils.accelerate_utils import apply_forward_hook
25
+ from ..utils.torch_utils import randn_tensor
26
+ from .attention_processor import (
27
+ ADDED_KV_ATTENTION_PROCESSORS,
28
+ CROSS_ATTENTION_PROCESSORS,
29
+ AttentionProcessor,
30
+ AttnAddedKVProcessor,
31
+ AttnProcessor,
32
+ )
33
+ from .modeling_utils import ModelMixin
34
+ from .unet_2d import UNet2DModel
35
+ from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder
36
+
37
+
38
+ @dataclass
39
+ class ConsistencyDecoderVAEOutput(BaseOutput):
40
+ """
41
+ Output of encoding method.
42
+
43
+ Args:
44
+ latent_dist (`DiagonalGaussianDistribution`):
45
+ Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
46
+ `DiagonalGaussianDistribution` allows for sampling latents from the distribution.
47
+ """
48
+
49
+ latent_dist: "DiagonalGaussianDistribution"
50
+
51
+
52
+ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
53
+ r"""
54
+ The consistency decoder used with DALL-E 3.
55
+
56
+ Examples:
57
+ ```py
58
+ >>> import torch
59
+ >>> from diffusers import DiffusionPipeline, ConsistencyDecoderVAE
60
+
61
+ >>> vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=pipe.torch_dtype)
62
+ >>> pipe = StableDiffusionPipeline.from_pretrained(
63
+ ... "runwayml/stable-diffusion-v1-5", vae=vae, torch_dtype=torch.float16
64
+ ... ).to("cuda")
65
+
66
+ >>> pipe("horse", generator=torch.manual_seed(0)).images
67
+ ```
68
+ """
69
+
70
+ @register_to_config
71
+ def __init__(
72
+ self,
73
+ scaling_factor=0.18215,
74
+ latent_channels=4,
75
+ encoder_act_fn="silu",
76
+ encoder_block_out_channels=(128, 256, 512, 512),
77
+ encoder_double_z=True,
78
+ encoder_down_block_types=(
79
+ "DownEncoderBlock2D",
80
+ "DownEncoderBlock2D",
81
+ "DownEncoderBlock2D",
82
+ "DownEncoderBlock2D",
83
+ ),
84
+ encoder_in_channels=3,
85
+ encoder_layers_per_block=2,
86
+ encoder_norm_num_groups=32,
87
+ encoder_out_channels=4,
88
+ decoder_add_attention=False,
89
+ decoder_block_out_channels=(320, 640, 1024, 1024),
90
+ decoder_down_block_types=(
91
+ "ResnetDownsampleBlock2D",
92
+ "ResnetDownsampleBlock2D",
93
+ "ResnetDownsampleBlock2D",
94
+ "ResnetDownsampleBlock2D",
95
+ ),
96
+ decoder_downsample_padding=1,
97
+ decoder_in_channels=7,
98
+ decoder_layers_per_block=3,
99
+ decoder_norm_eps=1e-05,
100
+ decoder_norm_num_groups=32,
101
+ decoder_num_train_timesteps=1024,
102
+ decoder_out_channels=6,
103
+ decoder_resnet_time_scale_shift="scale_shift",
104
+ decoder_time_embedding_type="learned",
105
+ decoder_up_block_types=(
106
+ "ResnetUpsampleBlock2D",
107
+ "ResnetUpsampleBlock2D",
108
+ "ResnetUpsampleBlock2D",
109
+ "ResnetUpsampleBlock2D",
110
+ ),
111
+ ):
112
+ super().__init__()
113
+ self.encoder = Encoder(
114
+ act_fn=encoder_act_fn,
115
+ block_out_channels=encoder_block_out_channels,
116
+ double_z=encoder_double_z,
117
+ down_block_types=encoder_down_block_types,
118
+ in_channels=encoder_in_channels,
119
+ layers_per_block=encoder_layers_per_block,
120
+ norm_num_groups=encoder_norm_num_groups,
121
+ out_channels=encoder_out_channels,
122
+ )
123
+
124
+ self.decoder_unet = UNet2DModel(
125
+ add_attention=decoder_add_attention,
126
+ block_out_channels=decoder_block_out_channels,
127
+ down_block_types=decoder_down_block_types,
128
+ downsample_padding=decoder_downsample_padding,
129
+ in_channels=decoder_in_channels,
130
+ layers_per_block=decoder_layers_per_block,
131
+ norm_eps=decoder_norm_eps,
132
+ norm_num_groups=decoder_norm_num_groups,
133
+ num_train_timesteps=decoder_num_train_timesteps,
134
+ out_channels=decoder_out_channels,
135
+ resnet_time_scale_shift=decoder_resnet_time_scale_shift,
136
+ time_embedding_type=decoder_time_embedding_type,
137
+ up_block_types=decoder_up_block_types,
138
+ )
139
+ self.decoder_scheduler = ConsistencyDecoderScheduler()
140
+ self.register_to_config(block_out_channels=encoder_block_out_channels)
141
+ self.register_buffer(
142
+ "means",
143
+ torch.tensor([0.38862467, 0.02253063, 0.07381133, -0.0171294])[None, :, None, None],
144
+ persistent=False,
145
+ )
146
+ self.register_buffer(
147
+ "stds", torch.tensor([0.9654121, 1.0440036, 0.76147926, 0.77022034])[None, :, None, None], persistent=False
148
+ )
149
+
150
+ self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
151
+
152
+ self.use_slicing = False
153
+ self.use_tiling = False
154
+
155
+ # Copied from diffusers.models.autoencoder_kl.AutoencoderKL.enable_tiling
156
+ def enable_tiling(self, use_tiling: bool = True):
157
+ r"""
158
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
159
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
160
+ processing larger images.
161
+ """
162
+ self.use_tiling = use_tiling
163
+
164
+ # Copied from diffusers.models.autoencoder_kl.AutoencoderKL.disable_tiling
165
+ def disable_tiling(self):
166
+ r"""
167
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
168
+ decoding in one step.
169
+ """
170
+ self.enable_tiling(False)
171
+
172
+ # Copied from diffusers.models.autoencoder_kl.AutoencoderKL.enable_slicing
173
+ def enable_slicing(self):
174
+ r"""
175
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
176
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
177
+ """
178
+ self.use_slicing = True
179
+
180
+ # Copied from diffusers.models.autoencoder_kl.AutoencoderKL.disable_slicing
181
+ def disable_slicing(self):
182
+ r"""
183
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
184
+ decoding in one step.
185
+ """
186
+ self.use_slicing = False
187
+
188
+ @property
189
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
190
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
191
+ r"""
192
+ Returns:
193
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
194
+ indexed by its weight name.
195
+ """
196
+ # set recursively
197
+ processors = {}
198
+
199
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
200
+ if hasattr(module, "get_processor"):
201
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
202
+
203
+ for sub_name, child in module.named_children():
204
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
205
+
206
+ return processors
207
+
208
+ for name, module in self.named_children():
209
+ fn_recursive_add_processors(name, module, processors)
210
+
211
+ return processors
212
+
213
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
214
+ def set_attn_processor(
215
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
216
+ ):
217
+ r"""
218
+ Sets the attention processor to use to compute attention.
219
+
220
+ Parameters:
221
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
222
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
223
+ for **all** `Attention` layers.
224
+
225
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
226
+ processor. This is strongly recommended when setting trainable attention processors.
227
+
228
+ """
229
+ count = len(self.attn_processors.keys())
230
+
231
+ if isinstance(processor, dict) and len(processor) != count:
232
+ raise ValueError(
233
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
234
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
235
+ )
236
+
237
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
238
+ if hasattr(module, "set_processor"):
239
+ if not isinstance(processor, dict):
240
+ module.set_processor(processor, _remove_lora=_remove_lora)
241
+ else:
242
+ module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
243
+
244
+ for sub_name, child in module.named_children():
245
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
246
+
247
+ for name, module in self.named_children():
248
+ fn_recursive_attn_processor(name, module, processor)
249
+
250
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
251
+ def set_default_attn_processor(self):
252
+ """
253
+ Disables custom attention processors and sets the default attention implementation.
254
+ """
255
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
256
+ processor = AttnAddedKVProcessor()
257
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
258
+ processor = AttnProcessor()
259
+ else:
260
+ raise ValueError(
261
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
262
+ )
263
+
264
+ self.set_attn_processor(processor, _remove_lora=True)
265
+
266
+ @apply_forward_hook
267
+ def encode(
268
+ self, x: torch.FloatTensor, return_dict: bool = True
269
+ ) -> Union[ConsistencyDecoderVAEOutput, Tuple[DiagonalGaussianDistribution]]:
270
+ """
271
+ Encode a batch of images into latents.
272
+
273
+ Args:
274
+ x (`torch.FloatTensor`): Input batch of images.
275
+ return_dict (`bool`, *optional*, defaults to `True`):
276
+ Whether to return a [`~models.consistecy_decoder_vae.ConsistencyDecoderOoutput`] instead of a plain
277
+ tuple.
278
+
279
+ Returns:
280
+ The latent representations of the encoded images. If `return_dict` is True, a
281
+ [`~models.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] is returned, otherwise a plain `tuple`
282
+ is returned.
283
+ """
284
+ if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
285
+ return self.tiled_encode(x, return_dict=return_dict)
286
+
287
+ if self.use_slicing and x.shape[0] > 1:
288
+ encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
289
+ h = torch.cat(encoded_slices)
290
+ else:
291
+ h = self.encoder(x)
292
+
293
+ moments = self.quant_conv(h)
294
+ posterior = DiagonalGaussianDistribution(moments)
295
+
296
+ if not return_dict:
297
+ return (posterior,)
298
+
299
+ return ConsistencyDecoderVAEOutput(latent_dist=posterior)
300
+
301
+ @apply_forward_hook
302
+ def decode(
303
+ self,
304
+ z: torch.FloatTensor,
305
+ generator: Optional[torch.Generator] = None,
306
+ return_dict: bool = True,
307
+ num_inference_steps=2,
308
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
309
+ z = (z * self.config.scaling_factor - self.means) / self.stds
310
+
311
+ scale_factor = 2 ** (len(self.config.block_out_channels) - 1)
312
+ z = F.interpolate(z, mode="nearest", scale_factor=scale_factor)
313
+
314
+ batch_size, _, height, width = z.shape
315
+
316
+ self.decoder_scheduler.set_timesteps(num_inference_steps, device=self.device)
317
+
318
+ x_t = self.decoder_scheduler.init_noise_sigma * randn_tensor(
319
+ (batch_size, 3, height, width), generator=generator, dtype=z.dtype, device=z.device
320
+ )
321
+
322
+ for t in self.decoder_scheduler.timesteps:
323
+ model_input = torch.concat([self.decoder_scheduler.scale_model_input(x_t, t), z], dim=1)
324
+ model_output = self.decoder_unet(model_input, t).sample[:, :3, :, :]
325
+ prev_sample = self.decoder_scheduler.step(model_output, t, x_t, generator).prev_sample
326
+ x_t = prev_sample
327
+
328
+ x_0 = x_t
329
+
330
+ if not return_dict:
331
+ return (x_0,)
332
+
333
+ return DecoderOutput(sample=x_0)
334
+
335
+ # Copied from diffusers.models.autoencoder_kl.AutoencoderKL.blend_v
336
+ def blend_v(self, a, b, blend_extent):
337
+ blend_extent = min(a.shape[2], b.shape[2], blend_extent)
338
+ for y in range(blend_extent):
339
+ b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
340
+ return b
341
+
342
+ # Copied from diffusers.models.autoencoder_kl.AutoencoderKL.blend_h
343
+ def blend_h(self, a, b, blend_extent):
344
+ blend_extent = min(a.shape[3], b.shape[3], blend_extent)
345
+ for x in range(blend_extent):
346
+ b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
347
+ return b
348
+
349
+ def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> ConsistencyDecoderVAEOutput:
350
+ r"""Encode a batch of images using a tiled encoder.
351
+
352
+ When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
353
+ steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
354
+ different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
355
+ tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
356
+ output, but they should be much less noticeable.
357
+
358
+ Args:
359
+ x (`torch.FloatTensor`): Input batch of images.
360
+ return_dict (`bool`, *optional*, defaults to `True`):
361
+ Whether or not to return a [`~models.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] instead of a
362
+ plain tuple.
363
+
364
+ Returns:
365
+ [`~models.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] or `tuple`:
366
+ If return_dict is True, a [`~models.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] is returned,
367
+ otherwise a plain `tuple` is returned.
368
+ """
369
+ overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
370
+ blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
371
+ row_limit = self.tile_latent_min_size - blend_extent
372
+
373
+ # Split the image into 512x512 tiles and encode them separately.
374
+ rows = []
375
+ for i in range(0, x.shape[2], overlap_size):
376
+ row = []
377
+ for j in range(0, x.shape[3], overlap_size):
378
+ tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
379
+ tile = self.encoder(tile)
380
+ tile = self.quant_conv(tile)
381
+ row.append(tile)
382
+ rows.append(row)
383
+ result_rows = []
384
+ for i, row in enumerate(rows):
385
+ result_row = []
386
+ for j, tile in enumerate(row):
387
+ # blend the above tile and the left tile
388
+ # to the current tile and add the current tile to the result row
389
+ if i > 0:
390
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
391
+ if j > 0:
392
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
393
+ result_row.append(tile[:, :, :row_limit, :row_limit])
394
+ result_rows.append(torch.cat(result_row, dim=3))
395
+
396
+ moments = torch.cat(result_rows, dim=2)
397
+ posterior = DiagonalGaussianDistribution(moments)
398
+
399
+ if not return_dict:
400
+ return (posterior,)
401
+
402
+ return ConsistencyDecoderVAEOutput(latent_dist=posterior)
403
+
404
+ def forward(
405
+ self,
406
+ sample: torch.FloatTensor,
407
+ sample_posterior: bool = False,
408
+ return_dict: bool = True,
409
+ generator: Optional[torch.Generator] = None,
410
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
411
+ r"""
412
+ Args:
413
+ sample (`torch.FloatTensor`): Input sample.
414
+ sample_posterior (`bool`, *optional*, defaults to `False`):
415
+ Whether to sample from the posterior.
416
+ return_dict (`bool`, *optional*, defaults to `True`):
417
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
418
+ """
419
+ x = sample
420
+ posterior = self.encode(x).latent_dist
421
+ if sample_posterior:
422
+ z = posterior.sample(generator=generator)
423
+ else:
424
+ z = posterior.mode()
425
+ dec = self.decode(z, generator=generator).sample
426
+
427
+ if not return_dict:
428
+ return (dec,)
429
+
430
+ return DecoderOutput(sample=dec)
diffusers/models/controlnet.py ADDED
@@ -0,0 +1,844 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, List, Optional, Tuple, Union
16
+
17
+ import torch
18
+ from torch import nn
19
+ from torch.nn import functional as F
20
+
21
+ from ..configuration_utils import ConfigMixin, register_to_config
22
+ from ..loaders import FromOriginalControlnetMixin
23
+ from ..utils import BaseOutput, logging
24
+ from .attention_processor import (
25
+ ADDED_KV_ATTENTION_PROCESSORS,
26
+ CROSS_ATTENTION_PROCESSORS,
27
+ AttentionProcessor,
28
+ AttnAddedKVProcessor,
29
+ AttnProcessor,
30
+ )
31
+ from .embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
32
+ from .modeling_utils import ModelMixin
33
+ from .unet_2d_blocks import (
34
+ CrossAttnDownBlock2D,
35
+ DownBlock2D,
36
+ UNetMidBlock2DCrossAttn,
37
+ get_down_block,
38
+ )
39
+ from .unet_2d_condition import UNet2DConditionModel
40
+
41
+
42
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
43
+
44
+
45
+ @dataclass
46
+ class ControlNetOutput(BaseOutput):
47
+ """
48
+ The output of [`ControlNetModel`].
49
+
50
+ Args:
51
+ down_block_res_samples (`tuple[torch.Tensor]`):
52
+ A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
53
+ be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
54
+ used to condition the original UNet's downsampling activations.
55
+ mid_down_block_re_sample (`torch.Tensor`):
56
+ The activation of the midde block (the lowest sample resolution). Each tensor should be of shape
57
+ `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
58
+ Output can be used to condition the original UNet's middle block activation.
59
+ """
60
+
61
+ down_block_res_samples: Tuple[torch.Tensor]
62
+ mid_block_res_sample: torch.Tensor
63
+
64
+
65
+ class ControlNetConditioningEmbedding(nn.Module):
66
+ """
67
+ Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
68
+ [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
69
+ training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
70
+ convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
71
+ (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
72
+ model) to encode image-space conditions ... into feature maps ..."
73
+ """
74
+
75
+ def __init__(
76
+ self,
77
+ conditioning_embedding_channels: int,
78
+ conditioning_channels: int = 3,
79
+ block_out_channels: Tuple[int] = (16, 32, 96, 256),
80
+ ):
81
+ super().__init__()
82
+
83
+ self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
84
+
85
+ self.blocks = nn.ModuleList([])
86
+
87
+ for i in range(len(block_out_channels) - 1):
88
+ channel_in = block_out_channels[i]
89
+ channel_out = block_out_channels[i + 1]
90
+ self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
91
+ self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
92
+
93
+ self.conv_out = zero_module(
94
+ nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
95
+ )
96
+
97
+ def forward(self, conditioning):
98
+ embedding = self.conv_in(conditioning)
99
+ embedding = F.silu(embedding)
100
+
101
+ for block in self.blocks:
102
+ embedding = block(embedding)
103
+ embedding = F.silu(embedding)
104
+
105
+ embedding = self.conv_out(embedding)
106
+
107
+ return embedding
108
+
109
+
110
+ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
111
+ """
112
+ A ControlNet model.
113
+
114
+ Args:
115
+ in_channels (`int`, defaults to 4):
116
+ The number of channels in the input sample.
117
+ flip_sin_to_cos (`bool`, defaults to `True`):
118
+ Whether to flip the sin to cos in the time embedding.
119
+ freq_shift (`int`, defaults to 0):
120
+ The frequency shift to apply to the time embedding.
121
+ down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
122
+ The tuple of downsample blocks to use.
123
+ only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
124
+ block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
125
+ The tuple of output channels for each block.
126
+ layers_per_block (`int`, defaults to 2):
127
+ The number of layers per block.
128
+ downsample_padding (`int`, defaults to 1):
129
+ The padding to use for the downsampling convolution.
130
+ mid_block_scale_factor (`float`, defaults to 1):
131
+ The scale factor to use for the mid block.
132
+ act_fn (`str`, defaults to "silu"):
133
+ The activation function to use.
134
+ norm_num_groups (`int`, *optional*, defaults to 32):
135
+ The number of groups to use for the normalization. If None, normalization and activation layers is skipped
136
+ in post-processing.
137
+ norm_eps (`float`, defaults to 1e-5):
138
+ The epsilon to use for the normalization.
139
+ cross_attention_dim (`int`, defaults to 1280):
140
+ The dimension of the cross attention features.
141
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
142
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
143
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
144
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
145
+ encoder_hid_dim (`int`, *optional*, defaults to None):
146
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
147
+ dimension to `cross_attention_dim`.
148
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
149
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
150
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
151
+ attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
152
+ The dimension of the attention heads.
153
+ use_linear_projection (`bool`, defaults to `False`):
154
+ class_embed_type (`str`, *optional*, defaults to `None`):
155
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
156
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
157
+ addition_embed_type (`str`, *optional*, defaults to `None`):
158
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
159
+ "text". "text" will use the `TextTimeEmbedding` layer.
160
+ num_class_embeds (`int`, *optional*, defaults to 0):
161
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
162
+ class conditioning with `class_embed_type` equal to `None`.
163
+ upcast_attention (`bool`, defaults to `False`):
164
+ resnet_time_scale_shift (`str`, defaults to `"default"`):
165
+ Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
166
+ projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
167
+ The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
168
+ `class_embed_type="projection"`.
169
+ controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
170
+ The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
171
+ conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
172
+ The tuple of output channel for each block in the `conditioning_embedding` layer.
173
+ global_pool_conditions (`bool`, defaults to `False`):
174
+ """
175
+
176
+ _supports_gradient_checkpointing = True
177
+
178
+ @register_to_config
179
+ def __init__(
180
+ self,
181
+ in_channels: int = 4,
182
+ conditioning_channels: int = 3,
183
+ flip_sin_to_cos: bool = True,
184
+ freq_shift: int = 0,
185
+ down_block_types: Tuple[str] = (
186
+ "CrossAttnDownBlock2D",
187
+ "CrossAttnDownBlock2D",
188
+ "CrossAttnDownBlock2D",
189
+ "DownBlock2D",
190
+ ),
191
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
192
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
193
+ layers_per_block: int = 2,
194
+ downsample_padding: int = 1,
195
+ mid_block_scale_factor: float = 1,
196
+ act_fn: str = "silu",
197
+ norm_num_groups: Optional[int] = 32,
198
+ norm_eps: float = 1e-5,
199
+ cross_attention_dim: int = 1280,
200
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
201
+ encoder_hid_dim: Optional[int] = None,
202
+ encoder_hid_dim_type: Optional[str] = None,
203
+ attention_head_dim: Union[int, Tuple[int]] = 8,
204
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
205
+ use_linear_projection: bool = False,
206
+ class_embed_type: Optional[str] = None,
207
+ addition_embed_type: Optional[str] = None,
208
+ addition_time_embed_dim: Optional[int] = None,
209
+ num_class_embeds: Optional[int] = None,
210
+ upcast_attention: bool = False,
211
+ resnet_time_scale_shift: str = "default",
212
+ projection_class_embeddings_input_dim: Optional[int] = None,
213
+ controlnet_conditioning_channel_order: str = "rgb",
214
+ conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
215
+ global_pool_conditions: bool = False,
216
+ addition_embed_type_num_heads=64,
217
+ ):
218
+ super().__init__()
219
+
220
+ # If `num_attention_heads` is not defined (which is the case for most models)
221
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
222
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
223
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
224
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
225
+ # which is why we correct for the naming here.
226
+ num_attention_heads = num_attention_heads or attention_head_dim
227
+
228
+ # Check inputs
229
+ if len(block_out_channels) != len(down_block_types):
230
+ raise ValueError(
231
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
232
+ )
233
+
234
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
235
+ raise ValueError(
236
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
237
+ )
238
+
239
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
240
+ raise ValueError(
241
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
242
+ )
243
+
244
+ if isinstance(transformer_layers_per_block, int):
245
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
246
+
247
+ # input
248
+ conv_in_kernel = 3
249
+ conv_in_padding = (conv_in_kernel - 1) // 2
250
+ self.conv_in = nn.Conv2d(
251
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
252
+ )
253
+
254
+ # time
255
+ time_embed_dim = block_out_channels[0] * 4
256
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
257
+ timestep_input_dim = block_out_channels[0]
258
+ self.time_embedding = TimestepEmbedding(
259
+ timestep_input_dim,
260
+ time_embed_dim,
261
+ act_fn=act_fn,
262
+ )
263
+
264
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
265
+ encoder_hid_dim_type = "text_proj"
266
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
267
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
268
+
269
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
270
+ raise ValueError(
271
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
272
+ )
273
+
274
+ if encoder_hid_dim_type == "text_proj":
275
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
276
+ elif encoder_hid_dim_type == "text_image_proj":
277
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
278
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
279
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
280
+ self.encoder_hid_proj = TextImageProjection(
281
+ text_embed_dim=encoder_hid_dim,
282
+ image_embed_dim=cross_attention_dim,
283
+ cross_attention_dim=cross_attention_dim,
284
+ )
285
+
286
+ elif encoder_hid_dim_type is not None:
287
+ raise ValueError(
288
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
289
+ )
290
+ else:
291
+ self.encoder_hid_proj = None
292
+
293
+ # class embedding
294
+ if class_embed_type is None and num_class_embeds is not None:
295
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
296
+ elif class_embed_type == "timestep":
297
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
298
+ elif class_embed_type == "identity":
299
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
300
+ elif class_embed_type == "projection":
301
+ if projection_class_embeddings_input_dim is None:
302
+ raise ValueError(
303
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
304
+ )
305
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
306
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
307
+ # 2. it projects from an arbitrary input dimension.
308
+ #
309
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
310
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
311
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
312
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
313
+ else:
314
+ self.class_embedding = None
315
+
316
+ if addition_embed_type == "text":
317
+ if encoder_hid_dim is not None:
318
+ text_time_embedding_from_dim = encoder_hid_dim
319
+ else:
320
+ text_time_embedding_from_dim = cross_attention_dim
321
+
322
+ self.add_embedding = TextTimeEmbedding(
323
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
324
+ )
325
+ elif addition_embed_type == "text_image":
326
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
327
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
328
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
329
+ self.add_embedding = TextImageTimeEmbedding(
330
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
331
+ )
332
+ elif addition_embed_type == "text_time":
333
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
334
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
335
+
336
+ elif addition_embed_type is not None:
337
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
338
+
339
+ # control net conditioning embedding
340
+ self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
341
+ conditioning_embedding_channels=block_out_channels[0],
342
+ block_out_channels=conditioning_embedding_out_channels,
343
+ conditioning_channels=conditioning_channels,
344
+ )
345
+
346
+ self.down_blocks = nn.ModuleList([])
347
+ self.controlnet_down_blocks = nn.ModuleList([])
348
+
349
+ if isinstance(only_cross_attention, bool):
350
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
351
+
352
+ if isinstance(attention_head_dim, int):
353
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
354
+
355
+ if isinstance(num_attention_heads, int):
356
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
357
+
358
+ # down
359
+ output_channel = block_out_channels[0]
360
+
361
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
362
+ controlnet_block = zero_module(controlnet_block)
363
+ self.controlnet_down_blocks.append(controlnet_block)
364
+
365
+ for i, down_block_type in enumerate(down_block_types):
366
+ input_channel = output_channel
367
+ output_channel = block_out_channels[i]
368
+ is_final_block = i == len(block_out_channels) - 1
369
+
370
+ down_block = get_down_block(
371
+ down_block_type,
372
+ num_layers=layers_per_block,
373
+ transformer_layers_per_block=transformer_layers_per_block[i],
374
+ in_channels=input_channel,
375
+ out_channels=output_channel,
376
+ temb_channels=time_embed_dim,
377
+ add_downsample=not is_final_block,
378
+ resnet_eps=norm_eps,
379
+ resnet_act_fn=act_fn,
380
+ resnet_groups=norm_num_groups,
381
+ cross_attention_dim=cross_attention_dim,
382
+ num_attention_heads=num_attention_heads[i],
383
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
384
+ downsample_padding=downsample_padding,
385
+ use_linear_projection=use_linear_projection,
386
+ only_cross_attention=only_cross_attention[i],
387
+ upcast_attention=upcast_attention,
388
+ resnet_time_scale_shift=resnet_time_scale_shift,
389
+ )
390
+ self.down_blocks.append(down_block)
391
+
392
+ for _ in range(layers_per_block):
393
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
394
+ controlnet_block = zero_module(controlnet_block)
395
+ self.controlnet_down_blocks.append(controlnet_block)
396
+
397
+ if not is_final_block:
398
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
399
+ controlnet_block = zero_module(controlnet_block)
400
+ self.controlnet_down_blocks.append(controlnet_block)
401
+
402
+ # mid
403
+ mid_block_channel = block_out_channels[-1]
404
+
405
+ controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
406
+ controlnet_block = zero_module(controlnet_block)
407
+ self.controlnet_mid_block = controlnet_block
408
+
409
+ self.mid_block = UNetMidBlock2DCrossAttn(
410
+ transformer_layers_per_block=transformer_layers_per_block[-1],
411
+ in_channels=mid_block_channel,
412
+ temb_channels=time_embed_dim,
413
+ resnet_eps=norm_eps,
414
+ resnet_act_fn=act_fn,
415
+ output_scale_factor=mid_block_scale_factor,
416
+ resnet_time_scale_shift=resnet_time_scale_shift,
417
+ cross_attention_dim=cross_attention_dim,
418
+ num_attention_heads=num_attention_heads[-1],
419
+ resnet_groups=norm_num_groups,
420
+ use_linear_projection=use_linear_projection,
421
+ upcast_attention=upcast_attention,
422
+ )
423
+
424
+ @classmethod
425
+ def from_unet(
426
+ cls,
427
+ unet: UNet2DConditionModel,
428
+ controlnet_conditioning_channel_order: str = "rgb",
429
+ conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
430
+ load_weights_from_unet: bool = True,
431
+ ):
432
+ r"""
433
+ Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`].
434
+
435
+ Parameters:
436
+ unet (`UNet2DConditionModel`):
437
+ The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied
438
+ where applicable.
439
+ """
440
+ transformer_layers_per_block = (
441
+ unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
442
+ )
443
+ encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
444
+ encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
445
+ addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
446
+ addition_time_embed_dim = (
447
+ unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
448
+ )
449
+
450
+ controlnet = cls(
451
+ encoder_hid_dim=encoder_hid_dim,
452
+ encoder_hid_dim_type=encoder_hid_dim_type,
453
+ addition_embed_type=addition_embed_type,
454
+ addition_time_embed_dim=addition_time_embed_dim,
455
+ transformer_layers_per_block=transformer_layers_per_block,
456
+ in_channels=unet.config.in_channels,
457
+ flip_sin_to_cos=unet.config.flip_sin_to_cos,
458
+ freq_shift=unet.config.freq_shift,
459
+ down_block_types=unet.config.down_block_types,
460
+ only_cross_attention=unet.config.only_cross_attention,
461
+ block_out_channels=unet.config.block_out_channels,
462
+ layers_per_block=unet.config.layers_per_block,
463
+ downsample_padding=unet.config.downsample_padding,
464
+ mid_block_scale_factor=unet.config.mid_block_scale_factor,
465
+ act_fn=unet.config.act_fn,
466
+ norm_num_groups=unet.config.norm_num_groups,
467
+ norm_eps=unet.config.norm_eps,
468
+ cross_attention_dim=unet.config.cross_attention_dim,
469
+ attention_head_dim=unet.config.attention_head_dim,
470
+ num_attention_heads=unet.config.num_attention_heads,
471
+ use_linear_projection=unet.config.use_linear_projection,
472
+ class_embed_type=unet.config.class_embed_type,
473
+ num_class_embeds=unet.config.num_class_embeds,
474
+ upcast_attention=unet.config.upcast_attention,
475
+ resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
476
+ projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
477
+ controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
478
+ conditioning_embedding_out_channels=conditioning_embedding_out_channels,
479
+ )
480
+
481
+ if load_weights_from_unet:
482
+ controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
483
+ controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
484
+ controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
485
+
486
+ if controlnet.class_embedding:
487
+ controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
488
+
489
+ controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict())
490
+ controlnet.mid_block.load_state_dict(unet.mid_block.state_dict())
491
+
492
+ return controlnet
493
+
494
+ @property
495
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
496
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
497
+ r"""
498
+ Returns:
499
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
500
+ indexed by its weight name.
501
+ """
502
+ # set recursively
503
+ processors = {}
504
+
505
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
506
+ if hasattr(module, "get_processor"):
507
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
508
+
509
+ for sub_name, child in module.named_children():
510
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
511
+
512
+ return processors
513
+
514
+ for name, module in self.named_children():
515
+ fn_recursive_add_processors(name, module, processors)
516
+
517
+ return processors
518
+
519
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
520
+ def set_attn_processor(
521
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
522
+ ):
523
+ r"""
524
+ Sets the attention processor to use to compute attention.
525
+
526
+ Parameters:
527
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
528
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
529
+ for **all** `Attention` layers.
530
+
531
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
532
+ processor. This is strongly recommended when setting trainable attention processors.
533
+
534
+ """
535
+ count = len(self.attn_processors.keys())
536
+
537
+ if isinstance(processor, dict) and len(processor) != count:
538
+ raise ValueError(
539
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
540
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
541
+ )
542
+
543
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
544
+ if hasattr(module, "set_processor"):
545
+ if not isinstance(processor, dict):
546
+ module.set_processor(processor, _remove_lora=_remove_lora)
547
+ else:
548
+ module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
549
+
550
+ for sub_name, child in module.named_children():
551
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
552
+
553
+ for name, module in self.named_children():
554
+ fn_recursive_attn_processor(name, module, processor)
555
+
556
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
557
+ def set_default_attn_processor(self):
558
+ """
559
+ Disables custom attention processors and sets the default attention implementation.
560
+ """
561
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
562
+ processor = AttnAddedKVProcessor()
563
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
564
+ processor = AttnProcessor()
565
+ else:
566
+ raise ValueError(
567
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
568
+ )
569
+
570
+ self.set_attn_processor(processor, _remove_lora=True)
571
+
572
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
573
+ def set_attention_slice(self, slice_size):
574
+ r"""
575
+ Enable sliced attention computation.
576
+
577
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
578
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
579
+
580
+ Args:
581
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
582
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
583
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
584
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
585
+ must be a multiple of `slice_size`.
586
+ """
587
+ sliceable_head_dims = []
588
+
589
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
590
+ if hasattr(module, "set_attention_slice"):
591
+ sliceable_head_dims.append(module.sliceable_head_dim)
592
+
593
+ for child in module.children():
594
+ fn_recursive_retrieve_sliceable_dims(child)
595
+
596
+ # retrieve number of attention layers
597
+ for module in self.children():
598
+ fn_recursive_retrieve_sliceable_dims(module)
599
+
600
+ num_sliceable_layers = len(sliceable_head_dims)
601
+
602
+ if slice_size == "auto":
603
+ # half the attention head size is usually a good trade-off between
604
+ # speed and memory
605
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
606
+ elif slice_size == "max":
607
+ # make smallest slice possible
608
+ slice_size = num_sliceable_layers * [1]
609
+
610
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
611
+
612
+ if len(slice_size) != len(sliceable_head_dims):
613
+ raise ValueError(
614
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
615
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
616
+ )
617
+
618
+ for i in range(len(slice_size)):
619
+ size = slice_size[i]
620
+ dim = sliceable_head_dims[i]
621
+ if size is not None and size > dim:
622
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
623
+
624
+ # Recursively walk through all the children.
625
+ # Any children which exposes the set_attention_slice method
626
+ # gets the message
627
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
628
+ if hasattr(module, "set_attention_slice"):
629
+ module.set_attention_slice(slice_size.pop())
630
+
631
+ for child in module.children():
632
+ fn_recursive_set_attention_slice(child, slice_size)
633
+
634
+ reversed_slice_size = list(reversed(slice_size))
635
+ for module in self.children():
636
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
637
+
638
+ def _set_gradient_checkpointing(self, module, value=False):
639
+ if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
640
+ module.gradient_checkpointing = value
641
+
642
+ def forward(
643
+ self,
644
+ sample: torch.FloatTensor,
645
+ timestep: Union[torch.Tensor, float, int],
646
+ encoder_hidden_states: torch.Tensor,
647
+ controlnet_cond: torch.FloatTensor,
648
+ conditioning_scale: float = 1.0,
649
+ class_labels: Optional[torch.Tensor] = None,
650
+ timestep_cond: Optional[torch.Tensor] = None,
651
+ attention_mask: Optional[torch.Tensor] = None,
652
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
653
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
654
+ guess_mode: bool = False,
655
+ return_dict: bool = True,
656
+ ) -> Union[ControlNetOutput, Tuple]:
657
+ """
658
+ The [`ControlNetModel`] forward method.
659
+
660
+ Args:
661
+ sample (`torch.FloatTensor`):
662
+ The noisy input tensor.
663
+ timestep (`Union[torch.Tensor, float, int]`):
664
+ The number of timesteps to denoise an input.
665
+ encoder_hidden_states (`torch.Tensor`):
666
+ The encoder hidden states.
667
+ controlnet_cond (`torch.FloatTensor`):
668
+ The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
669
+ conditioning_scale (`float`, defaults to `1.0`):
670
+ The scale factor for ControlNet outputs.
671
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
672
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
673
+ timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
674
+ Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
675
+ timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
676
+ embeddings.
677
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
678
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
679
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
680
+ negative values to the attention scores corresponding to "discard" tokens.
681
+ added_cond_kwargs (`dict`):
682
+ Additional conditions for the Stable Diffusion XL UNet.
683
+ cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
684
+ A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
685
+ guess_mode (`bool`, defaults to `False`):
686
+ In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
687
+ you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
688
+ return_dict (`bool`, defaults to `True`):
689
+ Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
690
+
691
+ Returns:
692
+ [`~models.controlnet.ControlNetOutput`] **or** `tuple`:
693
+ If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
694
+ returned where the first element is the sample tensor.
695
+ """
696
+ # check channel order
697
+ channel_order = self.config.controlnet_conditioning_channel_order
698
+
699
+ if channel_order == "rgb":
700
+ # in rgb order by default
701
+ ...
702
+ elif channel_order == "bgr":
703
+ controlnet_cond = torch.flip(controlnet_cond, dims=[1])
704
+ else:
705
+ raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
706
+
707
+ # prepare attention_mask
708
+ if attention_mask is not None:
709
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
710
+ attention_mask = attention_mask.unsqueeze(1)
711
+
712
+ # 1. time
713
+ timesteps = timestep
714
+ if not torch.is_tensor(timesteps):
715
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
716
+ # This would be a good case for the `match` statement (Python 3.10+)
717
+ is_mps = sample.device.type == "mps"
718
+ if isinstance(timestep, float):
719
+ dtype = torch.float32 if is_mps else torch.float64
720
+ else:
721
+ dtype = torch.int32 if is_mps else torch.int64
722
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
723
+ elif len(timesteps.shape) == 0:
724
+ timesteps = timesteps[None].to(sample.device)
725
+
726
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
727
+ timesteps = timesteps.expand(sample.shape[0])
728
+
729
+ t_emb = self.time_proj(timesteps)
730
+
731
+ # timesteps does not contain any weights and will always return f32 tensors
732
+ # but time_embedding might actually be running in fp16. so we need to cast here.
733
+ # there might be better ways to encapsulate this.
734
+ t_emb = t_emb.to(dtype=sample.dtype)
735
+
736
+ emb = self.time_embedding(t_emb, timestep_cond)
737
+ aug_emb = None
738
+
739
+ if self.class_embedding is not None:
740
+ if class_labels is None:
741
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
742
+
743
+ if self.config.class_embed_type == "timestep":
744
+ class_labels = self.time_proj(class_labels)
745
+
746
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
747
+ emb = emb + class_emb
748
+
749
+ if self.config.addition_embed_type is not None:
750
+ if self.config.addition_embed_type == "text":
751
+ aug_emb = self.add_embedding(encoder_hidden_states)
752
+
753
+ elif self.config.addition_embed_type == "text_time":
754
+ if "text_embeds" not in added_cond_kwargs:
755
+ raise ValueError(
756
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
757
+ )
758
+ text_embeds = added_cond_kwargs.get("text_embeds")
759
+ if "time_ids" not in added_cond_kwargs:
760
+ raise ValueError(
761
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
762
+ )
763
+ time_ids = added_cond_kwargs.get("time_ids")
764
+ time_embeds = self.add_time_proj(time_ids.flatten())
765
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
766
+
767
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
768
+ add_embeds = add_embeds.to(emb.dtype)
769
+ aug_emb = self.add_embedding(add_embeds)
770
+
771
+ emb = emb + aug_emb if aug_emb is not None else emb
772
+
773
+ # 2. pre-process
774
+ sample = self.conv_in(sample)
775
+
776
+ controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
777
+ sample = sample + controlnet_cond
778
+
779
+ # 3. down
780
+ down_block_res_samples = (sample,)
781
+ for downsample_block in self.down_blocks:
782
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
783
+ sample, res_samples = downsample_block(
784
+ hidden_states=sample,
785
+ temb=emb,
786
+ encoder_hidden_states=encoder_hidden_states,
787
+ attention_mask=attention_mask,
788
+ cross_attention_kwargs=cross_attention_kwargs,
789
+ )
790
+ else:
791
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
792
+
793
+ down_block_res_samples += res_samples
794
+
795
+ # 4. mid
796
+ if self.mid_block is not None:
797
+ sample = self.mid_block(
798
+ sample,
799
+ emb,
800
+ encoder_hidden_states=encoder_hidden_states,
801
+ attention_mask=attention_mask,
802
+ cross_attention_kwargs=cross_attention_kwargs,
803
+ )
804
+
805
+ # 5. Control net blocks
806
+
807
+ controlnet_down_block_res_samples = ()
808
+
809
+ for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
810
+ down_block_res_sample = controlnet_block(down_block_res_sample)
811
+ controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
812
+
813
+ down_block_res_samples = controlnet_down_block_res_samples
814
+
815
+ mid_block_res_sample = self.controlnet_mid_block(sample)
816
+
817
+ # 6. scaling
818
+ if guess_mode and not self.config.global_pool_conditions:
819
+ scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
820
+ scales = scales * conditioning_scale
821
+ down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
822
+ mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
823
+ else:
824
+ down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
825
+ mid_block_res_sample = mid_block_res_sample * conditioning_scale
826
+
827
+ if self.config.global_pool_conditions:
828
+ down_block_res_samples = [
829
+ torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
830
+ ]
831
+ mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
832
+
833
+ if not return_dict:
834
+ return (down_block_res_samples, mid_block_res_sample)
835
+
836
+ return ControlNetOutput(
837
+ down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
838
+ )
839
+
840
+
841
+ def zero_module(module):
842
+ for p in module.parameters():
843
+ nn.init.zeros_(p)
844
+ return module
diffusers/models/controlnet_flax.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Optional, Tuple, Union
15
+
16
+ import flax
17
+ import flax.linen as nn
18
+ import jax
19
+ import jax.numpy as jnp
20
+ from flax.core.frozen_dict import FrozenDict
21
+
22
+ from ..configuration_utils import ConfigMixin, flax_register_to_config
23
+ from ..utils import BaseOutput
24
+ from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
25
+ from .modeling_flax_utils import FlaxModelMixin
26
+ from .unet_2d_blocks_flax import (
27
+ FlaxCrossAttnDownBlock2D,
28
+ FlaxDownBlock2D,
29
+ FlaxUNetMidBlock2DCrossAttn,
30
+ )
31
+
32
+
33
+ @flax.struct.dataclass
34
+ class FlaxControlNetOutput(BaseOutput):
35
+ """
36
+ The output of [`FlaxControlNetModel`].
37
+
38
+ Args:
39
+ down_block_res_samples (`jnp.ndarray`):
40
+ mid_block_res_sample (`jnp.ndarray`):
41
+ """
42
+
43
+ down_block_res_samples: jnp.ndarray
44
+ mid_block_res_sample: jnp.ndarray
45
+
46
+
47
+ class FlaxControlNetConditioningEmbedding(nn.Module):
48
+ conditioning_embedding_channels: int
49
+ block_out_channels: Tuple[int] = (16, 32, 96, 256)
50
+ dtype: jnp.dtype = jnp.float32
51
+
52
+ def setup(self):
53
+ self.conv_in = nn.Conv(
54
+ self.block_out_channels[0],
55
+ kernel_size=(3, 3),
56
+ padding=((1, 1), (1, 1)),
57
+ dtype=self.dtype,
58
+ )
59
+
60
+ blocks = []
61
+ for i in range(len(self.block_out_channels) - 1):
62
+ channel_in = self.block_out_channels[i]
63
+ channel_out = self.block_out_channels[i + 1]
64
+ conv1 = nn.Conv(
65
+ channel_in,
66
+ kernel_size=(3, 3),
67
+ padding=((1, 1), (1, 1)),
68
+ dtype=self.dtype,
69
+ )
70
+ blocks.append(conv1)
71
+ conv2 = nn.Conv(
72
+ channel_out,
73
+ kernel_size=(3, 3),
74
+ strides=(2, 2),
75
+ padding=((1, 1), (1, 1)),
76
+ dtype=self.dtype,
77
+ )
78
+ blocks.append(conv2)
79
+ self.blocks = blocks
80
+
81
+ self.conv_out = nn.Conv(
82
+ self.conditioning_embedding_channels,
83
+ kernel_size=(3, 3),
84
+ padding=((1, 1), (1, 1)),
85
+ kernel_init=nn.initializers.zeros_init(),
86
+ bias_init=nn.initializers.zeros_init(),
87
+ dtype=self.dtype,
88
+ )
89
+
90
+ def __call__(self, conditioning):
91
+ embedding = self.conv_in(conditioning)
92
+ embedding = nn.silu(embedding)
93
+
94
+ for block in self.blocks:
95
+ embedding = block(embedding)
96
+ embedding = nn.silu(embedding)
97
+
98
+ embedding = self.conv_out(embedding)
99
+
100
+ return embedding
101
+
102
+
103
+ @flax_register_to_config
104
+ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
105
+ r"""
106
+ A ControlNet model.
107
+
108
+ This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for it’s generic methods
109
+ implemented for all models (such as downloading or saving).
110
+
111
+ This model is also a Flax Linen [`flax.linen.Module`](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
112
+ subclass. Use it as a regular Flax Linen module and refer to the Flax documentation for all matters related to its
113
+ general usage and behavior.
114
+
115
+ Inherent JAX features such as the following are supported:
116
+
117
+ - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
118
+ - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
119
+ - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
120
+ - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
121
+
122
+ Parameters:
123
+ sample_size (`int`, *optional*):
124
+ The size of the input sample.
125
+ in_channels (`int`, *optional*, defaults to 4):
126
+ The number of channels in the input sample.
127
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D")`):
128
+ The tuple of downsample blocks to use.
129
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
130
+ The tuple of output channels for each block.
131
+ layers_per_block (`int`, *optional*, defaults to 2):
132
+ The number of layers per block.
133
+ attention_head_dim (`int` or `Tuple[int]`, *optional*, defaults to 8):
134
+ The dimension of the attention heads.
135
+ num_attention_heads (`int` or `Tuple[int]`, *optional*):
136
+ The number of attention heads.
137
+ cross_attention_dim (`int`, *optional*, defaults to 768):
138
+ The dimension of the cross attention features.
139
+ dropout (`float`, *optional*, defaults to 0):
140
+ Dropout probability for down, up and bottleneck blocks.
141
+ flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
142
+ Whether to flip the sin to cos in the time embedding.
143
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
144
+ controlnet_conditioning_channel_order (`str`, *optional*, defaults to `rgb`):
145
+ The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
146
+ conditioning_embedding_out_channels (`tuple`, *optional*, defaults to `(16, 32, 96, 256)`):
147
+ The tuple of output channel for each block in the `conditioning_embedding` layer.
148
+ """
149
+ sample_size: int = 32
150
+ in_channels: int = 4
151
+ down_block_types: Tuple[str] = (
152
+ "CrossAttnDownBlock2D",
153
+ "CrossAttnDownBlock2D",
154
+ "CrossAttnDownBlock2D",
155
+ "DownBlock2D",
156
+ )
157
+ only_cross_attention: Union[bool, Tuple[bool]] = False
158
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280)
159
+ layers_per_block: int = 2
160
+ attention_head_dim: Union[int, Tuple[int]] = 8
161
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None
162
+ cross_attention_dim: int = 1280
163
+ dropout: float = 0.0
164
+ use_linear_projection: bool = False
165
+ dtype: jnp.dtype = jnp.float32
166
+ flip_sin_to_cos: bool = True
167
+ freq_shift: int = 0
168
+ controlnet_conditioning_channel_order: str = "rgb"
169
+ conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256)
170
+
171
+ def init_weights(self, rng: jax.Array) -> FrozenDict:
172
+ # init input tensors
173
+ sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
174
+ sample = jnp.zeros(sample_shape, dtype=jnp.float32)
175
+ timesteps = jnp.ones((1,), dtype=jnp.int32)
176
+ encoder_hidden_states = jnp.zeros((1, 1, self.cross_attention_dim), dtype=jnp.float32)
177
+ controlnet_cond_shape = (1, 3, self.sample_size * 8, self.sample_size * 8)
178
+ controlnet_cond = jnp.zeros(controlnet_cond_shape, dtype=jnp.float32)
179
+
180
+ params_rng, dropout_rng = jax.random.split(rng)
181
+ rngs = {"params": params_rng, "dropout": dropout_rng}
182
+
183
+ return self.init(rngs, sample, timesteps, encoder_hidden_states, controlnet_cond)["params"]
184
+
185
+ def setup(self):
186
+ block_out_channels = self.block_out_channels
187
+ time_embed_dim = block_out_channels[0] * 4
188
+
189
+ # If `num_attention_heads` is not defined (which is the case for most models)
190
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
191
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
192
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
193
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
194
+ # which is why we correct for the naming here.
195
+ num_attention_heads = self.num_attention_heads or self.attention_head_dim
196
+
197
+ # input
198
+ self.conv_in = nn.Conv(
199
+ block_out_channels[0],
200
+ kernel_size=(3, 3),
201
+ strides=(1, 1),
202
+ padding=((1, 1), (1, 1)),
203
+ dtype=self.dtype,
204
+ )
205
+
206
+ # time
207
+ self.time_proj = FlaxTimesteps(
208
+ block_out_channels[0], flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.config.freq_shift
209
+ )
210
+ self.time_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype)
211
+
212
+ self.controlnet_cond_embedding = FlaxControlNetConditioningEmbedding(
213
+ conditioning_embedding_channels=block_out_channels[0],
214
+ block_out_channels=self.conditioning_embedding_out_channels,
215
+ )
216
+
217
+ only_cross_attention = self.only_cross_attention
218
+ if isinstance(only_cross_attention, bool):
219
+ only_cross_attention = (only_cross_attention,) * len(self.down_block_types)
220
+
221
+ if isinstance(num_attention_heads, int):
222
+ num_attention_heads = (num_attention_heads,) * len(self.down_block_types)
223
+
224
+ # down
225
+ down_blocks = []
226
+ controlnet_down_blocks = []
227
+
228
+ output_channel = block_out_channels[0]
229
+
230
+ controlnet_block = nn.Conv(
231
+ output_channel,
232
+ kernel_size=(1, 1),
233
+ padding="VALID",
234
+ kernel_init=nn.initializers.zeros_init(),
235
+ bias_init=nn.initializers.zeros_init(),
236
+ dtype=self.dtype,
237
+ )
238
+ controlnet_down_blocks.append(controlnet_block)
239
+
240
+ for i, down_block_type in enumerate(self.down_block_types):
241
+ input_channel = output_channel
242
+ output_channel = block_out_channels[i]
243
+ is_final_block = i == len(block_out_channels) - 1
244
+
245
+ if down_block_type == "CrossAttnDownBlock2D":
246
+ down_block = FlaxCrossAttnDownBlock2D(
247
+ in_channels=input_channel,
248
+ out_channels=output_channel,
249
+ dropout=self.dropout,
250
+ num_layers=self.layers_per_block,
251
+ num_attention_heads=num_attention_heads[i],
252
+ add_downsample=not is_final_block,
253
+ use_linear_projection=self.use_linear_projection,
254
+ only_cross_attention=only_cross_attention[i],
255
+ dtype=self.dtype,
256
+ )
257
+ else:
258
+ down_block = FlaxDownBlock2D(
259
+ in_channels=input_channel,
260
+ out_channels=output_channel,
261
+ dropout=self.dropout,
262
+ num_layers=self.layers_per_block,
263
+ add_downsample=not is_final_block,
264
+ dtype=self.dtype,
265
+ )
266
+
267
+ down_blocks.append(down_block)
268
+
269
+ for _ in range(self.layers_per_block):
270
+ controlnet_block = nn.Conv(
271
+ output_channel,
272
+ kernel_size=(1, 1),
273
+ padding="VALID",
274
+ kernel_init=nn.initializers.zeros_init(),
275
+ bias_init=nn.initializers.zeros_init(),
276
+ dtype=self.dtype,
277
+ )
278
+ controlnet_down_blocks.append(controlnet_block)
279
+
280
+ if not is_final_block:
281
+ controlnet_block = nn.Conv(
282
+ output_channel,
283
+ kernel_size=(1, 1),
284
+ padding="VALID",
285
+ kernel_init=nn.initializers.zeros_init(),
286
+ bias_init=nn.initializers.zeros_init(),
287
+ dtype=self.dtype,
288
+ )
289
+ controlnet_down_blocks.append(controlnet_block)
290
+
291
+ self.down_blocks = down_blocks
292
+ self.controlnet_down_blocks = controlnet_down_blocks
293
+
294
+ # mid
295
+ mid_block_channel = block_out_channels[-1]
296
+ self.mid_block = FlaxUNetMidBlock2DCrossAttn(
297
+ in_channels=mid_block_channel,
298
+ dropout=self.dropout,
299
+ num_attention_heads=num_attention_heads[-1],
300
+ use_linear_projection=self.use_linear_projection,
301
+ dtype=self.dtype,
302
+ )
303
+
304
+ self.controlnet_mid_block = nn.Conv(
305
+ mid_block_channel,
306
+ kernel_size=(1, 1),
307
+ padding="VALID",
308
+ kernel_init=nn.initializers.zeros_init(),
309
+ bias_init=nn.initializers.zeros_init(),
310
+ dtype=self.dtype,
311
+ )
312
+
313
+ def __call__(
314
+ self,
315
+ sample,
316
+ timesteps,
317
+ encoder_hidden_states,
318
+ controlnet_cond,
319
+ conditioning_scale: float = 1.0,
320
+ return_dict: bool = True,
321
+ train: bool = False,
322
+ ) -> Union[FlaxControlNetOutput, Tuple]:
323
+ r"""
324
+ Args:
325
+ sample (`jnp.ndarray`): (batch, channel, height, width) noisy inputs tensor
326
+ timestep (`jnp.ndarray` or `float` or `int`): timesteps
327
+ encoder_hidden_states (`jnp.ndarray`): (batch_size, sequence_length, hidden_size) encoder hidden states
328
+ controlnet_cond (`jnp.ndarray`): (batch, channel, height, width) the conditional input tensor
329
+ conditioning_scale: (`float`) the scale factor for controlnet outputs
330
+ return_dict (`bool`, *optional*, defaults to `True`):
331
+ Whether or not to return a [`models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a
332
+ plain tuple.
333
+ train (`bool`, *optional*, defaults to `False`):
334
+ Use deterministic functions and disable dropout when not training.
335
+
336
+ Returns:
337
+ [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`:
338
+ [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`.
339
+ When returning a tuple, the first element is the sample tensor.
340
+ """
341
+ channel_order = self.controlnet_conditioning_channel_order
342
+ if channel_order == "bgr":
343
+ controlnet_cond = jnp.flip(controlnet_cond, axis=1)
344
+
345
+ # 1. time
346
+ if not isinstance(timesteps, jnp.ndarray):
347
+ timesteps = jnp.array([timesteps], dtype=jnp.int32)
348
+ elif isinstance(timesteps, jnp.ndarray) and len(timesteps.shape) == 0:
349
+ timesteps = timesteps.astype(dtype=jnp.float32)
350
+ timesteps = jnp.expand_dims(timesteps, 0)
351
+
352
+ t_emb = self.time_proj(timesteps)
353
+ t_emb = self.time_embedding(t_emb)
354
+
355
+ # 2. pre-process
356
+ sample = jnp.transpose(sample, (0, 2, 3, 1))
357
+ sample = self.conv_in(sample)
358
+
359
+ controlnet_cond = jnp.transpose(controlnet_cond, (0, 2, 3, 1))
360
+ controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
361
+ sample += controlnet_cond
362
+
363
+ # 3. down
364
+ down_block_res_samples = (sample,)
365
+ for down_block in self.down_blocks:
366
+ if isinstance(down_block, FlaxCrossAttnDownBlock2D):
367
+ sample, res_samples = down_block(sample, t_emb, encoder_hidden_states, deterministic=not train)
368
+ else:
369
+ sample, res_samples = down_block(sample, t_emb, deterministic=not train)
370
+ down_block_res_samples += res_samples
371
+
372
+ # 4. mid
373
+ sample = self.mid_block(sample, t_emb, encoder_hidden_states, deterministic=not train)
374
+
375
+ # 5. contronet blocks
376
+ controlnet_down_block_res_samples = ()
377
+ for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
378
+ down_block_res_sample = controlnet_block(down_block_res_sample)
379
+ controlnet_down_block_res_samples += (down_block_res_sample,)
380
+
381
+ down_block_res_samples = controlnet_down_block_res_samples
382
+
383
+ mid_block_res_sample = self.controlnet_mid_block(sample)
384
+
385
+ # 6. scaling
386
+ down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
387
+ mid_block_res_sample *= conditioning_scale
388
+
389
+ if not return_dict:
390
+ return (down_block_res_samples, mid_block_res_sample)
391
+
392
+ return FlaxControlNetOutput(
393
+ down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
394
+ )
diffusers/models/dual_transformer_2d.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Optional
15
+
16
+ from torch import nn
17
+
18
+ from .transformer_2d import Transformer2DModel, Transformer2DModelOutput
19
+
20
+
21
+ class DualTransformer2DModel(nn.Module):
22
+ """
23
+ Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference.
24
+
25
+ Parameters:
26
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
27
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
28
+ in_channels (`int`, *optional*):
29
+ Pass if the input is continuous. The number of channels in the input and output.
30
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
31
+ dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
32
+ cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
33
+ sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
34
+ Note that this is fixed at training time as it is used for learning a number of position embeddings. See
35
+ `ImagePositionalEmbeddings`.
36
+ num_vector_embeds (`int`, *optional*):
37
+ Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
38
+ Includes the class for the masked latent pixel.
39
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
40
+ num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
41
+ The number of diffusion steps used during training. Note that this is fixed at training time as it is used
42
+ to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
43
+ up to but not more than steps than `num_embeds_ada_norm`.
44
+ attention_bias (`bool`, *optional*):
45
+ Configure if the TransformerBlocks' attention should contain a bias parameter.
46
+ """
47
+
48
+ def __init__(
49
+ self,
50
+ num_attention_heads: int = 16,
51
+ attention_head_dim: int = 88,
52
+ in_channels: Optional[int] = None,
53
+ num_layers: int = 1,
54
+ dropout: float = 0.0,
55
+ norm_num_groups: int = 32,
56
+ cross_attention_dim: Optional[int] = None,
57
+ attention_bias: bool = False,
58
+ sample_size: Optional[int] = None,
59
+ num_vector_embeds: Optional[int] = None,
60
+ activation_fn: str = "geglu",
61
+ num_embeds_ada_norm: Optional[int] = None,
62
+ ):
63
+ super().__init__()
64
+ self.transformers = nn.ModuleList(
65
+ [
66
+ Transformer2DModel(
67
+ num_attention_heads=num_attention_heads,
68
+ attention_head_dim=attention_head_dim,
69
+ in_channels=in_channels,
70
+ num_layers=num_layers,
71
+ dropout=dropout,
72
+ norm_num_groups=norm_num_groups,
73
+ cross_attention_dim=cross_attention_dim,
74
+ attention_bias=attention_bias,
75
+ sample_size=sample_size,
76
+ num_vector_embeds=num_vector_embeds,
77
+ activation_fn=activation_fn,
78
+ num_embeds_ada_norm=num_embeds_ada_norm,
79
+ )
80
+ for _ in range(2)
81
+ ]
82
+ )
83
+
84
+ # Variables that can be set by a pipeline:
85
+
86
+ # The ratio of transformer1 to transformer2's output states to be combined during inference
87
+ self.mix_ratio = 0.5
88
+
89
+ # The shape of `encoder_hidden_states` is expected to be
90
+ # `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)`
91
+ self.condition_lengths = [77, 257]
92
+
93
+ # Which transformer to use to encode which condition.
94
+ # E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])`
95
+ self.transformer_index_for_condition = [1, 0]
96
+
97
+ def forward(
98
+ self,
99
+ hidden_states,
100
+ encoder_hidden_states,
101
+ timestep=None,
102
+ attention_mask=None,
103
+ cross_attention_kwargs=None,
104
+ return_dict: bool = True,
105
+ ):
106
+ """
107
+ Args:
108
+ hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
109
+ When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
110
+ hidden_states.
111
+ encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
112
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
113
+ self-attention.
114
+ timestep ( `torch.long`, *optional*):
115
+ Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
116
+ attention_mask (`torch.FloatTensor`, *optional*):
117
+ Optional attention mask to be applied in Attention.
118
+ cross_attention_kwargs (`dict`, *optional*):
119
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
120
+ `self.processor` in
121
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
122
+ return_dict (`bool`, *optional*, defaults to `True`):
123
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
124
+
125
+ Returns:
126
+ [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:
127
+ [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When
128
+ returning a tuple, the first element is the sample tensor.
129
+ """
130
+ input_states = hidden_states
131
+
132
+ encoded_states = []
133
+ tokens_start = 0
134
+ # attention_mask is not used yet
135
+ for i in range(2):
136
+ # for each of the two transformers, pass the corresponding condition tokens
137
+ condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]]
138
+ transformer_index = self.transformer_index_for_condition[i]
139
+ encoded_state = self.transformers[transformer_index](
140
+ input_states,
141
+ encoder_hidden_states=condition_state,
142
+ timestep=timestep,
143
+ cross_attention_kwargs=cross_attention_kwargs,
144
+ return_dict=False,
145
+ )[0]
146
+ encoded_states.append(encoded_state - input_states)
147
+ tokens_start += self.condition_lengths[i]
148
+
149
+ output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio)
150
+ output_states = output_states + input_states
151
+
152
+ if not return_dict:
153
+ return (output_states,)
154
+
155
+ return Transformer2DModelOutput(sample=output_states)
diffusers/models/embeddings.py ADDED
@@ -0,0 +1,792 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import math
15
+ from typing import Optional
16
+
17
+ import numpy as np
18
+ import torch
19
+ from torch import nn
20
+
21
+ from ..utils import USE_PEFT_BACKEND
22
+ from .activations import get_activation
23
+ from .lora import LoRACompatibleLinear
24
+
25
+
26
+ def get_timestep_embedding(
27
+ timesteps: torch.Tensor,
28
+ embedding_dim: int,
29
+ flip_sin_to_cos: bool = False,
30
+ downscale_freq_shift: float = 1,
31
+ scale: float = 1,
32
+ max_period: int = 10000,
33
+ ):
34
+ """
35
+ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
36
+
37
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
38
+ These may be fractional.
39
+ :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
40
+ embeddings. :return: an [N x dim] Tensor of positional embeddings.
41
+ """
42
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
43
+
44
+ half_dim = embedding_dim // 2
45
+ exponent = -math.log(max_period) * torch.arange(
46
+ start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
47
+ )
48
+ exponent = exponent / (half_dim - downscale_freq_shift)
49
+
50
+ emb = torch.exp(exponent)
51
+ emb = timesteps[:, None].float() * emb[None, :]
52
+
53
+ # scale embeddings
54
+ emb = scale * emb
55
+
56
+ # concat sine and cosine embeddings
57
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
58
+
59
+ # flip sine and cosine embeddings
60
+ if flip_sin_to_cos:
61
+ emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
62
+
63
+ # zero pad
64
+ if embedding_dim % 2 == 1:
65
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
66
+ return emb
67
+
68
+
69
+ def get_2d_sincos_pos_embed(
70
+ embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16
71
+ ):
72
+ """
73
+ grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
74
+ [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
75
+ """
76
+ if isinstance(grid_size, int):
77
+ grid_size = (grid_size, grid_size)
78
+
79
+ grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale
80
+ grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale
81
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
82
+ grid = np.stack(grid, axis=0)
83
+
84
+ grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
85
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
86
+ if cls_token and extra_tokens > 0:
87
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
88
+ return pos_embed
89
+
90
+
91
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
92
+ if embed_dim % 2 != 0:
93
+ raise ValueError("embed_dim must be divisible by 2")
94
+
95
+ # use half of dimensions to encode grid_h
96
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
97
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
98
+
99
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
100
+ return emb
101
+
102
+
103
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
104
+ """
105
+ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
106
+ """
107
+ if embed_dim % 2 != 0:
108
+ raise ValueError("embed_dim must be divisible by 2")
109
+
110
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
111
+ omega /= embed_dim / 2.0
112
+ omega = 1.0 / 10000**omega # (D/2,)
113
+
114
+ pos = pos.reshape(-1) # (M,)
115
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
116
+
117
+ emb_sin = np.sin(out) # (M, D/2)
118
+ emb_cos = np.cos(out) # (M, D/2)
119
+
120
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
121
+ return emb
122
+
123
+
124
+ class PatchEmbed(nn.Module):
125
+ """2D Image to Patch Embedding"""
126
+
127
+ def __init__(
128
+ self,
129
+ height=224,
130
+ width=224,
131
+ patch_size=16,
132
+ in_channels=3,
133
+ embed_dim=768,
134
+ layer_norm=False,
135
+ flatten=True,
136
+ bias=True,
137
+ interpolation_scale=1,
138
+ ):
139
+ super().__init__()
140
+
141
+ num_patches = (height // patch_size) * (width // patch_size)
142
+ self.flatten = flatten
143
+ self.layer_norm = layer_norm
144
+
145
+ self.proj = nn.Conv2d(
146
+ in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
147
+ )
148
+ if layer_norm:
149
+ self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
150
+ else:
151
+ self.norm = None
152
+
153
+ self.patch_size = patch_size
154
+ # See:
155
+ # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L161
156
+ self.height, self.width = height // patch_size, width // patch_size
157
+ self.base_size = height // patch_size
158
+ self.interpolation_scale = interpolation_scale
159
+ pos_embed = get_2d_sincos_pos_embed(
160
+ embed_dim, int(num_patches**0.5), base_size=self.base_size, interpolation_scale=self.interpolation_scale
161
+ )
162
+ self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
163
+
164
+ def forward(self, latent):
165
+ height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
166
+
167
+ latent = self.proj(latent)
168
+ if self.flatten:
169
+ latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
170
+ if self.layer_norm:
171
+ latent = self.norm(latent)
172
+
173
+ # Interpolate positional embeddings if needed.
174
+ # (For PixArt-Alpha: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L162C151-L162C160)
175
+ if self.height != height or self.width != width:
176
+ pos_embed = get_2d_sincos_pos_embed(
177
+ embed_dim=self.pos_embed.shape[-1],
178
+ grid_size=(height, width),
179
+ base_size=self.base_size,
180
+ interpolation_scale=self.interpolation_scale,
181
+ )
182
+ pos_embed = torch.from_numpy(pos_embed)
183
+ pos_embed = pos_embed.float().unsqueeze(0).to(latent.device)
184
+ else:
185
+ pos_embed = self.pos_embed
186
+
187
+ return (latent + pos_embed).to(latent.dtype)
188
+
189
+
190
+ class TimestepEmbedding(nn.Module):
191
+ def __init__(
192
+ self,
193
+ in_channels: int,
194
+ time_embed_dim: int,
195
+ act_fn: str = "silu",
196
+ out_dim: int = None,
197
+ post_act_fn: Optional[str] = None,
198
+ cond_proj_dim=None,
199
+ ):
200
+ super().__init__()
201
+ linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
202
+
203
+ self.linear_1 = linear_cls(in_channels, time_embed_dim)
204
+
205
+ if cond_proj_dim is not None:
206
+ self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
207
+ else:
208
+ self.cond_proj = None
209
+
210
+ self.act = get_activation(act_fn)
211
+
212
+ if out_dim is not None:
213
+ time_embed_dim_out = out_dim
214
+ else:
215
+ time_embed_dim_out = time_embed_dim
216
+ self.linear_2 = linear_cls(time_embed_dim, time_embed_dim_out)
217
+
218
+ if post_act_fn is None:
219
+ self.post_act = None
220
+ else:
221
+ self.post_act = get_activation(post_act_fn)
222
+
223
+ def forward(self, sample, condition=None):
224
+ if condition is not None:
225
+ sample = sample + self.cond_proj(condition)
226
+ sample = self.linear_1(sample)
227
+
228
+ if self.act is not None:
229
+ sample = self.act(sample)
230
+
231
+ sample = self.linear_2(sample)
232
+
233
+ if self.post_act is not None:
234
+ sample = self.post_act(sample)
235
+ return sample
236
+
237
+
238
+ class Timesteps(nn.Module):
239
+ def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
240
+ super().__init__()
241
+ self.num_channels = num_channels
242
+ self.flip_sin_to_cos = flip_sin_to_cos
243
+ self.downscale_freq_shift = downscale_freq_shift
244
+
245
+ def forward(self, timesteps):
246
+ t_emb = get_timestep_embedding(
247
+ timesteps,
248
+ self.num_channels,
249
+ flip_sin_to_cos=self.flip_sin_to_cos,
250
+ downscale_freq_shift=self.downscale_freq_shift,
251
+ )
252
+ return t_emb
253
+
254
+
255
+ class GaussianFourierProjection(nn.Module):
256
+ """Gaussian Fourier embeddings for noise levels."""
257
+
258
+ def __init__(
259
+ self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False
260
+ ):
261
+ super().__init__()
262
+ self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
263
+ self.log = log
264
+ self.flip_sin_to_cos = flip_sin_to_cos
265
+
266
+ if set_W_to_weight:
267
+ # to delete later
268
+ self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
269
+
270
+ self.weight = self.W
271
+
272
+ def forward(self, x):
273
+ if self.log:
274
+ x = torch.log(x)
275
+
276
+ x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi
277
+
278
+ if self.flip_sin_to_cos:
279
+ out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1)
280
+ else:
281
+ out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
282
+ return out
283
+
284
+
285
+ class SinusoidalPositionalEmbedding(nn.Module):
286
+ """Apply positional information to a sequence of embeddings.
287
+
288
+ Takes in a sequence of embeddings with shape (batch_size, seq_length, embed_dim) and adds positional embeddings to
289
+ them
290
+
291
+ Args:
292
+ embed_dim: (int): Dimension of the positional embedding.
293
+ max_seq_length: Maximum sequence length to apply positional embeddings
294
+
295
+ """
296
+
297
+ def __init__(self, embed_dim: int, max_seq_length: int = 32):
298
+ super().__init__()
299
+ position = torch.arange(max_seq_length).unsqueeze(1)
300
+ div_term = torch.exp(torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim))
301
+ pe = torch.zeros(1, max_seq_length, embed_dim)
302
+ pe[0, :, 0::2] = torch.sin(position * div_term)
303
+ pe[0, :, 1::2] = torch.cos(position * div_term)
304
+ self.register_buffer("pe", pe)
305
+
306
+ def forward(self, x):
307
+ _, seq_length, _ = x.shape
308
+ x = x + self.pe[:, :seq_length]
309
+ return x
310
+
311
+
312
+ class ImagePositionalEmbeddings(nn.Module):
313
+ """
314
+ Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the
315
+ height and width of the latent space.
316
+
317
+ For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092
318
+
319
+ For VQ-diffusion:
320
+
321
+ Output vector embeddings are used as input for the transformer.
322
+
323
+ Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE.
324
+
325
+ Args:
326
+ num_embed (`int`):
327
+ Number of embeddings for the latent pixels embeddings.
328
+ height (`int`):
329
+ Height of the latent image i.e. the number of height embeddings.
330
+ width (`int`):
331
+ Width of the latent image i.e. the number of width embeddings.
332
+ embed_dim (`int`):
333
+ Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings.
334
+ """
335
+
336
+ def __init__(
337
+ self,
338
+ num_embed: int,
339
+ height: int,
340
+ width: int,
341
+ embed_dim: int,
342
+ ):
343
+ super().__init__()
344
+
345
+ self.height = height
346
+ self.width = width
347
+ self.num_embed = num_embed
348
+ self.embed_dim = embed_dim
349
+
350
+ self.emb = nn.Embedding(self.num_embed, embed_dim)
351
+ self.height_emb = nn.Embedding(self.height, embed_dim)
352
+ self.width_emb = nn.Embedding(self.width, embed_dim)
353
+
354
+ def forward(self, index):
355
+ emb = self.emb(index)
356
+
357
+ height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height))
358
+
359
+ # 1 x H x D -> 1 x H x 1 x D
360
+ height_emb = height_emb.unsqueeze(2)
361
+
362
+ width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width))
363
+
364
+ # 1 x W x D -> 1 x 1 x W x D
365
+ width_emb = width_emb.unsqueeze(1)
366
+
367
+ pos_emb = height_emb + width_emb
368
+
369
+ # 1 x H x W x D -> 1 x L xD
370
+ pos_emb = pos_emb.view(1, self.height * self.width, -1)
371
+
372
+ emb = emb + pos_emb[:, : emb.shape[1], :]
373
+
374
+ return emb
375
+
376
+
377
+ class LabelEmbedding(nn.Module):
378
+ """
379
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
380
+
381
+ Args:
382
+ num_classes (`int`): The number of classes.
383
+ hidden_size (`int`): The size of the vector embeddings.
384
+ dropout_prob (`float`): The probability of dropping a label.
385
+ """
386
+
387
+ def __init__(self, num_classes, hidden_size, dropout_prob):
388
+ super().__init__()
389
+ use_cfg_embedding = dropout_prob > 0
390
+ self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
391
+ self.num_classes = num_classes
392
+ self.dropout_prob = dropout_prob
393
+
394
+ def token_drop(self, labels, force_drop_ids=None):
395
+ """
396
+ Drops labels to enable classifier-free guidance.
397
+ """
398
+ if force_drop_ids is None:
399
+ drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
400
+ else:
401
+ drop_ids = torch.tensor(force_drop_ids == 1)
402
+ labels = torch.where(drop_ids, self.num_classes, labels)
403
+ return labels
404
+
405
+ def forward(self, labels: torch.LongTensor, force_drop_ids=None):
406
+ use_dropout = self.dropout_prob > 0
407
+ if (self.training and use_dropout) or (force_drop_ids is not None):
408
+ labels = self.token_drop(labels, force_drop_ids)
409
+ embeddings = self.embedding_table(labels)
410
+ return embeddings
411
+
412
+
413
+ class TextImageProjection(nn.Module):
414
+ def __init__(
415
+ self,
416
+ text_embed_dim: int = 1024,
417
+ image_embed_dim: int = 768,
418
+ cross_attention_dim: int = 768,
419
+ num_image_text_embeds: int = 10,
420
+ ):
421
+ super().__init__()
422
+
423
+ self.num_image_text_embeds = num_image_text_embeds
424
+ self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim)
425
+ self.text_proj = nn.Linear(text_embed_dim, cross_attention_dim)
426
+
427
+ def forward(self, text_embeds: torch.FloatTensor, image_embeds: torch.FloatTensor):
428
+ batch_size = text_embeds.shape[0]
429
+
430
+ # image
431
+ image_text_embeds = self.image_embeds(image_embeds)
432
+ image_text_embeds = image_text_embeds.reshape(batch_size, self.num_image_text_embeds, -1)
433
+
434
+ # text
435
+ text_embeds = self.text_proj(text_embeds)
436
+
437
+ return torch.cat([image_text_embeds, text_embeds], dim=1)
438
+
439
+
440
+ class ImageProjection(nn.Module):
441
+ def __init__(
442
+ self,
443
+ image_embed_dim: int = 768,
444
+ cross_attention_dim: int = 768,
445
+ num_image_text_embeds: int = 32,
446
+ ):
447
+ super().__init__()
448
+
449
+ self.num_image_text_embeds = num_image_text_embeds
450
+ self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim)
451
+ self.norm = nn.LayerNorm(cross_attention_dim)
452
+
453
+ def forward(self, image_embeds: torch.FloatTensor):
454
+ batch_size = image_embeds.shape[0]
455
+
456
+ # image
457
+ image_embeds = self.image_embeds(image_embeds)
458
+ image_embeds = image_embeds.reshape(batch_size, self.num_image_text_embeds, -1)
459
+ image_embeds = self.norm(image_embeds)
460
+ return image_embeds
461
+
462
+
463
+ class CombinedTimestepLabelEmbeddings(nn.Module):
464
+ def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1):
465
+ super().__init__()
466
+
467
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
468
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
469
+ self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob)
470
+
471
+ def forward(self, timestep, class_labels, hidden_dtype=None):
472
+ timesteps_proj = self.time_proj(timestep)
473
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
474
+
475
+ class_labels = self.class_embedder(class_labels) # (N, D)
476
+
477
+ conditioning = timesteps_emb + class_labels # (N, D)
478
+
479
+ return conditioning
480
+
481
+
482
+ class TextTimeEmbedding(nn.Module):
483
+ def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64):
484
+ super().__init__()
485
+ self.norm1 = nn.LayerNorm(encoder_dim)
486
+ self.pool = AttentionPooling(num_heads, encoder_dim)
487
+ self.proj = nn.Linear(encoder_dim, time_embed_dim)
488
+ self.norm2 = nn.LayerNorm(time_embed_dim)
489
+
490
+ def forward(self, hidden_states):
491
+ hidden_states = self.norm1(hidden_states)
492
+ hidden_states = self.pool(hidden_states)
493
+ hidden_states = self.proj(hidden_states)
494
+ hidden_states = self.norm2(hidden_states)
495
+ return hidden_states
496
+
497
+
498
+ class TextImageTimeEmbedding(nn.Module):
499
+ def __init__(self, text_embed_dim: int = 768, image_embed_dim: int = 768, time_embed_dim: int = 1536):
500
+ super().__init__()
501
+ self.text_proj = nn.Linear(text_embed_dim, time_embed_dim)
502
+ self.text_norm = nn.LayerNorm(time_embed_dim)
503
+ self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
504
+
505
+ def forward(self, text_embeds: torch.FloatTensor, image_embeds: torch.FloatTensor):
506
+ # text
507
+ time_text_embeds = self.text_proj(text_embeds)
508
+ time_text_embeds = self.text_norm(time_text_embeds)
509
+
510
+ # image
511
+ time_image_embeds = self.image_proj(image_embeds)
512
+
513
+ return time_image_embeds + time_text_embeds
514
+
515
+
516
+ class ImageTimeEmbedding(nn.Module):
517
+ def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536):
518
+ super().__init__()
519
+ self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
520
+ self.image_norm = nn.LayerNorm(time_embed_dim)
521
+
522
+ def forward(self, image_embeds: torch.FloatTensor):
523
+ # image
524
+ time_image_embeds = self.image_proj(image_embeds)
525
+ time_image_embeds = self.image_norm(time_image_embeds)
526
+ return time_image_embeds
527
+
528
+
529
+ class ImageHintTimeEmbedding(nn.Module):
530
+ def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536):
531
+ super().__init__()
532
+ self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
533
+ self.image_norm = nn.LayerNorm(time_embed_dim)
534
+ self.input_hint_block = nn.Sequential(
535
+ nn.Conv2d(3, 16, 3, padding=1),
536
+ nn.SiLU(),
537
+ nn.Conv2d(16, 16, 3, padding=1),
538
+ nn.SiLU(),
539
+ nn.Conv2d(16, 32, 3, padding=1, stride=2),
540
+ nn.SiLU(),
541
+ nn.Conv2d(32, 32, 3, padding=1),
542
+ nn.SiLU(),
543
+ nn.Conv2d(32, 96, 3, padding=1, stride=2),
544
+ nn.SiLU(),
545
+ nn.Conv2d(96, 96, 3, padding=1),
546
+ nn.SiLU(),
547
+ nn.Conv2d(96, 256, 3, padding=1, stride=2),
548
+ nn.SiLU(),
549
+ nn.Conv2d(256, 4, 3, padding=1),
550
+ )
551
+
552
+ def forward(self, image_embeds: torch.FloatTensor, hint: torch.FloatTensor):
553
+ # image
554
+ time_image_embeds = self.image_proj(image_embeds)
555
+ time_image_embeds = self.image_norm(time_image_embeds)
556
+ hint = self.input_hint_block(hint)
557
+ return time_image_embeds, hint
558
+
559
+
560
+ class AttentionPooling(nn.Module):
561
+ # Copied from https://github.com/deep-floyd/IF/blob/2f91391f27dd3c468bf174be5805b4cc92980c0b/deepfloyd_if/model/nn.py#L54
562
+
563
+ def __init__(self, num_heads, embed_dim, dtype=None):
564
+ super().__init__()
565
+ self.dtype = dtype
566
+ self.positional_embedding = nn.Parameter(torch.randn(1, embed_dim) / embed_dim**0.5)
567
+ self.k_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
568
+ self.q_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
569
+ self.v_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
570
+ self.num_heads = num_heads
571
+ self.dim_per_head = embed_dim // self.num_heads
572
+
573
+ def forward(self, x):
574
+ bs, length, width = x.size()
575
+
576
+ def shape(x):
577
+ # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
578
+ x = x.view(bs, -1, self.num_heads, self.dim_per_head)
579
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
580
+ x = x.transpose(1, 2)
581
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
582
+ x = x.reshape(bs * self.num_heads, -1, self.dim_per_head)
583
+ # (bs*n_heads, length, dim_per_head) --> (bs*n_heads, dim_per_head, length)
584
+ x = x.transpose(1, 2)
585
+ return x
586
+
587
+ class_token = x.mean(dim=1, keepdim=True) + self.positional_embedding.to(x.dtype)
588
+ x = torch.cat([class_token, x], dim=1) # (bs, length+1, width)
589
+
590
+ # (bs*n_heads, class_token_length, dim_per_head)
591
+ q = shape(self.q_proj(class_token))
592
+ # (bs*n_heads, length+class_token_length, dim_per_head)
593
+ k = shape(self.k_proj(x))
594
+ v = shape(self.v_proj(x))
595
+
596
+ # (bs*n_heads, class_token_length, length+class_token_length):
597
+ scale = 1 / math.sqrt(math.sqrt(self.dim_per_head))
598
+ weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
599
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
600
+
601
+ # (bs*n_heads, dim_per_head, class_token_length)
602
+ a = torch.einsum("bts,bcs->bct", weight, v)
603
+
604
+ # (bs, length+1, width)
605
+ a = a.reshape(bs, -1, 1).transpose(1, 2)
606
+
607
+ return a[:, 0, :] # cls_token
608
+
609
+
610
+ class FourierEmbedder(nn.Module):
611
+ def __init__(self, num_freqs=64, temperature=100):
612
+ super().__init__()
613
+
614
+ self.num_freqs = num_freqs
615
+ self.temperature = temperature
616
+
617
+ freq_bands = temperature ** (torch.arange(num_freqs) / num_freqs)
618
+ freq_bands = freq_bands[None, None, None]
619
+ self.register_buffer("freq_bands", freq_bands, persistent=False)
620
+
621
+ def __call__(self, x):
622
+ x = self.freq_bands * x.unsqueeze(-1)
623
+ return torch.stack((x.sin(), x.cos()), dim=-1).permute(0, 1, 3, 4, 2).reshape(*x.shape[:2], -1)
624
+
625
+
626
+ class PositionNet(nn.Module):
627
+ def __init__(self, positive_len, out_dim, feature_type="text-only", fourier_freqs=8):
628
+ super().__init__()
629
+ self.positive_len = positive_len
630
+ self.out_dim = out_dim
631
+
632
+ self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs)
633
+ self.position_dim = fourier_freqs * 2 * 4 # 2: sin/cos, 4: xyxy
634
+
635
+ if isinstance(out_dim, tuple):
636
+ out_dim = out_dim[0]
637
+
638
+ if feature_type == "text-only":
639
+ self.linears = nn.Sequential(
640
+ nn.Linear(self.positive_len + self.position_dim, 512),
641
+ nn.SiLU(),
642
+ nn.Linear(512, 512),
643
+ nn.SiLU(),
644
+ nn.Linear(512, out_dim),
645
+ )
646
+ self.null_positive_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
647
+
648
+ elif feature_type == "text-image":
649
+ self.linears_text = nn.Sequential(
650
+ nn.Linear(self.positive_len + self.position_dim, 512),
651
+ nn.SiLU(),
652
+ nn.Linear(512, 512),
653
+ nn.SiLU(),
654
+ nn.Linear(512, out_dim),
655
+ )
656
+ self.linears_image = nn.Sequential(
657
+ nn.Linear(self.positive_len + self.position_dim, 512),
658
+ nn.SiLU(),
659
+ nn.Linear(512, 512),
660
+ nn.SiLU(),
661
+ nn.Linear(512, out_dim),
662
+ )
663
+ self.null_text_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
664
+ self.null_image_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
665
+
666
+ self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim]))
667
+
668
+ def forward(
669
+ self,
670
+ boxes,
671
+ masks,
672
+ positive_embeddings=None,
673
+ phrases_masks=None,
674
+ image_masks=None,
675
+ phrases_embeddings=None,
676
+ image_embeddings=None,
677
+ ):
678
+ masks = masks.unsqueeze(-1)
679
+
680
+ # embedding position (it may includes padding as placeholder)
681
+ xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 -> B*N*C
682
+
683
+ # learnable null embedding
684
+ xyxy_null = self.null_position_feature.view(1, 1, -1)
685
+
686
+ # replace padding with learnable null embedding
687
+ xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null
688
+
689
+ # positionet with text only information
690
+ if positive_embeddings is not None:
691
+ # learnable null embedding
692
+ positive_null = self.null_positive_feature.view(1, 1, -1)
693
+
694
+ # replace padding with learnable null embedding
695
+ positive_embeddings = positive_embeddings * masks + (1 - masks) * positive_null
696
+
697
+ objs = self.linears(torch.cat([positive_embeddings, xyxy_embedding], dim=-1))
698
+
699
+ # positionet with text and image infomation
700
+ else:
701
+ phrases_masks = phrases_masks.unsqueeze(-1)
702
+ image_masks = image_masks.unsqueeze(-1)
703
+
704
+ # learnable null embedding
705
+ text_null = self.null_text_feature.view(1, 1, -1)
706
+ image_null = self.null_image_feature.view(1, 1, -1)
707
+
708
+ # replace padding with learnable null embedding
709
+ phrases_embeddings = phrases_embeddings * phrases_masks + (1 - phrases_masks) * text_null
710
+ image_embeddings = image_embeddings * image_masks + (1 - image_masks) * image_null
711
+
712
+ objs_text = self.linears_text(torch.cat([phrases_embeddings, xyxy_embedding], dim=-1))
713
+ objs_image = self.linears_image(torch.cat([image_embeddings, xyxy_embedding], dim=-1))
714
+ objs = torch.cat([objs_text, objs_image], dim=1)
715
+
716
+ return objs
717
+
718
+
719
+ class CombinedTimestepSizeEmbeddings(nn.Module):
720
+ """
721
+ For PixArt-Alpha.
722
+
723
+ Reference:
724
+ https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
725
+ """
726
+
727
+ def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False):
728
+ super().__init__()
729
+
730
+ self.outdim = size_emb_dim
731
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
732
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
733
+
734
+ self.use_additional_conditions = use_additional_conditions
735
+ if use_additional_conditions:
736
+ self.use_additional_conditions = True
737
+ self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
738
+ self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
739
+ self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
740
+
741
+ def apply_condition(self, size: torch.Tensor, batch_size: int, embedder: nn.Module):
742
+ if size.ndim == 1:
743
+ size = size[:, None]
744
+
745
+ if size.shape[0] != batch_size:
746
+ size = size.repeat(batch_size // size.shape[0], 1)
747
+ if size.shape[0] != batch_size:
748
+ raise ValueError(f"`batch_size` should be {size.shape[0]} but found {batch_size}.")
749
+
750
+ current_batch_size, dims = size.shape[0], size.shape[1]
751
+ size = size.reshape(-1)
752
+ size_freq = self.additional_condition_proj(size).to(size.dtype)
753
+
754
+ size_emb = embedder(size_freq)
755
+ size_emb = size_emb.reshape(current_batch_size, dims * self.outdim)
756
+ return size_emb
757
+
758
+ def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
759
+ timesteps_proj = self.time_proj(timestep)
760
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
761
+
762
+ if self.use_additional_conditions:
763
+ resolution = self.apply_condition(resolution, batch_size=batch_size, embedder=self.resolution_embedder)
764
+ aspect_ratio = self.apply_condition(
765
+ aspect_ratio, batch_size=batch_size, embedder=self.aspect_ratio_embedder
766
+ )
767
+ conditioning = timesteps_emb + torch.cat([resolution, aspect_ratio], dim=1)
768
+ else:
769
+ conditioning = timesteps_emb
770
+
771
+ return conditioning
772
+
773
+
774
+ class CaptionProjection(nn.Module):
775
+ """
776
+ Projects caption embeddings. Also handles dropout for classifier-free guidance.
777
+
778
+ Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
779
+ """
780
+
781
+ def __init__(self, in_features, hidden_size, num_tokens=120):
782
+ super().__init__()
783
+ self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
784
+ self.act_1 = nn.GELU(approximate="tanh")
785
+ self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True)
786
+ self.register_buffer("y_embedding", nn.Parameter(torch.randn(num_tokens, in_features) / in_features**0.5))
787
+
788
+ def forward(self, caption, force_drop_ids=None):
789
+ hidden_states = self.linear_1(caption)
790
+ hidden_states = self.act_1(hidden_states)
791
+ hidden_states = self.linear_2(hidden_states)
792
+ return hidden_states
diffusers/models/embeddings_flax.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import math
15
+
16
+ import flax.linen as nn
17
+ import jax.numpy as jnp
18
+
19
+
20
+ def get_sinusoidal_embeddings(
21
+ timesteps: jnp.ndarray,
22
+ embedding_dim: int,
23
+ freq_shift: float = 1,
24
+ min_timescale: float = 1,
25
+ max_timescale: float = 1.0e4,
26
+ flip_sin_to_cos: bool = False,
27
+ scale: float = 1.0,
28
+ ) -> jnp.ndarray:
29
+ """Returns the positional encoding (same as Tensor2Tensor).
30
+
31
+ Args:
32
+ timesteps: a 1-D Tensor of N indices, one per batch element.
33
+ These may be fractional.
34
+ embedding_dim: The number of output channels.
35
+ min_timescale: The smallest time unit (should probably be 0.0).
36
+ max_timescale: The largest time unit.
37
+ Returns:
38
+ a Tensor of timing signals [N, num_channels]
39
+ """
40
+ assert timesteps.ndim == 1, "Timesteps should be a 1d-array"
41
+ assert embedding_dim % 2 == 0, f"Embedding dimension {embedding_dim} should be even"
42
+ num_timescales = float(embedding_dim // 2)
43
+ log_timescale_increment = math.log(max_timescale / min_timescale) / (num_timescales - freq_shift)
44
+ inv_timescales = min_timescale * jnp.exp(jnp.arange(num_timescales, dtype=jnp.float32) * -log_timescale_increment)
45
+ emb = jnp.expand_dims(timesteps, 1) * jnp.expand_dims(inv_timescales, 0)
46
+
47
+ # scale embeddings
48
+ scaled_time = scale * emb
49
+
50
+ if flip_sin_to_cos:
51
+ signal = jnp.concatenate([jnp.cos(scaled_time), jnp.sin(scaled_time)], axis=1)
52
+ else:
53
+ signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=1)
54
+ signal = jnp.reshape(signal, [jnp.shape(timesteps)[0], embedding_dim])
55
+ return signal
56
+
57
+
58
+ class FlaxTimestepEmbedding(nn.Module):
59
+ r"""
60
+ Time step Embedding Module. Learns embeddings for input time steps.
61
+
62
+ Args:
63
+ time_embed_dim (`int`, *optional*, defaults to `32`):
64
+ Time step embedding dimension
65
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
66
+ Parameters `dtype`
67
+ """
68
+ time_embed_dim: int = 32
69
+ dtype: jnp.dtype = jnp.float32
70
+
71
+ @nn.compact
72
+ def __call__(self, temb):
73
+ temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_1")(temb)
74
+ temb = nn.silu(temb)
75
+ temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_2")(temb)
76
+ return temb
77
+
78
+
79
+ class FlaxTimesteps(nn.Module):
80
+ r"""
81
+ Wrapper Module for sinusoidal Time step Embeddings as described in https://arxiv.org/abs/2006.11239
82
+
83
+ Args:
84
+ dim (`int`, *optional*, defaults to `32`):
85
+ Time step embedding dimension
86
+ """
87
+ dim: int = 32
88
+ flip_sin_to_cos: bool = False
89
+ freq_shift: float = 1
90
+
91
+ @nn.compact
92
+ def __call__(self, timesteps):
93
+ return get_sinusoidal_embeddings(
94
+ timesteps, embedding_dim=self.dim, flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.freq_shift
95
+ )
diffusers/models/lora.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional, Tuple, Union
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torch import nn
20
+
21
+ from ..loaders import PatchedLoraProjection, text_encoder_attn_modules, text_encoder_mlp_modules
22
+ from ..utils import logging
23
+
24
+
25
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
26
+
27
+
28
+ def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0):
29
+ for _, attn_module in text_encoder_attn_modules(text_encoder):
30
+ if isinstance(attn_module.q_proj, PatchedLoraProjection):
31
+ attn_module.q_proj.lora_scale = lora_scale
32
+ attn_module.k_proj.lora_scale = lora_scale
33
+ attn_module.v_proj.lora_scale = lora_scale
34
+ attn_module.out_proj.lora_scale = lora_scale
35
+
36
+ for _, mlp_module in text_encoder_mlp_modules(text_encoder):
37
+ if isinstance(mlp_module.fc1, PatchedLoraProjection):
38
+ mlp_module.fc1.lora_scale = lora_scale
39
+ mlp_module.fc2.lora_scale = lora_scale
40
+
41
+
42
+ class LoRALinearLayer(nn.Module):
43
+ r"""
44
+ A linear layer that is used with LoRA.
45
+
46
+ Parameters:
47
+ in_features (`int`):
48
+ Number of input features.
49
+ out_features (`int`):
50
+ Number of output features.
51
+ rank (`int`, `optional`, defaults to 4):
52
+ The rank of the LoRA layer.
53
+ network_alpha (`float`, `optional`, defaults to `None`):
54
+ The value of the network alpha used for stable learning and preventing underflow. This value has the same
55
+ meaning as the `--network_alpha` option in the kohya-ss trainer script. See
56
+ https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
57
+ device (`torch.device`, `optional`, defaults to `None`):
58
+ The device to use for the layer's weights.
59
+ dtype (`torch.dtype`, `optional`, defaults to `None`):
60
+ The dtype to use for the layer's weights.
61
+ """
62
+
63
+ def __init__(
64
+ self,
65
+ in_features: int,
66
+ out_features: int,
67
+ rank: int = 4,
68
+ network_alpha: Optional[float] = None,
69
+ device: Optional[Union[torch.device, str]] = None,
70
+ dtype: Optional[torch.dtype] = None,
71
+ ):
72
+ super().__init__()
73
+
74
+ self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
75
+ self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
76
+ # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
77
+ # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
78
+ self.network_alpha = network_alpha
79
+ self.rank = rank
80
+ self.out_features = out_features
81
+ self.in_features = in_features
82
+
83
+ nn.init.normal_(self.down.weight, std=1 / rank)
84
+ nn.init.zeros_(self.up.weight)
85
+
86
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
87
+ orig_dtype = hidden_states.dtype
88
+ dtype = self.down.weight.dtype
89
+
90
+ down_hidden_states = self.down(hidden_states.to(dtype))
91
+ up_hidden_states = self.up(down_hidden_states)
92
+
93
+ if self.network_alpha is not None:
94
+ up_hidden_states *= self.network_alpha / self.rank
95
+
96
+ return up_hidden_states.to(orig_dtype)
97
+
98
+
99
+ class LoRAConv2dLayer(nn.Module):
100
+ r"""
101
+ A convolutional layer that is used with LoRA.
102
+
103
+ Parameters:
104
+ in_features (`int`):
105
+ Number of input features.
106
+ out_features (`int`):
107
+ Number of output features.
108
+ rank (`int`, `optional`, defaults to 4):
109
+ The rank of the LoRA layer.
110
+ kernel_size (`int` or `tuple` of two `int`, `optional`, defaults to 1):
111
+ The kernel size of the convolution.
112
+ stride (`int` or `tuple` of two `int`, `optional`, defaults to 1):
113
+ The stride of the convolution.
114
+ padding (`int` or `tuple` of two `int` or `str`, `optional`, defaults to 0):
115
+ The padding of the convolution.
116
+ network_alpha (`float`, `optional`, defaults to `None`):
117
+ The value of the network alpha used for stable learning and preventing underflow. This value has the same
118
+ meaning as the `--network_alpha` option in the kohya-ss trainer script. See
119
+ https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
120
+ """
121
+
122
+ def __init__(
123
+ self,
124
+ in_features: int,
125
+ out_features: int,
126
+ rank: int = 4,
127
+ kernel_size: Union[int, Tuple[int, int]] = (1, 1),
128
+ stride: Union[int, Tuple[int, int]] = (1, 1),
129
+ padding: Union[int, Tuple[int, int], str] = 0,
130
+ network_alpha: Optional[float] = None,
131
+ ):
132
+ super().__init__()
133
+
134
+ self.down = nn.Conv2d(in_features, rank, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
135
+ # according to the official kohya_ss trainer kernel_size are always fixed for the up layer
136
+ # # see: https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L129
137
+ self.up = nn.Conv2d(rank, out_features, kernel_size=(1, 1), stride=(1, 1), bias=False)
138
+
139
+ # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
140
+ # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
141
+ self.network_alpha = network_alpha
142
+ self.rank = rank
143
+
144
+ nn.init.normal_(self.down.weight, std=1 / rank)
145
+ nn.init.zeros_(self.up.weight)
146
+
147
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
148
+ orig_dtype = hidden_states.dtype
149
+ dtype = self.down.weight.dtype
150
+
151
+ down_hidden_states = self.down(hidden_states.to(dtype))
152
+ up_hidden_states = self.up(down_hidden_states)
153
+
154
+ if self.network_alpha is not None:
155
+ up_hidden_states *= self.network_alpha / self.rank
156
+
157
+ return up_hidden_states.to(orig_dtype)
158
+
159
+
160
+ class LoRACompatibleConv(nn.Conv2d):
161
+ """
162
+ A convolutional layer that can be used with LoRA.
163
+ """
164
+
165
+ def __init__(self, *args, lora_layer: Optional[LoRAConv2dLayer] = None, **kwargs):
166
+ super().__init__(*args, **kwargs)
167
+ self.lora_layer = lora_layer
168
+
169
+ def set_lora_layer(self, lora_layer: Optional[LoRAConv2dLayer]):
170
+ self.lora_layer = lora_layer
171
+
172
+ def _fuse_lora(self, lora_scale: float = 1.0, safe_fusing: bool = False):
173
+ if self.lora_layer is None:
174
+ return
175
+
176
+ dtype, device = self.weight.data.dtype, self.weight.data.device
177
+
178
+ w_orig = self.weight.data.float()
179
+ w_up = self.lora_layer.up.weight.data.float()
180
+ w_down = self.lora_layer.down.weight.data.float()
181
+
182
+ if self.lora_layer.network_alpha is not None:
183
+ w_up = w_up * self.lora_layer.network_alpha / self.lora_layer.rank
184
+
185
+ fusion = torch.mm(w_up.flatten(start_dim=1), w_down.flatten(start_dim=1))
186
+ fusion = fusion.reshape((w_orig.shape))
187
+ fused_weight = w_orig + (lora_scale * fusion)
188
+
189
+ if safe_fusing and torch.isnan(fused_weight).any().item():
190
+ raise ValueError(
191
+ "This LoRA weight seems to be broken. "
192
+ f"Encountered NaN values when trying to fuse LoRA weights for {self}."
193
+ "LoRA weights will not be fused."
194
+ )
195
+
196
+ self.weight.data = fused_weight.to(device=device, dtype=dtype)
197
+
198
+ # we can drop the lora layer now
199
+ self.lora_layer = None
200
+
201
+ # offload the up and down matrices to CPU to not blow the memory
202
+ self.w_up = w_up.cpu()
203
+ self.w_down = w_down.cpu()
204
+ self._lora_scale = lora_scale
205
+
206
+ def _unfuse_lora(self):
207
+ if not (getattr(self, "w_up", None) is not None and getattr(self, "w_down", None) is not None):
208
+ return
209
+
210
+ fused_weight = self.weight.data
211
+ dtype, device = fused_weight.data.dtype, fused_weight.data.device
212
+
213
+ self.w_up = self.w_up.to(device=device).float()
214
+ self.w_down = self.w_down.to(device).float()
215
+
216
+ fusion = torch.mm(self.w_up.flatten(start_dim=1), self.w_down.flatten(start_dim=1))
217
+ fusion = fusion.reshape((fused_weight.shape))
218
+ unfused_weight = fused_weight.float() - (self._lora_scale * fusion)
219
+ self.weight.data = unfused_weight.to(device=device, dtype=dtype)
220
+
221
+ self.w_up = None
222
+ self.w_down = None
223
+
224
+ def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
225
+ if self.lora_layer is None:
226
+ # make sure to the functional Conv2D function as otherwise torch.compile's graph will break
227
+ # see: https://github.com/huggingface/diffusers/pull/4315
228
+ return F.conv2d(
229
+ hidden_states, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups
230
+ )
231
+ else:
232
+ original_outputs = F.conv2d(
233
+ hidden_states, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups
234
+ )
235
+ return original_outputs + (scale * self.lora_layer(hidden_states))
236
+
237
+
238
+ class LoRACompatibleLinear(nn.Linear):
239
+ """
240
+ A Linear layer that can be used with LoRA.
241
+ """
242
+
243
+ def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, **kwargs):
244
+ super().__init__(*args, **kwargs)
245
+ self.lora_layer = lora_layer
246
+
247
+ def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]):
248
+ self.lora_layer = lora_layer
249
+
250
+ def _fuse_lora(self, lora_scale: float = 1.0, safe_fusing: bool = False):
251
+ if self.lora_layer is None:
252
+ return
253
+
254
+ dtype, device = self.weight.data.dtype, self.weight.data.device
255
+
256
+ w_orig = self.weight.data.float()
257
+ w_up = self.lora_layer.up.weight.data.float()
258
+ w_down = self.lora_layer.down.weight.data.float()
259
+
260
+ if self.lora_layer.network_alpha is not None:
261
+ w_up = w_up * self.lora_layer.network_alpha / self.lora_layer.rank
262
+
263
+ fused_weight = w_orig + (lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])
264
+
265
+ if safe_fusing and torch.isnan(fused_weight).any().item():
266
+ raise ValueError(
267
+ "This LoRA weight seems to be broken. "
268
+ f"Encountered NaN values when trying to fuse LoRA weights for {self}."
269
+ "LoRA weights will not be fused."
270
+ )
271
+
272
+ self.weight.data = fused_weight.to(device=device, dtype=dtype)
273
+
274
+ # we can drop the lora layer now
275
+ self.lora_layer = None
276
+
277
+ # offload the up and down matrices to CPU to not blow the memory
278
+ self.w_up = w_up.cpu()
279
+ self.w_down = w_down.cpu()
280
+ self._lora_scale = lora_scale
281
+
282
+ def _unfuse_lora(self):
283
+ if not (getattr(self, "w_up", None) is not None and getattr(self, "w_down", None) is not None):
284
+ return
285
+
286
+ fused_weight = self.weight.data
287
+ dtype, device = fused_weight.dtype, fused_weight.device
288
+
289
+ w_up = self.w_up.to(device=device).float()
290
+ w_down = self.w_down.to(device).float()
291
+
292
+ unfused_weight = fused_weight.float() - (self._lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])
293
+ self.weight.data = unfused_weight.to(device=device, dtype=dtype)
294
+
295
+ self.w_up = None
296
+ self.w_down = None
297
+
298
+ def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
299
+ if self.lora_layer is None:
300
+ out = super().forward(hidden_states)
301
+ return out
302
+ else:
303
+ out = super().forward(hidden_states) + (scale * self.lora_layer(hidden_states))
304
+ return out
diffusers/models/modeling_flax_pytorch_utils.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch - Flax general utilities."""
16
+ import re
17
+
18
+ import jax.numpy as jnp
19
+ from flax.traverse_util import flatten_dict, unflatten_dict
20
+ from jax.random import PRNGKey
21
+
22
+ from ..utils import logging
23
+
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+
28
+ def rename_key(key):
29
+ regex = r"\w+[.]\d+"
30
+ pats = re.findall(regex, key)
31
+ for pat in pats:
32
+ key = key.replace(pat, "_".join(pat.split(".")))
33
+ return key
34
+
35
+
36
+ #####################
37
+ # PyTorch => Flax #
38
+ #####################
39
+
40
+
41
+ # Adapted from https://github.com/huggingface/transformers/blob/c603c80f46881ae18b2ca50770ef65fa4033eacd/src/transformers/modeling_flax_pytorch_utils.py#L69
42
+ # and https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py
43
+ def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict):
44
+ """Rename PT weight names to corresponding Flax weight names and reshape tensor if necessary"""
45
+ # conv norm or layer norm
46
+ renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
47
+
48
+ # rename attention layers
49
+ if len(pt_tuple_key) > 1:
50
+ for rename_from, rename_to in (
51
+ ("to_out_0", "proj_attn"),
52
+ ("to_k", "key"),
53
+ ("to_v", "value"),
54
+ ("to_q", "query"),
55
+ ):
56
+ if pt_tuple_key[-2] == rename_from:
57
+ weight_name = pt_tuple_key[-1]
58
+ weight_name = "kernel" if weight_name == "weight" else weight_name
59
+ renamed_pt_tuple_key = pt_tuple_key[:-2] + (rename_to, weight_name)
60
+ if renamed_pt_tuple_key in random_flax_state_dict:
61
+ assert random_flax_state_dict[renamed_pt_tuple_key].shape == pt_tensor.T.shape
62
+ return renamed_pt_tuple_key, pt_tensor.T
63
+
64
+ if (
65
+ any("norm" in str_ for str_ in pt_tuple_key)
66
+ and (pt_tuple_key[-1] == "bias")
67
+ and (pt_tuple_key[:-1] + ("bias",) not in random_flax_state_dict)
68
+ and (pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict)
69
+ ):
70
+ renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
71
+ return renamed_pt_tuple_key, pt_tensor
72
+ elif pt_tuple_key[-1] in ["weight", "gamma"] and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict:
73
+ renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
74
+ return renamed_pt_tuple_key, pt_tensor
75
+
76
+ # embedding
77
+ if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict:
78
+ pt_tuple_key = pt_tuple_key[:-1] + ("embedding",)
79
+ return renamed_pt_tuple_key, pt_tensor
80
+
81
+ # conv layer
82
+ renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
83
+ if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4:
84
+ pt_tensor = pt_tensor.transpose(2, 3, 1, 0)
85
+ return renamed_pt_tuple_key, pt_tensor
86
+
87
+ # linear layer
88
+ renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
89
+ if pt_tuple_key[-1] == "weight":
90
+ pt_tensor = pt_tensor.T
91
+ return renamed_pt_tuple_key, pt_tensor
92
+
93
+ # old PyTorch layer norm weight
94
+ renamed_pt_tuple_key = pt_tuple_key[:-1] + ("weight",)
95
+ if pt_tuple_key[-1] == "gamma":
96
+ return renamed_pt_tuple_key, pt_tensor
97
+
98
+ # old PyTorch layer norm bias
99
+ renamed_pt_tuple_key = pt_tuple_key[:-1] + ("bias",)
100
+ if pt_tuple_key[-1] == "beta":
101
+ return renamed_pt_tuple_key, pt_tensor
102
+
103
+ return pt_tuple_key, pt_tensor
104
+
105
+
106
+ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model, init_key=42):
107
+ # Step 1: Convert pytorch tensor to numpy
108
+ pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
109
+
110
+ # Step 2: Since the model is stateless, get random Flax params
111
+ random_flax_params = flax_model.init_weights(PRNGKey(init_key))
112
+
113
+ random_flax_state_dict = flatten_dict(random_flax_params)
114
+ flax_state_dict = {}
115
+
116
+ # Need to change some parameters name to match Flax names
117
+ for pt_key, pt_tensor in pt_state_dict.items():
118
+ renamed_pt_key = rename_key(pt_key)
119
+ pt_tuple_key = tuple(renamed_pt_key.split("."))
120
+
121
+ # Correctly rename weight parameters
122
+ flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict)
123
+
124
+ if flax_key in random_flax_state_dict:
125
+ if flax_tensor.shape != random_flax_state_dict[flax_key].shape:
126
+ raise ValueError(
127
+ f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape "
128
+ f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}."
129
+ )
130
+
131
+ # also add unexpected weight so that warning is thrown
132
+ flax_state_dict[flax_key] = jnp.asarray(flax_tensor)
133
+
134
+ return unflatten_dict(flax_state_dict)
diffusers/models/modeling_flax_utils.py ADDED
@@ -0,0 +1,560 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ from pickle import UnpicklingError
18
+ from typing import Any, Dict, Union
19
+
20
+ import jax
21
+ import jax.numpy as jnp
22
+ import msgpack.exceptions
23
+ from flax.core.frozen_dict import FrozenDict, unfreeze
24
+ from flax.serialization import from_bytes, to_bytes
25
+ from flax.traverse_util import flatten_dict, unflatten_dict
26
+ from huggingface_hub import create_repo, hf_hub_download
27
+ from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
28
+ from requests import HTTPError
29
+
30
+ from .. import __version__, is_torch_available
31
+ from ..utils import (
32
+ CONFIG_NAME,
33
+ DIFFUSERS_CACHE,
34
+ FLAX_WEIGHTS_NAME,
35
+ HUGGINGFACE_CO_RESOLVE_ENDPOINT,
36
+ WEIGHTS_NAME,
37
+ PushToHubMixin,
38
+ logging,
39
+ )
40
+ from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax
41
+
42
+
43
+ logger = logging.get_logger(__name__)
44
+
45
+
46
+ class FlaxModelMixin(PushToHubMixin):
47
+ r"""
48
+ Base class for all Flax models.
49
+
50
+ [`FlaxModelMixin`] takes care of storing the model configuration and provides methods for loading, downloading and
51
+ saving models.
52
+
53
+ - **config_name** ([`str`]) -- Filename to save a model to when calling [`~FlaxModelMixin.save_pretrained`].
54
+ """
55
+ config_name = CONFIG_NAME
56
+ _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
57
+ _flax_internal_args = ["name", "parent", "dtype"]
58
+
59
+ @classmethod
60
+ def _from_config(cls, config, **kwargs):
61
+ """
62
+ All context managers that the model should be initialized under go here.
63
+ """
64
+ return cls(config, **kwargs)
65
+
66
+ def _cast_floating_to(self, params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any:
67
+ """
68
+ Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`.
69
+ """
70
+
71
+ # taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27
72
+ def conditional_cast(param):
73
+ if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating):
74
+ param = param.astype(dtype)
75
+ return param
76
+
77
+ if mask is None:
78
+ return jax.tree_map(conditional_cast, params)
79
+
80
+ flat_params = flatten_dict(params)
81
+ flat_mask, _ = jax.tree_flatten(mask)
82
+
83
+ for masked, key in zip(flat_mask, flat_params.keys()):
84
+ if masked:
85
+ param = flat_params[key]
86
+ flat_params[key] = conditional_cast(param)
87
+
88
+ return unflatten_dict(flat_params)
89
+
90
+ def to_bf16(self, params: Union[Dict, FrozenDict], mask: Any = None):
91
+ r"""
92
+ Cast the floating-point `params` to `jax.numpy.bfloat16`. This returns a new `params` tree and does not cast
93
+ the `params` in place.
94
+
95
+ This method can be used on a TPU to explicitly convert the model parameters to bfloat16 precision to do full
96
+ half-precision training or to save weights in bfloat16 for inference in order to save memory and improve speed.
97
+
98
+ Arguments:
99
+ params (`Union[Dict, FrozenDict]`):
100
+ A `PyTree` of model parameters.
101
+ mask (`Union[Dict, FrozenDict]`):
102
+ A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True`
103
+ for params you want to cast, and `False` for those you want to skip.
104
+
105
+ Examples:
106
+
107
+ ```python
108
+ >>> from diffusers import FlaxUNet2DConditionModel
109
+
110
+ >>> # load model
111
+ >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
112
+ >>> # By default, the model parameters will be in fp32 precision, to cast these to bfloat16 precision
113
+ >>> params = model.to_bf16(params)
114
+ >>> # If you don't want to cast certain parameters (for example layer norm bias and scale)
115
+ >>> # then pass the mask as follows
116
+ >>> from flax import traverse_util
117
+
118
+ >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
119
+ >>> flat_params = traverse_util.flatten_dict(params)
120
+ >>> mask = {
121
+ ... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale"))
122
+ ... for path in flat_params
123
+ ... }
124
+ >>> mask = traverse_util.unflatten_dict(mask)
125
+ >>> params = model.to_bf16(params, mask)
126
+ ```"""
127
+ return self._cast_floating_to(params, jnp.bfloat16, mask)
128
+
129
+ def to_fp32(self, params: Union[Dict, FrozenDict], mask: Any = None):
130
+ r"""
131
+ Cast the floating-point `params` to `jax.numpy.float32`. This method can be used to explicitly convert the
132
+ model parameters to fp32 precision. This returns a new `params` tree and does not cast the `params` in place.
133
+
134
+ Arguments:
135
+ params (`Union[Dict, FrozenDict]`):
136
+ A `PyTree` of model parameters.
137
+ mask (`Union[Dict, FrozenDict]`):
138
+ A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True`
139
+ for params you want to cast, and `False` for those you want to skip.
140
+
141
+ Examples:
142
+
143
+ ```python
144
+ >>> from diffusers import FlaxUNet2DConditionModel
145
+
146
+ >>> # Download model and configuration from huggingface.co
147
+ >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
148
+ >>> # By default, the model params will be in fp32, to illustrate the use of this method,
149
+ >>> # we'll first cast to fp16 and back to fp32
150
+ >>> params = model.to_f16(params)
151
+ >>> # now cast back to fp32
152
+ >>> params = model.to_fp32(params)
153
+ ```"""
154
+ return self._cast_floating_to(params, jnp.float32, mask)
155
+
156
+ def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None):
157
+ r"""
158
+ Cast the floating-point `params` to `jax.numpy.float16`. This returns a new `params` tree and does not cast the
159
+ `params` in place.
160
+
161
+ This method can be used on a GPU to explicitly convert the model parameters to float16 precision to do full
162
+ half-precision training or to save weights in float16 for inference in order to save memory and improve speed.
163
+
164
+ Arguments:
165
+ params (`Union[Dict, FrozenDict]`):
166
+ A `PyTree` of model parameters.
167
+ mask (`Union[Dict, FrozenDict]`):
168
+ A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True`
169
+ for params you want to cast, and `False` for those you want to skip.
170
+
171
+ Examples:
172
+
173
+ ```python
174
+ >>> from diffusers import FlaxUNet2DConditionModel
175
+
176
+ >>> # load model
177
+ >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
178
+ >>> # By default, the model params will be in fp32, to cast these to float16
179
+ >>> params = model.to_fp16(params)
180
+ >>> # If you want don't want to cast certain parameters (for example layer norm bias and scale)
181
+ >>> # then pass the mask as follows
182
+ >>> from flax import traverse_util
183
+
184
+ >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
185
+ >>> flat_params = traverse_util.flatten_dict(params)
186
+ >>> mask = {
187
+ ... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale"))
188
+ ... for path in flat_params
189
+ ... }
190
+ >>> mask = traverse_util.unflatten_dict(mask)
191
+ >>> params = model.to_fp16(params, mask)
192
+ ```"""
193
+ return self._cast_floating_to(params, jnp.float16, mask)
194
+
195
+ def init_weights(self, rng: jax.Array) -> Dict:
196
+ raise NotImplementedError(f"init_weights method has to be implemented for {self}")
197
+
198
+ @classmethod
199
+ def from_pretrained(
200
+ cls,
201
+ pretrained_model_name_or_path: Union[str, os.PathLike],
202
+ dtype: jnp.dtype = jnp.float32,
203
+ *model_args,
204
+ **kwargs,
205
+ ):
206
+ r"""
207
+ Instantiate a pretrained Flax model from a pretrained model configuration.
208
+
209
+ Parameters:
210
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
211
+ Can be either:
212
+
213
+ - A string, the *model id* (for example `runwayml/stable-diffusion-v1-5`) of a pretrained model
214
+ hosted on the Hub.
215
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
216
+ using [`~FlaxModelMixin.save_pretrained`].
217
+ dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
218
+ The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
219
+ `jax.numpy.bfloat16` (on TPUs).
220
+
221
+ This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
222
+ specified, all the computation will be performed with the given `dtype`.
223
+
224
+ <Tip>
225
+
226
+ This only specifies the dtype of the *computation* and does not influence the dtype of model
227
+ parameters.
228
+
229
+ If you wish to change the dtype of the model parameters, see [`~FlaxModelMixin.to_fp16`] and
230
+ [`~FlaxModelMixin.to_bf16`].
231
+
232
+ </Tip>
233
+
234
+ model_args (sequence of positional arguments, *optional*):
235
+ All remaining positional arguments are passed to the underlying model's `__init__` method.
236
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
237
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
238
+ is not used.
239
+ force_download (`bool`, *optional*, defaults to `False`):
240
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
241
+ cached versions if they exist.
242
+ resume_download (`bool`, *optional*, defaults to `False`):
243
+ Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
244
+ incompletely downloaded files are deleted.
245
+ proxies (`Dict[str, str]`, *optional*):
246
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
247
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
248
+ local_files_only(`bool`, *optional*, defaults to `False`):
249
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
250
+ won't be downloaded from the Hub.
251
+ revision (`str`, *optional*, defaults to `"main"`):
252
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
253
+ allowed by Git.
254
+ from_pt (`bool`, *optional*, defaults to `False`):
255
+ Load the model weights from a PyTorch checkpoint save file.
256
+ kwargs (remaining dictionary of keyword arguments, *optional*):
257
+ Can be used to update the configuration object (after it is loaded) and initiate the model (for
258
+ example, `output_attentions=True`). Behaves differently depending on whether a `config` is provided or
259
+ automatically loaded:
260
+
261
+ - If a configuration is provided with `config`, `kwargs` are directly passed to the underlying
262
+ model's `__init__` method (we assume all relevant updates to the configuration have already been
263
+ done).
264
+ - If a configuration is not provided, `kwargs` are first passed to the configuration class
265
+ initialization function [`~ConfigMixin.from_config`]. Each key of the `kwargs` that corresponds
266
+ to a configuration attribute is used to override said attribute with the supplied `kwargs` value.
267
+ Remaining keys that do not correspond to any configuration attribute are passed to the underlying
268
+ model's `__init__` function.
269
+
270
+ Examples:
271
+
272
+ ```python
273
+ >>> from diffusers import FlaxUNet2DConditionModel
274
+
275
+ >>> # Download model and configuration from huggingface.co and cache.
276
+ >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
277
+ >>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable).
278
+ >>> model, params = FlaxUNet2DConditionModel.from_pretrained("./test/saved_model/")
279
+ ```
280
+
281
+ If you get the error message below, you need to finetune the weights for your downstream task:
282
+
283
+ ```bash
284
+ Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
285
+ - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
286
+ You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
287
+ ```
288
+ """
289
+ config = kwargs.pop("config", None)
290
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
291
+ force_download = kwargs.pop("force_download", False)
292
+ from_pt = kwargs.pop("from_pt", False)
293
+ resume_download = kwargs.pop("resume_download", False)
294
+ proxies = kwargs.pop("proxies", None)
295
+ local_files_only = kwargs.pop("local_files_only", False)
296
+ use_auth_token = kwargs.pop("use_auth_token", None)
297
+ revision = kwargs.pop("revision", None)
298
+ subfolder = kwargs.pop("subfolder", None)
299
+
300
+ user_agent = {
301
+ "diffusers": __version__,
302
+ "file_type": "model",
303
+ "framework": "flax",
304
+ }
305
+
306
+ # Load config if we don't provide one
307
+ if config is None:
308
+ config, unused_kwargs = cls.load_config(
309
+ pretrained_model_name_or_path,
310
+ cache_dir=cache_dir,
311
+ return_unused_kwargs=True,
312
+ force_download=force_download,
313
+ resume_download=resume_download,
314
+ proxies=proxies,
315
+ local_files_only=local_files_only,
316
+ use_auth_token=use_auth_token,
317
+ revision=revision,
318
+ subfolder=subfolder,
319
+ **kwargs,
320
+ )
321
+
322
+ model, model_kwargs = cls.from_config(config, dtype=dtype, return_unused_kwargs=True, **unused_kwargs)
323
+
324
+ # Load model
325
+ pretrained_path_with_subfolder = (
326
+ pretrained_model_name_or_path
327
+ if subfolder is None
328
+ else os.path.join(pretrained_model_name_or_path, subfolder)
329
+ )
330
+ if os.path.isdir(pretrained_path_with_subfolder):
331
+ if from_pt:
332
+ if not os.path.isfile(os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)):
333
+ raise EnvironmentError(
334
+ f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_path_with_subfolder} "
335
+ )
336
+ model_file = os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)
337
+ elif os.path.isfile(os.path.join(pretrained_path_with_subfolder, FLAX_WEIGHTS_NAME)):
338
+ # Load from a Flax checkpoint
339
+ model_file = os.path.join(pretrained_path_with_subfolder, FLAX_WEIGHTS_NAME)
340
+ # Check if pytorch weights exist instead
341
+ elif os.path.isfile(os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)):
342
+ raise EnvironmentError(
343
+ f"{WEIGHTS_NAME} file found in directory {pretrained_path_with_subfolder}. Please load the model"
344
+ " using `from_pt=True`."
345
+ )
346
+ else:
347
+ raise EnvironmentError(
348
+ f"Error no file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory "
349
+ f"{pretrained_path_with_subfolder}."
350
+ )
351
+ else:
352
+ try:
353
+ model_file = hf_hub_download(
354
+ pretrained_model_name_or_path,
355
+ filename=FLAX_WEIGHTS_NAME if not from_pt else WEIGHTS_NAME,
356
+ cache_dir=cache_dir,
357
+ force_download=force_download,
358
+ proxies=proxies,
359
+ resume_download=resume_download,
360
+ local_files_only=local_files_only,
361
+ use_auth_token=use_auth_token,
362
+ user_agent=user_agent,
363
+ subfolder=subfolder,
364
+ revision=revision,
365
+ )
366
+
367
+ except RepositoryNotFoundError:
368
+ raise EnvironmentError(
369
+ f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
370
+ "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
371
+ "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
372
+ "login`."
373
+ )
374
+ except RevisionNotFoundError:
375
+ raise EnvironmentError(
376
+ f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
377
+ "this model name. Check the model page at "
378
+ f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
379
+ )
380
+ except EntryNotFoundError:
381
+ raise EnvironmentError(
382
+ f"{pretrained_model_name_or_path} does not appear to have a file named {FLAX_WEIGHTS_NAME}."
383
+ )
384
+ except HTTPError as err:
385
+ raise EnvironmentError(
386
+ f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n"
387
+ f"{err}"
388
+ )
389
+ except ValueError:
390
+ raise EnvironmentError(
391
+ f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
392
+ f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
393
+ f" directory containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}.\nCheckout your"
394
+ " internet connection or see how to run the library in offline mode at"
395
+ " 'https://huggingface.co/docs/transformers/installation#offline-mode'."
396
+ )
397
+ except EnvironmentError:
398
+ raise EnvironmentError(
399
+ f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
400
+ "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
401
+ f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
402
+ f"containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}."
403
+ )
404
+
405
+ if from_pt:
406
+ if is_torch_available():
407
+ from .modeling_utils import load_state_dict
408
+ else:
409
+ raise EnvironmentError(
410
+ "Can't load the model in PyTorch format because PyTorch is not installed. "
411
+ "Please, install PyTorch or use native Flax weights."
412
+ )
413
+
414
+ # Step 1: Get the pytorch file
415
+ pytorch_model_file = load_state_dict(model_file)
416
+
417
+ # Step 2: Convert the weights
418
+ state = convert_pytorch_state_dict_to_flax(pytorch_model_file, model)
419
+ else:
420
+ try:
421
+ with open(model_file, "rb") as state_f:
422
+ state = from_bytes(cls, state_f.read())
423
+ except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
424
+ try:
425
+ with open(model_file) as f:
426
+ if f.read().startswith("version"):
427
+ raise OSError(
428
+ "You seem to have cloned a repository without having git-lfs installed. Please"
429
+ " install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
430
+ " folder you cloned."
431
+ )
432
+ else:
433
+ raise ValueError from e
434
+ except (UnicodeDecodeError, ValueError):
435
+ raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ")
436
+ # make sure all arrays are stored as jnp.ndarray
437
+ # NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4:
438
+ # https://github.com/google/flax/issues/1261
439
+ state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.devices("cpu")[0]), state)
440
+
441
+ # flatten dicts
442
+ state = flatten_dict(state)
443
+
444
+ params_shape_tree = jax.eval_shape(model.init_weights, rng=jax.random.PRNGKey(0))
445
+ required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys())
446
+
447
+ shape_state = flatten_dict(unfreeze(params_shape_tree))
448
+
449
+ missing_keys = required_params - set(state.keys())
450
+ unexpected_keys = set(state.keys()) - required_params
451
+
452
+ if missing_keys:
453
+ logger.warning(
454
+ f"The checkpoint {pretrained_model_name_or_path} is missing required keys: {missing_keys}. "
455
+ "Make sure to call model.init_weights to initialize the missing weights."
456
+ )
457
+ cls._missing_keys = missing_keys
458
+
459
+ for key in state.keys():
460
+ if key in shape_state and state[key].shape != shape_state[key].shape:
461
+ raise ValueError(
462
+ f"Trying to load the pretrained weight for {key} failed: checkpoint has shape "
463
+ f"{state[key].shape} which is incompatible with the model shape {shape_state[key].shape}. "
464
+ )
465
+
466
+ # remove unexpected keys to not be saved again
467
+ for unexpected_key in unexpected_keys:
468
+ del state[unexpected_key]
469
+
470
+ if len(unexpected_keys) > 0:
471
+ logger.warning(
472
+ f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
473
+ f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
474
+ f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"
475
+ " with another architecture."
476
+ )
477
+ else:
478
+ logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
479
+
480
+ if len(missing_keys) > 0:
481
+ logger.warning(
482
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
483
+ f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
484
+ " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
485
+ )
486
+ else:
487
+ logger.info(
488
+ f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
489
+ f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint"
490
+ f" was trained on, you can already use {model.__class__.__name__} for predictions without further"
491
+ " training."
492
+ )
493
+
494
+ return model, unflatten_dict(state)
495
+
496
+ def save_pretrained(
497
+ self,
498
+ save_directory: Union[str, os.PathLike],
499
+ params: Union[Dict, FrozenDict],
500
+ is_main_process: bool = True,
501
+ push_to_hub: bool = False,
502
+ **kwargs,
503
+ ):
504
+ """
505
+ Save a model and its configuration file to a directory so that it can be reloaded using the
506
+ [`~FlaxModelMixin.from_pretrained`] class method.
507
+
508
+ Arguments:
509
+ save_directory (`str` or `os.PathLike`):
510
+ Directory to save a model and its configuration file to. Will be created if it doesn't exist.
511
+ params (`Union[Dict, FrozenDict]`):
512
+ A `PyTree` of model parameters.
513
+ is_main_process (`bool`, *optional*, defaults to `True`):
514
+ Whether the process calling this is the main process or not. Useful during distributed training and you
515
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
516
+ process to avoid race conditions.
517
+ push_to_hub (`bool`, *optional*, defaults to `False`):
518
+ Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
519
+ repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
520
+ namespace).
521
+ kwargs (`Dict[str, Any]`, *optional*):
522
+ Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
523
+ """
524
+ if os.path.isfile(save_directory):
525
+ logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
526
+ return
527
+
528
+ os.makedirs(save_directory, exist_ok=True)
529
+
530
+ if push_to_hub:
531
+ commit_message = kwargs.pop("commit_message", None)
532
+ private = kwargs.pop("private", False)
533
+ create_pr = kwargs.pop("create_pr", False)
534
+ token = kwargs.pop("token", None)
535
+ repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
536
+ repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
537
+
538
+ model_to_save = self
539
+
540
+ # Attach architecture to the config
541
+ # Save the config
542
+ if is_main_process:
543
+ model_to_save.save_config(save_directory)
544
+
545
+ # save model
546
+ output_model_file = os.path.join(save_directory, FLAX_WEIGHTS_NAME)
547
+ with open(output_model_file, "wb") as f:
548
+ model_bytes = to_bytes(params)
549
+ f.write(model_bytes)
550
+
551
+ logger.info(f"Model weights saved in {output_model_file}")
552
+
553
+ if push_to_hub:
554
+ self._upload_folder(
555
+ save_directory,
556
+ repo_id,
557
+ token=token,
558
+ commit_message=commit_message,
559
+ create_pr=create_pr,
560
+ )
diffusers/models/modeling_pytorch_flax_utils.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch - Flax general utilities."""
16
+
17
+ from pickle import UnpicklingError
18
+
19
+ import jax
20
+ import jax.numpy as jnp
21
+ import numpy as np
22
+ from flax.serialization import from_bytes
23
+ from flax.traverse_util import flatten_dict
24
+
25
+ from ..utils import logging
26
+
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+
31
+ #####################
32
+ # Flax => PyTorch #
33
+ #####################
34
+
35
+
36
+ # from https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_flax_pytorch_utils.py#L224-L352
37
+ def load_flax_checkpoint_in_pytorch_model(pt_model, model_file):
38
+ try:
39
+ with open(model_file, "rb") as flax_state_f:
40
+ flax_state = from_bytes(None, flax_state_f.read())
41
+ except UnpicklingError as e:
42
+ try:
43
+ with open(model_file) as f:
44
+ if f.read().startswith("version"):
45
+ raise OSError(
46
+ "You seem to have cloned a repository without having git-lfs installed. Please"
47
+ " install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
48
+ " folder you cloned."
49
+ )
50
+ else:
51
+ raise ValueError from e
52
+ except (UnicodeDecodeError, ValueError):
53
+ raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ")
54
+
55
+ return load_flax_weights_in_pytorch_model(pt_model, flax_state)
56
+
57
+
58
+ def load_flax_weights_in_pytorch_model(pt_model, flax_state):
59
+ """Load flax checkpoints in a PyTorch model"""
60
+
61
+ try:
62
+ import torch # noqa: F401
63
+ except ImportError:
64
+ logger.error(
65
+ "Loading Flax weights in PyTorch requires both PyTorch and Flax to be installed. Please see"
66
+ " https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation"
67
+ " instructions."
68
+ )
69
+ raise
70
+
71
+ # check if we have bf16 weights
72
+ is_type_bf16 = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype == jnp.bfloat16, flax_state)).values()
73
+ if any(is_type_bf16):
74
+ # convert all weights to fp32 if they are bf16 since torch.from_numpy can-not handle bf16
75
+
76
+ # and bf16 is not fully supported in PT yet.
77
+ logger.warning(
78
+ "Found ``bfloat16`` weights in Flax model. Casting all ``bfloat16`` weights to ``float32`` "
79
+ "before loading those in PyTorch model."
80
+ )
81
+ flax_state = jax.tree_util.tree_map(
82
+ lambda params: params.astype(np.float32) if params.dtype == jnp.bfloat16 else params, flax_state
83
+ )
84
+
85
+ pt_model.base_model_prefix = ""
86
+
87
+ flax_state_dict = flatten_dict(flax_state, sep=".")
88
+ pt_model_dict = pt_model.state_dict()
89
+
90
+ # keep track of unexpected & missing keys
91
+ unexpected_keys = []
92
+ missing_keys = set(pt_model_dict.keys())
93
+
94
+ for flax_key_tuple, flax_tensor in flax_state_dict.items():
95
+ flax_key_tuple_array = flax_key_tuple.split(".")
96
+
97
+ if flax_key_tuple_array[-1] == "kernel" and flax_tensor.ndim == 4:
98
+ flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"]
99
+ flax_tensor = jnp.transpose(flax_tensor, (3, 2, 0, 1))
100
+ elif flax_key_tuple_array[-1] == "kernel":
101
+ flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"]
102
+ flax_tensor = flax_tensor.T
103
+ elif flax_key_tuple_array[-1] == "scale":
104
+ flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"]
105
+
106
+ if "time_embedding" not in flax_key_tuple_array:
107
+ for i, flax_key_tuple_string in enumerate(flax_key_tuple_array):
108
+ flax_key_tuple_array[i] = (
109
+ flax_key_tuple_string.replace("_0", ".0")
110
+ .replace("_1", ".1")
111
+ .replace("_2", ".2")
112
+ .replace("_3", ".3")
113
+ .replace("_4", ".4")
114
+ .replace("_5", ".5")
115
+ .replace("_6", ".6")
116
+ .replace("_7", ".7")
117
+ .replace("_8", ".8")
118
+ .replace("_9", ".9")
119
+ )
120
+
121
+ flax_key = ".".join(flax_key_tuple_array)
122
+
123
+ if flax_key in pt_model_dict:
124
+ if flax_tensor.shape != pt_model_dict[flax_key].shape:
125
+ raise ValueError(
126
+ f"Flax checkpoint seems to be incorrect. Weight {flax_key_tuple} was expected "
127
+ f"to be of shape {pt_model_dict[flax_key].shape}, but is {flax_tensor.shape}."
128
+ )
129
+ else:
130
+ # add weight to pytorch dict
131
+ flax_tensor = np.asarray(flax_tensor) if not isinstance(flax_tensor, np.ndarray) else flax_tensor
132
+ pt_model_dict[flax_key] = torch.from_numpy(flax_tensor)
133
+ # remove from missing keys
134
+ missing_keys.remove(flax_key)
135
+ else:
136
+ # weight is not expected by PyTorch model
137
+ unexpected_keys.append(flax_key)
138
+
139
+ pt_model.load_state_dict(pt_model_dict)
140
+
141
+ # re-transform missing_keys to list
142
+ missing_keys = list(missing_keys)
143
+
144
+ if len(unexpected_keys) > 0:
145
+ logger.warning(
146
+ "Some weights of the Flax model were not used when initializing the PyTorch model"
147
+ f" {pt_model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are initializing"
148
+ f" {pt_model.__class__.__name__} from a Flax model trained on another task or with another architecture"
149
+ " (e.g. initializing a BertForSequenceClassification model from a FlaxBertForPreTraining model).\n- This"
150
+ f" IS NOT expected if you are initializing {pt_model.__class__.__name__} from a Flax model that you expect"
151
+ " to be exactly identical (e.g. initializing a BertForSequenceClassification model from a"
152
+ " FlaxBertForSequenceClassification model)."
153
+ )
154
+ if len(missing_keys) > 0:
155
+ logger.warning(
156
+ f"Some weights of {pt_model.__class__.__name__} were not initialized from the Flax model and are newly"
157
+ f" initialized: {missing_keys}\nYou should probably TRAIN this model on a down-stream task to be able to"
158
+ " use it for predictions and inference."
159
+ )
160
+
161
+ return pt_model
diffusers/models/modeling_utils.py ADDED
@@ -0,0 +1,1158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import inspect
18
+ import itertools
19
+ import os
20
+ import re
21
+ from functools import partial
22
+ from typing import Any, Callable, List, Optional, Tuple, Union
23
+
24
+ import safetensors
25
+ import torch
26
+ from huggingface_hub import create_repo
27
+ from torch import Tensor, device, nn
28
+
29
+ from .. import __version__
30
+ from ..utils import (
31
+ CONFIG_NAME,
32
+ DIFFUSERS_CACHE,
33
+ FLAX_WEIGHTS_NAME,
34
+ HF_HUB_OFFLINE,
35
+ MIN_PEFT_VERSION,
36
+ SAFETENSORS_WEIGHTS_NAME,
37
+ WEIGHTS_NAME,
38
+ _add_variant,
39
+ _get_model_file,
40
+ check_peft_version,
41
+ deprecate,
42
+ is_accelerate_available,
43
+ is_torch_version,
44
+ logging,
45
+ )
46
+ from ..utils.hub_utils import PushToHubMixin
47
+
48
+
49
+ logger = logging.get_logger(__name__)
50
+
51
+
52
+ if is_torch_version(">=", "1.9.0"):
53
+ _LOW_CPU_MEM_USAGE_DEFAULT = True
54
+ else:
55
+ _LOW_CPU_MEM_USAGE_DEFAULT = False
56
+
57
+
58
+ if is_accelerate_available():
59
+ import accelerate
60
+ from accelerate.utils import set_module_tensor_to_device
61
+ from accelerate.utils.versions import is_torch_version
62
+
63
+
64
+ def get_parameter_device(parameter: torch.nn.Module):
65
+ try:
66
+ parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())
67
+ return next(parameters_and_buffers).device
68
+ except StopIteration:
69
+ # For torch.nn.DataParallel compatibility in PyTorch 1.5
70
+
71
+ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
72
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
73
+ return tuples
74
+
75
+ gen = parameter._named_members(get_members_fn=find_tensor_attributes)
76
+ first_tuple = next(gen)
77
+ return first_tuple[1].device
78
+
79
+
80
+ def get_parameter_dtype(parameter: torch.nn.Module):
81
+ try:
82
+ params = tuple(parameter.parameters())
83
+ if len(params) > 0:
84
+ return params[0].dtype
85
+
86
+ buffers = tuple(parameter.buffers())
87
+ if len(buffers) > 0:
88
+ return buffers[0].dtype
89
+
90
+ except StopIteration:
91
+ # For torch.nn.DataParallel compatibility in PyTorch 1.5
92
+
93
+ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
94
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
95
+ return tuples
96
+
97
+ gen = parameter._named_members(get_members_fn=find_tensor_attributes)
98
+ first_tuple = next(gen)
99
+ return first_tuple[1].dtype
100
+
101
+
102
+ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None):
103
+ """
104
+ Reads a checkpoint file, returning properly formatted errors if they arise.
105
+ """
106
+ try:
107
+ if os.path.basename(checkpoint_file) == _add_variant(WEIGHTS_NAME, variant):
108
+ return torch.load(checkpoint_file, map_location="cpu")
109
+ else:
110
+ return safetensors.torch.load_file(checkpoint_file, device="cpu")
111
+ except Exception as e:
112
+ try:
113
+ with open(checkpoint_file) as f:
114
+ if f.read().startswith("version"):
115
+ raise OSError(
116
+ "You seem to have cloned a repository without having git-lfs installed. Please install "
117
+ "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
118
+ "you cloned."
119
+ )
120
+ else:
121
+ raise ValueError(
122
+ f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained "
123
+ "model. Make sure you have saved the model properly."
124
+ ) from e
125
+ except (UnicodeDecodeError, ValueError):
126
+ raise OSError(
127
+ f"Unable to load weights from checkpoint file for '{checkpoint_file}' "
128
+ f"at '{checkpoint_file}'. "
129
+ "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True."
130
+ )
131
+
132
+
133
+ def load_model_dict_into_meta(model, state_dict, device=None, dtype=None, model_name_or_path=None):
134
+ device = device or torch.device("cpu")
135
+ dtype = dtype or torch.float32
136
+
137
+ accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
138
+
139
+ unexpected_keys = []
140
+ empty_state_dict = model.state_dict()
141
+ for param_name, param in state_dict.items():
142
+ if param_name not in empty_state_dict:
143
+ unexpected_keys.append(param_name)
144
+ continue
145
+
146
+ if empty_state_dict[param_name].shape != param.shape:
147
+ model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
148
+ raise ValueError(
149
+ f"Cannot load {model_name_or_path_str}because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
150
+ )
151
+
152
+ if accepts_dtype:
153
+ set_module_tensor_to_device(model, param_name, device, value=param, dtype=dtype)
154
+ else:
155
+ set_module_tensor_to_device(model, param_name, device, value=param)
156
+ return unexpected_keys
157
+
158
+
159
+ def _load_state_dict_into_model(model_to_load, state_dict):
160
+ # Convert old format to new format if needed from a PyTorch state_dict
161
+ # copy state_dict so _load_from_state_dict can modify it
162
+ state_dict = state_dict.copy()
163
+ error_msgs = []
164
+
165
+ # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
166
+ # so we need to apply the function recursively.
167
+ def load(module: torch.nn.Module, prefix=""):
168
+ args = (state_dict, prefix, {}, True, [], [], error_msgs)
169
+ module._load_from_state_dict(*args)
170
+
171
+ for name, child in module._modules.items():
172
+ if child is not None:
173
+ load(child, prefix + name + ".")
174
+
175
+ load(model_to_load)
176
+
177
+ return error_msgs
178
+
179
+
180
+ class ModelMixin(torch.nn.Module, PushToHubMixin):
181
+ r"""
182
+ Base class for all models.
183
+
184
+ [`ModelMixin`] takes care of storing the model configuration and provides methods for loading, downloading and
185
+ saving models.
186
+
187
+ - **config_name** ([`str`]) -- Filename to save a model to when calling [`~models.ModelMixin.save_pretrained`].
188
+ """
189
+ config_name = CONFIG_NAME
190
+ _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
191
+ _supports_gradient_checkpointing = False
192
+ _keys_to_ignore_on_load_unexpected = None
193
+ _hf_peft_config_loaded = False
194
+
195
+ def __init__(self):
196
+ super().__init__()
197
+
198
+ def __getattr__(self, name: str) -> Any:
199
+ """The only reason we overwrite `getattr` here is to gracefully deprecate accessing
200
+ config attributes directly. See https://github.com/huggingface/diffusers/pull/3129 We need to overwrite
201
+ __getattr__ here in addition so that we don't trigger `torch.nn.Module`'s __getattr__':
202
+ https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
203
+ """
204
+
205
+ is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name)
206
+ is_attribute = name in self.__dict__
207
+
208
+ if is_in_config and not is_attribute:
209
+ deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'unet.config.{name}'."
210
+ deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False, stacklevel=3)
211
+ return self._internal_dict[name]
212
+
213
+ # call PyTorch's https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
214
+ return super().__getattr__(name)
215
+
216
+ @property
217
+ def is_gradient_checkpointing(self) -> bool:
218
+ """
219
+ Whether gradient checkpointing is activated for this model or not.
220
+ """
221
+ return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
222
+
223
+ def enable_gradient_checkpointing(self):
224
+ """
225
+ Activates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
226
+ *checkpoint activations* in other frameworks).
227
+ """
228
+ if not self._supports_gradient_checkpointing:
229
+ raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
230
+ self.apply(partial(self._set_gradient_checkpointing, value=True))
231
+
232
+ def disable_gradient_checkpointing(self):
233
+ """
234
+ Deactivates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
235
+ *checkpoint activations* in other frameworks).
236
+ """
237
+ if self._supports_gradient_checkpointing:
238
+ self.apply(partial(self._set_gradient_checkpointing, value=False))
239
+
240
+ def set_use_memory_efficient_attention_xformers(
241
+ self, valid: bool, attention_op: Optional[Callable] = None
242
+ ) -> None:
243
+ # Recursively walk through all the children.
244
+ # Any children which exposes the set_use_memory_efficient_attention_xformers method
245
+ # gets the message
246
+ def fn_recursive_set_mem_eff(module: torch.nn.Module):
247
+ if hasattr(module, "set_use_memory_efficient_attention_xformers"):
248
+ module.set_use_memory_efficient_attention_xformers(valid, attention_op)
249
+
250
+ for child in module.children():
251
+ fn_recursive_set_mem_eff(child)
252
+
253
+ for module in self.children():
254
+ if isinstance(module, torch.nn.Module):
255
+ fn_recursive_set_mem_eff(module)
256
+
257
+ def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
258
+ r"""
259
+ Enable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/).
260
+
261
+ When this option is enabled, you should observe lower GPU memory usage and a potential speed up during
262
+ inference. Speed up during training is not guaranteed.
263
+
264
+ <Tip warning={true}>
265
+
266
+ ⚠️ When memory efficient attention and sliced attention are both enabled, memory efficient attention takes
267
+ precedent.
268
+
269
+ </Tip>
270
+
271
+ Parameters:
272
+ attention_op (`Callable`, *optional*):
273
+ Override the default `None` operator for use as `op` argument to the
274
+ [`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention)
275
+ function of xFormers.
276
+
277
+ Examples:
278
+
279
+ ```py
280
+ >>> import torch
281
+ >>> from diffusers import UNet2DConditionModel
282
+ >>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
283
+
284
+ >>> model = UNet2DConditionModel.from_pretrained(
285
+ ... "stabilityai/stable-diffusion-2-1", subfolder="unet", torch_dtype=torch.float16
286
+ ... )
287
+ >>> model = model.to("cuda")
288
+ >>> model.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
289
+ ```
290
+ """
291
+ self.set_use_memory_efficient_attention_xformers(True, attention_op)
292
+
293
+ def disable_xformers_memory_efficient_attention(self):
294
+ r"""
295
+ Disable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/).
296
+ """
297
+ self.set_use_memory_efficient_attention_xformers(False)
298
+
299
+ def add_adapter(self, adapter_config, adapter_name: str = "default") -> None:
300
+ r"""
301
+ Adds a new adapter to the current model for training. If no adapter name is passed, a default name is assigned
302
+ to the adapter to follow the convention of the PEFT library.
303
+
304
+ If you are not familiar with adapters and PEFT methods, we invite you to read more about them in the PEFT
305
+ [documentation](https://huggingface.co/docs/peft).
306
+
307
+ Args:
308
+ adapter_config (`[~peft.PeftConfig]`):
309
+ The configuration of the adapter to add; supported adapters are non-prefix tuning and adaption prompt
310
+ methods.
311
+ adapter_name (`str`, *optional*, defaults to `"default"`):
312
+ The name of the adapter to add. If no name is passed, a default name is assigned to the adapter.
313
+ """
314
+ check_peft_version(min_version=MIN_PEFT_VERSION)
315
+
316
+ from peft import PeftConfig, inject_adapter_in_model
317
+
318
+ if not self._hf_peft_config_loaded:
319
+ self._hf_peft_config_loaded = True
320
+ elif adapter_name in self.peft_config:
321
+ raise ValueError(f"Adapter with name {adapter_name} already exists. Please use a different name.")
322
+
323
+ if not isinstance(adapter_config, PeftConfig):
324
+ raise ValueError(
325
+ f"adapter_config should be an instance of PeftConfig. Got {type(adapter_config)} instead."
326
+ )
327
+
328
+ # Unlike transformers, here we don't need to retrieve the name_or_path of the unet as the loading logic is
329
+ # handled by the `load_lora_layers` or `LoraLoaderMixin`. Therefore we set it to `None` here.
330
+ adapter_config.base_model_name_or_path = None
331
+ inject_adapter_in_model(adapter_config, self, adapter_name)
332
+ self.set_adapter(adapter_name)
333
+
334
+ def set_adapter(self, adapter_name: Union[str, List[str]]) -> None:
335
+ """
336
+ Sets a specific adapter by forcing the model to only use that adapter and disables the other adapters.
337
+
338
+ If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
339
+ official documentation: https://huggingface.co/docs/peft
340
+
341
+ Args:
342
+ adapter_name (Union[str, List[str]])):
343
+ The list of adapters to set or the adapter name in case of single adapter.
344
+ """
345
+ check_peft_version(min_version=MIN_PEFT_VERSION)
346
+
347
+ if not self._hf_peft_config_loaded:
348
+ raise ValueError("No adapter loaded. Please load an adapter first.")
349
+
350
+ if isinstance(adapter_name, str):
351
+ adapter_name = [adapter_name]
352
+
353
+ missing = set(adapter_name) - set(self.peft_config)
354
+ if len(missing) > 0:
355
+ raise ValueError(
356
+ f"Following adapter(s) could not be found: {', '.join(missing)}. Make sure you are passing the correct adapter name(s)."
357
+ f" current loaded adapters are: {list(self.peft_config.keys())}"
358
+ )
359
+
360
+ from peft.tuners.tuners_utils import BaseTunerLayer
361
+
362
+ _adapters_has_been_set = False
363
+
364
+ for _, module in self.named_modules():
365
+ if isinstance(module, BaseTunerLayer):
366
+ if hasattr(module, "set_adapter"):
367
+ module.set_adapter(adapter_name)
368
+ # Previous versions of PEFT does not support multi-adapter inference
369
+ elif not hasattr(module, "set_adapter") and len(adapter_name) != 1:
370
+ raise ValueError(
371
+ "You are trying to set multiple adapters and you have a PEFT version that does not support multi-adapter inference. Please upgrade to the latest version of PEFT."
372
+ " `pip install -U peft` or `pip install -U git+https://github.com/huggingface/peft.git`"
373
+ )
374
+ else:
375
+ module.active_adapter = adapter_name
376
+ _adapters_has_been_set = True
377
+
378
+ if not _adapters_has_been_set:
379
+ raise ValueError(
380
+ "Did not succeeded in setting the adapter. Please make sure you are using a model that supports adapters."
381
+ )
382
+
383
+ def disable_adapters(self) -> None:
384
+ r"""
385
+ Disable all adapters attached to the model and fallback to inference with the base model only.
386
+
387
+ If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
388
+ official documentation: https://huggingface.co/docs/peft
389
+ """
390
+ check_peft_version(min_version=MIN_PEFT_VERSION)
391
+
392
+ if not self._hf_peft_config_loaded:
393
+ raise ValueError("No adapter loaded. Please load an adapter first.")
394
+
395
+ from peft.tuners.tuners_utils import BaseTunerLayer
396
+
397
+ for _, module in self.named_modules():
398
+ if isinstance(module, BaseTunerLayer):
399
+ if hasattr(module, "enable_adapters"):
400
+ module.enable_adapters(enabled=False)
401
+ else:
402
+ # support for older PEFT versions
403
+ module.disable_adapters = True
404
+
405
+ def enable_adapters(self) -> None:
406
+ """
407
+ Enable adapters that are attached to the model. The model will use `self.active_adapters()` to retrieve the
408
+ list of adapters to enable.
409
+
410
+ If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
411
+ official documentation: https://huggingface.co/docs/peft
412
+ """
413
+ check_peft_version(min_version=MIN_PEFT_VERSION)
414
+
415
+ if not self._hf_peft_config_loaded:
416
+ raise ValueError("No adapter loaded. Please load an adapter first.")
417
+
418
+ from peft.tuners.tuners_utils import BaseTunerLayer
419
+
420
+ for _, module in self.named_modules():
421
+ if isinstance(module, BaseTunerLayer):
422
+ if hasattr(module, "enable_adapters"):
423
+ module.enable_adapters(enabled=True)
424
+ else:
425
+ # support for older PEFT versions
426
+ module.disable_adapters = False
427
+
428
+ def active_adapters(self) -> List[str]:
429
+ """
430
+ Gets the current list of active adapters of the model.
431
+
432
+ If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
433
+ official documentation: https://huggingface.co/docs/peft
434
+ """
435
+ check_peft_version(min_version=MIN_PEFT_VERSION)
436
+
437
+ if not self._hf_peft_config_loaded:
438
+ raise ValueError("No adapter loaded. Please load an adapter first.")
439
+
440
+ from peft.tuners.tuners_utils import BaseTunerLayer
441
+
442
+ for _, module in self.named_modules():
443
+ if isinstance(module, BaseTunerLayer):
444
+ return module.active_adapter
445
+
446
+ def save_pretrained(
447
+ self,
448
+ save_directory: Union[str, os.PathLike],
449
+ is_main_process: bool = True,
450
+ save_function: Callable = None,
451
+ safe_serialization: bool = True,
452
+ variant: Optional[str] = None,
453
+ push_to_hub: bool = False,
454
+ **kwargs,
455
+ ):
456
+ """
457
+ Save a model and its configuration file to a directory so that it can be reloaded using the
458
+ [`~models.ModelMixin.from_pretrained`] class method.
459
+
460
+ Arguments:
461
+ save_directory (`str` or `os.PathLike`):
462
+ Directory to save a model and its configuration file to. Will be created if it doesn't exist.
463
+ is_main_process (`bool`, *optional*, defaults to `True`):
464
+ Whether the process calling this is the main process or not. Useful during distributed training and you
465
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
466
+ process to avoid race conditions.
467
+ save_function (`Callable`):
468
+ The function to use to save the state dictionary. Useful during distributed training when you need to
469
+ replace `torch.save` with another method. Can be configured with the environment variable
470
+ `DIFFUSERS_SAVE_MODE`.
471
+ safe_serialization (`bool`, *optional*, defaults to `True`):
472
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
473
+ variant (`str`, *optional*):
474
+ If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
475
+ push_to_hub (`bool`, *optional*, defaults to `False`):
476
+ Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
477
+ repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
478
+ namespace).
479
+ kwargs (`Dict[str, Any]`, *optional*):
480
+ Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
481
+ """
482
+ if os.path.isfile(save_directory):
483
+ logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
484
+ return
485
+
486
+ os.makedirs(save_directory, exist_ok=True)
487
+
488
+ if push_to_hub:
489
+ commit_message = kwargs.pop("commit_message", None)
490
+ private = kwargs.pop("private", False)
491
+ create_pr = kwargs.pop("create_pr", False)
492
+ token = kwargs.pop("token", None)
493
+ repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
494
+ repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
495
+
496
+ # Only save the model itself if we are using distributed training
497
+ model_to_save = self
498
+
499
+ # Attach architecture to the config
500
+ # Save the config
501
+ if is_main_process:
502
+ model_to_save.save_config(save_directory)
503
+
504
+ # Save the model
505
+ state_dict = model_to_save.state_dict()
506
+
507
+ weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
508
+ weights_name = _add_variant(weights_name, variant)
509
+
510
+ # Save the model
511
+ if safe_serialization:
512
+ safetensors.torch.save_file(
513
+ state_dict, os.path.join(save_directory, weights_name), metadata={"format": "pt"}
514
+ )
515
+ else:
516
+ torch.save(state_dict, os.path.join(save_directory, weights_name))
517
+
518
+ logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}")
519
+
520
+ if push_to_hub:
521
+ self._upload_folder(
522
+ save_directory,
523
+ repo_id,
524
+ token=token,
525
+ commit_message=commit_message,
526
+ create_pr=create_pr,
527
+ )
528
+
529
+ @classmethod
530
+ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
531
+ r"""
532
+ Instantiate a pretrained PyTorch model from a pretrained model configuration.
533
+
534
+ The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To
535
+ train the model, set it back in training mode with `model.train()`.
536
+
537
+ Parameters:
538
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
539
+ Can be either:
540
+
541
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
542
+ the Hub.
543
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
544
+ with [`~ModelMixin.save_pretrained`].
545
+
546
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
547
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
548
+ is not used.
549
+ torch_dtype (`str` or `torch.dtype`, *optional*):
550
+ Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
551
+ dtype is automatically derived from the model's weights.
552
+ force_download (`bool`, *optional*, defaults to `False`):
553
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
554
+ cached versions if they exist.
555
+ resume_download (`bool`, *optional*, defaults to `False`):
556
+ Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
557
+ incompletely downloaded files are deleted.
558
+ proxies (`Dict[str, str]`, *optional*):
559
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
560
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
561
+ output_loading_info (`bool`, *optional*, defaults to `False`):
562
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
563
+ local_files_only(`bool`, *optional*, defaults to `False`):
564
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
565
+ won't be downloaded from the Hub.
566
+ use_auth_token (`str` or *bool*, *optional*):
567
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
568
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
569
+ revision (`str`, *optional*, defaults to `"main"`):
570
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
571
+ allowed by Git.
572
+ from_flax (`bool`, *optional*, defaults to `False`):
573
+ Load the model weights from a Flax checkpoint save file.
574
+ subfolder (`str`, *optional*, defaults to `""`):
575
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
576
+ mirror (`str`, *optional*):
577
+ Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
578
+ guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
579
+ information.
580
+ device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
581
+ A map that specifies where each submodule should go. It doesn't need to be defined for each
582
+ parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
583
+ same device.
584
+
585
+ Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
586
+ more information about each option see [designing a device
587
+ map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
588
+ max_memory (`Dict`, *optional*):
589
+ A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
590
+ each GPU and the available CPU RAM if unset.
591
+ offload_folder (`str` or `os.PathLike`, *optional*):
592
+ The path to offload weights if `device_map` contains the value `"disk"`.
593
+ offload_state_dict (`bool`, *optional*):
594
+ If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if
595
+ the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`
596
+ when there is some disk offload.
597
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
598
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
599
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
600
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
601
+ argument to `True` will raise an error.
602
+ variant (`str`, *optional*):
603
+ Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when
604
+ loading `from_flax`.
605
+ use_safetensors (`bool`, *optional*, defaults to `None`):
606
+ If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
607
+ `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
608
+ weights. If set to `False`, `safetensors` weights are not loaded.
609
+
610
+ <Tip>
611
+
612
+ To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
613
+ `huggingface-cli login`. You can also activate the special
614
+ ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
615
+ firewalled environment.
616
+
617
+ </Tip>
618
+
619
+ Example:
620
+
621
+ ```py
622
+ from diffusers import UNet2DConditionModel
623
+
624
+ unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
625
+ ```
626
+
627
+ If you get the error message below, you need to finetune the weights for your downstream task:
628
+
629
+ ```bash
630
+ Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
631
+ - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
632
+ You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
633
+ ```
634
+ """
635
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
636
+ ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
637
+ force_download = kwargs.pop("force_download", False)
638
+ from_flax = kwargs.pop("from_flax", False)
639
+ resume_download = kwargs.pop("resume_download", False)
640
+ proxies = kwargs.pop("proxies", None)
641
+ output_loading_info = kwargs.pop("output_loading_info", False)
642
+ local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
643
+ use_auth_token = kwargs.pop("use_auth_token", None)
644
+ revision = kwargs.pop("revision", None)
645
+ torch_dtype = kwargs.pop("torch_dtype", None)
646
+ subfolder = kwargs.pop("subfolder", None)
647
+ device_map = kwargs.pop("device_map", None)
648
+ max_memory = kwargs.pop("max_memory", None)
649
+ offload_folder = kwargs.pop("offload_folder", None)
650
+ offload_state_dict = kwargs.pop("offload_state_dict", False)
651
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
652
+ variant = kwargs.pop("variant", None)
653
+ use_safetensors = kwargs.pop("use_safetensors", None)
654
+
655
+ allow_pickle = False
656
+ if use_safetensors is None:
657
+ use_safetensors = True
658
+ allow_pickle = True
659
+
660
+ if low_cpu_mem_usage and not is_accelerate_available():
661
+ low_cpu_mem_usage = False
662
+ logger.warning(
663
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
664
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
665
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
666
+ " install accelerate\n```\n."
667
+ )
668
+
669
+ if device_map is not None and not is_accelerate_available():
670
+ raise NotImplementedError(
671
+ "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
672
+ " `device_map=None`. You can install accelerate with `pip install accelerate`."
673
+ )
674
+
675
+ # Check if we can handle device_map and dispatching the weights
676
+ if device_map is not None and not is_torch_version(">=", "1.9.0"):
677
+ raise NotImplementedError(
678
+ "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
679
+ " `device_map=None`."
680
+ )
681
+
682
+ if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
683
+ raise NotImplementedError(
684
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
685
+ " `low_cpu_mem_usage=False`."
686
+ )
687
+
688
+ if low_cpu_mem_usage is False and device_map is not None:
689
+ raise ValueError(
690
+ f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and"
691
+ " dispatching. Please make sure to set `low_cpu_mem_usage=True`."
692
+ )
693
+
694
+ # Load config if we don't provide a configuration
695
+ config_path = pretrained_model_name_or_path
696
+
697
+ user_agent = {
698
+ "diffusers": __version__,
699
+ "file_type": "model",
700
+ "framework": "pytorch",
701
+ }
702
+
703
+ # load config
704
+ config, unused_kwargs, commit_hash = cls.load_config(
705
+ config_path,
706
+ cache_dir=cache_dir,
707
+ return_unused_kwargs=True,
708
+ return_commit_hash=True,
709
+ force_download=force_download,
710
+ resume_download=resume_download,
711
+ proxies=proxies,
712
+ local_files_only=local_files_only,
713
+ use_auth_token=use_auth_token,
714
+ revision=revision,
715
+ subfolder=subfolder,
716
+ device_map=device_map,
717
+ max_memory=max_memory,
718
+ offload_folder=offload_folder,
719
+ offload_state_dict=offload_state_dict,
720
+ user_agent=user_agent,
721
+ **kwargs,
722
+ )
723
+
724
+ # load model
725
+ model_file = None
726
+ if from_flax:
727
+ model_file = _get_model_file(
728
+ pretrained_model_name_or_path,
729
+ weights_name=FLAX_WEIGHTS_NAME,
730
+ cache_dir=cache_dir,
731
+ force_download=force_download,
732
+ resume_download=resume_download,
733
+ proxies=proxies,
734
+ local_files_only=local_files_only,
735
+ use_auth_token=use_auth_token,
736
+ revision=revision,
737
+ subfolder=subfolder,
738
+ user_agent=user_agent,
739
+ commit_hash=commit_hash,
740
+ )
741
+ model = cls.from_config(config, **unused_kwargs)
742
+
743
+ # Convert the weights
744
+ from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model
745
+
746
+ model = load_flax_checkpoint_in_pytorch_model(model, model_file)
747
+ else:
748
+ if use_safetensors:
749
+ try:
750
+ model_file = _get_model_file(
751
+ pretrained_model_name_or_path,
752
+ weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
753
+ cache_dir=cache_dir,
754
+ force_download=force_download,
755
+ resume_download=resume_download,
756
+ proxies=proxies,
757
+ local_files_only=local_files_only,
758
+ use_auth_token=use_auth_token,
759
+ revision=revision,
760
+ subfolder=subfolder,
761
+ user_agent=user_agent,
762
+ commit_hash=commit_hash,
763
+ )
764
+ except IOError as e:
765
+ if not allow_pickle:
766
+ raise e
767
+ pass
768
+ if model_file is None:
769
+ model_file = _get_model_file(
770
+ pretrained_model_name_or_path,
771
+ weights_name=_add_variant(WEIGHTS_NAME, variant),
772
+ cache_dir=cache_dir,
773
+ force_download=force_download,
774
+ resume_download=resume_download,
775
+ proxies=proxies,
776
+ local_files_only=local_files_only,
777
+ use_auth_token=use_auth_token,
778
+ revision=revision,
779
+ subfolder=subfolder,
780
+ user_agent=user_agent,
781
+ commit_hash=commit_hash,
782
+ )
783
+
784
+ if low_cpu_mem_usage:
785
+ # Instantiate model with empty weights
786
+ with accelerate.init_empty_weights():
787
+ model = cls.from_config(config, **unused_kwargs)
788
+
789
+ # if device_map is None, load the state dict and move the params from meta device to the cpu
790
+ if device_map is None:
791
+ param_device = "cpu"
792
+ state_dict = load_state_dict(model_file, variant=variant)
793
+ model._convert_deprecated_attention_blocks(state_dict)
794
+ # move the params from meta device to cpu
795
+ missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
796
+ if len(missing_keys) > 0:
797
+ raise ValueError(
798
+ f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are"
799
+ f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
800
+ " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
801
+ " those weights or else make sure your checkpoint file is correct."
802
+ )
803
+
804
+ unexpected_keys = load_model_dict_into_meta(
805
+ model,
806
+ state_dict,
807
+ device=param_device,
808
+ dtype=torch_dtype,
809
+ model_name_or_path=pretrained_model_name_or_path,
810
+ )
811
+
812
+ if cls._keys_to_ignore_on_load_unexpected is not None:
813
+ for pat in cls._keys_to_ignore_on_load_unexpected:
814
+ unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
815
+
816
+ if len(unexpected_keys) > 0:
817
+ logger.warn(
818
+ f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
819
+ )
820
+
821
+ else: # else let accelerate handle loading and dispatching.
822
+ # Load weights and dispatch according to the device_map
823
+ # by default the device_map is None and the weights are loaded on the CPU
824
+ try:
825
+ accelerate.load_checkpoint_and_dispatch(
826
+ model,
827
+ model_file,
828
+ device_map,
829
+ max_memory=max_memory,
830
+ offload_folder=offload_folder,
831
+ offload_state_dict=offload_state_dict,
832
+ dtype=torch_dtype,
833
+ )
834
+ except AttributeError as e:
835
+ # When using accelerate loading, we do not have the ability to load the state
836
+ # dict and rename the weight names manually. Additionally, accelerate skips
837
+ # torch loading conventions and directly writes into `module.{_buffers, _parameters}`
838
+ # (which look like they should be private variables?), so we can't use the standard hooks
839
+ # to rename parameters on load. We need to mimic the original weight names so the correct
840
+ # attributes are available. After we have loaded the weights, we convert the deprecated
841
+ # names to the new non-deprecated names. Then we _greatly encourage_ the user to convert
842
+ # the weights so we don't have to do this again.
843
+
844
+ if "'Attention' object has no attribute" in str(e):
845
+ logger.warn(
846
+ f"Taking `{str(e)}` while using `accelerate.load_checkpoint_and_dispatch` to mean {pretrained_model_name_or_path}"
847
+ " was saved with deprecated attention block weight names. We will load it with the deprecated attention block"
848
+ " names and convert them on the fly to the new attention block format. Please re-save the model after this conversion,"
849
+ " so we don't have to do the on the fly renaming in the future. If the model is from a hub checkpoint,"
850
+ " please also re-upload it or open a PR on the original repository."
851
+ )
852
+ model._temp_convert_self_to_deprecated_attention_blocks()
853
+ accelerate.load_checkpoint_and_dispatch(
854
+ model,
855
+ model_file,
856
+ device_map,
857
+ max_memory=max_memory,
858
+ offload_folder=offload_folder,
859
+ offload_state_dict=offload_state_dict,
860
+ dtype=torch_dtype,
861
+ )
862
+ model._undo_temp_convert_self_to_deprecated_attention_blocks()
863
+ else:
864
+ raise e
865
+
866
+ loading_info = {
867
+ "missing_keys": [],
868
+ "unexpected_keys": [],
869
+ "mismatched_keys": [],
870
+ "error_msgs": [],
871
+ }
872
+ else:
873
+ model = cls.from_config(config, **unused_kwargs)
874
+
875
+ state_dict = load_state_dict(model_file, variant=variant)
876
+ model._convert_deprecated_attention_blocks(state_dict)
877
+
878
+ model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
879
+ model,
880
+ state_dict,
881
+ model_file,
882
+ pretrained_model_name_or_path,
883
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
884
+ )
885
+
886
+ loading_info = {
887
+ "missing_keys": missing_keys,
888
+ "unexpected_keys": unexpected_keys,
889
+ "mismatched_keys": mismatched_keys,
890
+ "error_msgs": error_msgs,
891
+ }
892
+
893
+ if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
894
+ raise ValueError(
895
+ f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
896
+ )
897
+ elif torch_dtype is not None:
898
+ model = model.to(torch_dtype)
899
+
900
+ model.register_to_config(_name_or_path=pretrained_model_name_or_path)
901
+
902
+ # Set model in evaluation mode to deactivate DropOut modules by default
903
+ model.eval()
904
+ if output_loading_info:
905
+ return model, loading_info
906
+
907
+ return model
908
+
909
+ @classmethod
910
+ def _load_pretrained_model(
911
+ cls,
912
+ model,
913
+ state_dict,
914
+ resolved_archive_file,
915
+ pretrained_model_name_or_path,
916
+ ignore_mismatched_sizes=False,
917
+ ):
918
+ # Retrieve missing & unexpected_keys
919
+ model_state_dict = model.state_dict()
920
+ loaded_keys = list(state_dict.keys())
921
+
922
+ expected_keys = list(model_state_dict.keys())
923
+
924
+ original_loaded_keys = loaded_keys
925
+
926
+ missing_keys = list(set(expected_keys) - set(loaded_keys))
927
+ unexpected_keys = list(set(loaded_keys) - set(expected_keys))
928
+
929
+ # Make sure we are able to load base models as well as derived models (with heads)
930
+ model_to_load = model
931
+
932
+ def _find_mismatched_keys(
933
+ state_dict,
934
+ model_state_dict,
935
+ loaded_keys,
936
+ ignore_mismatched_sizes,
937
+ ):
938
+ mismatched_keys = []
939
+ if ignore_mismatched_sizes:
940
+ for checkpoint_key in loaded_keys:
941
+ model_key = checkpoint_key
942
+
943
+ if (
944
+ model_key in model_state_dict
945
+ and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
946
+ ):
947
+ mismatched_keys.append(
948
+ (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
949
+ )
950
+ del state_dict[checkpoint_key]
951
+ return mismatched_keys
952
+
953
+ if state_dict is not None:
954
+ # Whole checkpoint
955
+ mismatched_keys = _find_mismatched_keys(
956
+ state_dict,
957
+ model_state_dict,
958
+ original_loaded_keys,
959
+ ignore_mismatched_sizes,
960
+ )
961
+ error_msgs = _load_state_dict_into_model(model_to_load, state_dict)
962
+
963
+ if len(error_msgs) > 0:
964
+ error_msg = "\n\t".join(error_msgs)
965
+ if "size mismatch" in error_msg:
966
+ error_msg += (
967
+ "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
968
+ )
969
+ raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
970
+
971
+ if len(unexpected_keys) > 0:
972
+ logger.warning(
973
+ f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
974
+ f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
975
+ f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
976
+ " or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
977
+ " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
978
+ f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
979
+ " identical (initializing a BertForSequenceClassification model from a"
980
+ " BertForSequenceClassification model)."
981
+ )
982
+ else:
983
+ logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
984
+ if len(missing_keys) > 0:
985
+ logger.warning(
986
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
987
+ f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
988
+ " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
989
+ )
990
+ elif len(mismatched_keys) == 0:
991
+ logger.info(
992
+ f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
993
+ f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
994
+ f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
995
+ " without further training."
996
+ )
997
+ if len(mismatched_keys) > 0:
998
+ mismatched_warning = "\n".join(
999
+ [
1000
+ f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
1001
+ for key, shape1, shape2 in mismatched_keys
1002
+ ]
1003
+ )
1004
+ logger.warning(
1005
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
1006
+ f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
1007
+ f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
1008
+ " able to use it for predictions and inference."
1009
+ )
1010
+
1011
+ return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
1012
+
1013
+ @property
1014
+ def device(self) -> device:
1015
+ """
1016
+ `torch.device`: The device on which the module is (assuming that all the module parameters are on the same
1017
+ device).
1018
+ """
1019
+ return get_parameter_device(self)
1020
+
1021
+ @property
1022
+ def dtype(self) -> torch.dtype:
1023
+ """
1024
+ `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
1025
+ """
1026
+ return get_parameter_dtype(self)
1027
+
1028
+ def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
1029
+ """
1030
+ Get number of (trainable or non-embedding) parameters in the module.
1031
+
1032
+ Args:
1033
+ only_trainable (`bool`, *optional*, defaults to `False`):
1034
+ Whether or not to return only the number of trainable parameters.
1035
+ exclude_embeddings (`bool`, *optional*, defaults to `False`):
1036
+ Whether or not to return only the number of non-embedding parameters.
1037
+
1038
+ Returns:
1039
+ `int`: The number of parameters.
1040
+
1041
+ Example:
1042
+
1043
+ ```py
1044
+ from diffusers import UNet2DConditionModel
1045
+
1046
+ model_id = "runwayml/stable-diffusion-v1-5"
1047
+ unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet")
1048
+ unet.num_parameters(only_trainable=True)
1049
+ 859520964
1050
+ ```
1051
+ """
1052
+
1053
+ if exclude_embeddings:
1054
+ embedding_param_names = [
1055
+ f"{name}.weight"
1056
+ for name, module_type in self.named_modules()
1057
+ if isinstance(module_type, torch.nn.Embedding)
1058
+ ]
1059
+ non_embedding_parameters = [
1060
+ parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
1061
+ ]
1062
+ return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
1063
+ else:
1064
+ return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
1065
+
1066
+ def _convert_deprecated_attention_blocks(self, state_dict):
1067
+ deprecated_attention_block_paths = []
1068
+
1069
+ def recursive_find_attn_block(name, module):
1070
+ if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
1071
+ deprecated_attention_block_paths.append(name)
1072
+
1073
+ for sub_name, sub_module in module.named_children():
1074
+ sub_name = sub_name if name == "" else f"{name}.{sub_name}"
1075
+ recursive_find_attn_block(sub_name, sub_module)
1076
+
1077
+ recursive_find_attn_block("", self)
1078
+
1079
+ # NOTE: we have to check if the deprecated parameters are in the state dict
1080
+ # because it is possible we are loading from a state dict that was already
1081
+ # converted
1082
+
1083
+ for path in deprecated_attention_block_paths:
1084
+ # group_norm path stays the same
1085
+
1086
+ # query -> to_q
1087
+ if f"{path}.query.weight" in state_dict:
1088
+ state_dict[f"{path}.to_q.weight"] = state_dict.pop(f"{path}.query.weight")
1089
+ if f"{path}.query.bias" in state_dict:
1090
+ state_dict[f"{path}.to_q.bias"] = state_dict.pop(f"{path}.query.bias")
1091
+
1092
+ # key -> to_k
1093
+ if f"{path}.key.weight" in state_dict:
1094
+ state_dict[f"{path}.to_k.weight"] = state_dict.pop(f"{path}.key.weight")
1095
+ if f"{path}.key.bias" in state_dict:
1096
+ state_dict[f"{path}.to_k.bias"] = state_dict.pop(f"{path}.key.bias")
1097
+
1098
+ # value -> to_v
1099
+ if f"{path}.value.weight" in state_dict:
1100
+ state_dict[f"{path}.to_v.weight"] = state_dict.pop(f"{path}.value.weight")
1101
+ if f"{path}.value.bias" in state_dict:
1102
+ state_dict[f"{path}.to_v.bias"] = state_dict.pop(f"{path}.value.bias")
1103
+
1104
+ # proj_attn -> to_out.0
1105
+ if f"{path}.proj_attn.weight" in state_dict:
1106
+ state_dict[f"{path}.to_out.0.weight"] = state_dict.pop(f"{path}.proj_attn.weight")
1107
+ if f"{path}.proj_attn.bias" in state_dict:
1108
+ state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias")
1109
+
1110
+ def _temp_convert_self_to_deprecated_attention_blocks(self):
1111
+ deprecated_attention_block_modules = []
1112
+
1113
+ def recursive_find_attn_block(module):
1114
+ if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
1115
+ deprecated_attention_block_modules.append(module)
1116
+
1117
+ for sub_module in module.children():
1118
+ recursive_find_attn_block(sub_module)
1119
+
1120
+ recursive_find_attn_block(self)
1121
+
1122
+ for module in deprecated_attention_block_modules:
1123
+ module.query = module.to_q
1124
+ module.key = module.to_k
1125
+ module.value = module.to_v
1126
+ module.proj_attn = module.to_out[0]
1127
+
1128
+ # We don't _have_ to delete the old attributes, but it's helpful to ensure
1129
+ # that _all_ the weights are loaded into the new attributes and we're not
1130
+ # making an incorrect assumption that this model should be converted when
1131
+ # it really shouldn't be.
1132
+ del module.to_q
1133
+ del module.to_k
1134
+ del module.to_v
1135
+ del module.to_out
1136
+
1137
+ def _undo_temp_convert_self_to_deprecated_attention_blocks(self):
1138
+ deprecated_attention_block_modules = []
1139
+
1140
+ def recursive_find_attn_block(module):
1141
+ if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
1142
+ deprecated_attention_block_modules.append(module)
1143
+
1144
+ for sub_module in module.children():
1145
+ recursive_find_attn_block(sub_module)
1146
+
1147
+ recursive_find_attn_block(self)
1148
+
1149
+ for module in deprecated_attention_block_modules:
1150
+ module.to_q = module.query
1151
+ module.to_k = module.key
1152
+ module.to_v = module.value
1153
+ module.to_out = nn.ModuleList([module.proj_attn, nn.Dropout(module.dropout)])
1154
+
1155
+ del module.query
1156
+ del module.key
1157
+ del module.value
1158
+ del module.proj_attn
diffusers/models/normalization.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 HuggingFace Inc.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Dict, Optional, Tuple
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+
22
+ from .activations import get_activation
23
+ from .embeddings import CombinedTimestepLabelEmbeddings, CombinedTimestepSizeEmbeddings
24
+
25
+
26
+ class AdaLayerNorm(nn.Module):
27
+ r"""
28
+ Norm layer modified to incorporate timestep embeddings.
29
+
30
+ Parameters:
31
+ embedding_dim (`int`): The size of each embedding vector.
32
+ num_embeddings (`int`): The size of the embeddings dictionary.
33
+ """
34
+
35
+ def __init__(self, embedding_dim: int, num_embeddings: int):
36
+ super().__init__()
37
+ self.emb = nn.Embedding(num_embeddings, embedding_dim)
38
+ self.silu = nn.SiLU()
39
+ self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
40
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)
41
+
42
+ def forward(self, x: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor:
43
+ emb = self.linear(self.silu(self.emb(timestep)))
44
+ scale, shift = torch.chunk(emb, 2)
45
+ x = self.norm(x) * (1 + scale) + shift
46
+ return x
47
+
48
+
49
+ class AdaLayerNormZero(nn.Module):
50
+ r"""
51
+ Norm layer adaptive layer norm zero (adaLN-Zero).
52
+
53
+ Parameters:
54
+ embedding_dim (`int`): The size of each embedding vector.
55
+ num_embeddings (`int`): The size of the embeddings dictionary.
56
+ """
57
+
58
+ def __init__(self, embedding_dim: int, num_embeddings: int):
59
+ super().__init__()
60
+
61
+ self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
62
+
63
+ self.silu = nn.SiLU()
64
+ self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
65
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
66
+
67
+ def forward(
68
+ self,
69
+ x: torch.Tensor,
70
+ timestep: torch.Tensor,
71
+ class_labels: torch.LongTensor,
72
+ hidden_dtype: Optional[torch.dtype] = None,
73
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
74
+ emb = self.linear(self.silu(self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)))
75
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
76
+ x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
77
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
78
+
79
+
80
+ class AdaLayerNormSingle(nn.Module):
81
+ r"""
82
+ Norm layer adaptive layer norm single (adaLN-single).
83
+
84
+ As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
85
+
86
+ Parameters:
87
+ embedding_dim (`int`): The size of each embedding vector.
88
+ use_additional_conditions (`bool`): To use additional conditions for normalization or not.
89
+ """
90
+
91
+ def __init__(self, embedding_dim: int, use_additional_conditions: bool = False):
92
+ super().__init__()
93
+
94
+ self.emb = CombinedTimestepSizeEmbeddings(
95
+ embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions
96
+ )
97
+
98
+ self.silu = nn.SiLU()
99
+ self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
100
+
101
+ def forward(
102
+ self,
103
+ timestep: torch.Tensor,
104
+ added_cond_kwargs: Dict[str, torch.Tensor] = None,
105
+ batch_size: int = None,
106
+ hidden_dtype: Optional[torch.dtype] = None,
107
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
108
+ # No modulation happening here.
109
+ embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
110
+ return self.linear(self.silu(embedded_timestep)), embedded_timestep
111
+
112
+
113
+ class AdaGroupNorm(nn.Module):
114
+ r"""
115
+ GroupNorm layer modified to incorporate timestep embeddings.
116
+
117
+ Parameters:
118
+ embedding_dim (`int`): The size of each embedding vector.
119
+ num_embeddings (`int`): The size of the embeddings dictionary.
120
+ num_groups (`int`): The number of groups to separate the channels into.
121
+ act_fn (`str`, *optional*, defaults to `None`): The activation function to use.
122
+ eps (`float`, *optional*, defaults to `1e-5`): The epsilon value to use for numerical stability.
123
+ """
124
+
125
+ def __init__(
126
+ self, embedding_dim: int, out_dim: int, num_groups: int, act_fn: Optional[str] = None, eps: float = 1e-5
127
+ ):
128
+ super().__init__()
129
+ self.num_groups = num_groups
130
+ self.eps = eps
131
+
132
+ if act_fn is None:
133
+ self.act = None
134
+ else:
135
+ self.act = get_activation(act_fn)
136
+
137
+ self.linear = nn.Linear(embedding_dim, out_dim * 2)
138
+
139
+ def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
140
+ if self.act:
141
+ emb = self.act(emb)
142
+ emb = self.linear(emb)
143
+ emb = emb[:, :, None, None]
144
+ scale, shift = emb.chunk(2, dim=1)
145
+
146
+ x = F.group_norm(x, self.num_groups, eps=self.eps)
147
+ x = x * (1 + scale) + shift
148
+ return x
diffusers/models/prior_transformer.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Dict, Optional, Union
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import nn
7
+
8
+ from ..configuration_utils import ConfigMixin, register_to_config
9
+ from ..loaders import UNet2DConditionLoadersMixin
10
+ from ..utils import BaseOutput
11
+ from .attention import BasicTransformerBlock
12
+ from .attention_processor import (
13
+ ADDED_KV_ATTENTION_PROCESSORS,
14
+ CROSS_ATTENTION_PROCESSORS,
15
+ AttentionProcessor,
16
+ AttnAddedKVProcessor,
17
+ AttnProcessor,
18
+ )
19
+ from .embeddings import TimestepEmbedding, Timesteps
20
+ from .modeling_utils import ModelMixin
21
+
22
+
23
+ @dataclass
24
+ class PriorTransformerOutput(BaseOutput):
25
+ """
26
+ The output of [`PriorTransformer`].
27
+
28
+ Args:
29
+ predicted_image_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
30
+ The predicted CLIP image embedding conditioned on the CLIP text embedding input.
31
+ """
32
+
33
+ predicted_image_embedding: torch.FloatTensor
34
+
35
+
36
+ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
37
+ """
38
+ A Prior Transformer model.
39
+
40
+ Parameters:
41
+ num_attention_heads (`int`, *optional*, defaults to 32): The number of heads to use for multi-head attention.
42
+ attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
43
+ num_layers (`int`, *optional*, defaults to 20): The number of layers of Transformer blocks to use.
44
+ embedding_dim (`int`, *optional*, defaults to 768): The dimension of the model input `hidden_states`
45
+ num_embeddings (`int`, *optional*, defaults to 77):
46
+ The number of embeddings of the model input `hidden_states`
47
+ additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the
48
+ projected `hidden_states`. The actual length of the used `hidden_states` is `num_embeddings +
49
+ additional_embeddings`.
50
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
51
+ time_embed_act_fn (`str`, *optional*, defaults to 'silu'):
52
+ The activation function to use to create timestep embeddings.
53
+ norm_in_type (`str`, *optional*, defaults to None): The normalization layer to apply on hidden states before
54
+ passing to Transformer blocks. Set it to `None` if normalization is not needed.
55
+ embedding_proj_norm_type (`str`, *optional*, defaults to None):
56
+ The normalization layer to apply on the input `proj_embedding`. Set it to `None` if normalization is not
57
+ needed.
58
+ encoder_hid_proj_type (`str`, *optional*, defaults to `linear`):
59
+ The projection layer to apply on the input `encoder_hidden_states`. Set it to `None` if
60
+ `encoder_hidden_states` is `None`.
61
+ added_emb_type (`str`, *optional*, defaults to `prd`): Additional embeddings to condition the model.
62
+ Choose from `prd` or `None`. if choose `prd`, it will prepend a token indicating the (quantized) dot
63
+ product between the text embedding and image embedding as proposed in the unclip paper
64
+ https://arxiv.org/abs/2204.06125 If it is `None`, no additional embeddings will be prepended.
65
+ time_embed_dim (`int, *optional*, defaults to None): The dimension of timestep embeddings.
66
+ If None, will be set to `num_attention_heads * attention_head_dim`
67
+ embedding_proj_dim (`int`, *optional*, default to None):
68
+ The dimension of `proj_embedding`. If None, will be set to `embedding_dim`.
69
+ clip_embed_dim (`int`, *optional*, default to None):
70
+ The dimension of the output. If None, will be set to `embedding_dim`.
71
+ """
72
+
73
+ @register_to_config
74
+ def __init__(
75
+ self,
76
+ num_attention_heads: int = 32,
77
+ attention_head_dim: int = 64,
78
+ num_layers: int = 20,
79
+ embedding_dim: int = 768,
80
+ num_embeddings=77,
81
+ additional_embeddings=4,
82
+ dropout: float = 0.0,
83
+ time_embed_act_fn: str = "silu",
84
+ norm_in_type: Optional[str] = None, # layer
85
+ embedding_proj_norm_type: Optional[str] = None, # layer
86
+ encoder_hid_proj_type: Optional[str] = "linear", # linear
87
+ added_emb_type: Optional[str] = "prd", # prd
88
+ time_embed_dim: Optional[int] = None,
89
+ embedding_proj_dim: Optional[int] = None,
90
+ clip_embed_dim: Optional[int] = None,
91
+ ):
92
+ super().__init__()
93
+ self.num_attention_heads = num_attention_heads
94
+ self.attention_head_dim = attention_head_dim
95
+ inner_dim = num_attention_heads * attention_head_dim
96
+ self.additional_embeddings = additional_embeddings
97
+
98
+ time_embed_dim = time_embed_dim or inner_dim
99
+ embedding_proj_dim = embedding_proj_dim or embedding_dim
100
+ clip_embed_dim = clip_embed_dim or embedding_dim
101
+
102
+ self.time_proj = Timesteps(inner_dim, True, 0)
103
+ self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, out_dim=inner_dim, act_fn=time_embed_act_fn)
104
+
105
+ self.proj_in = nn.Linear(embedding_dim, inner_dim)
106
+
107
+ if embedding_proj_norm_type is None:
108
+ self.embedding_proj_norm = None
109
+ elif embedding_proj_norm_type == "layer":
110
+ self.embedding_proj_norm = nn.LayerNorm(embedding_proj_dim)
111
+ else:
112
+ raise ValueError(f"unsupported embedding_proj_norm_type: {embedding_proj_norm_type}")
113
+
114
+ self.embedding_proj = nn.Linear(embedding_proj_dim, inner_dim)
115
+
116
+ if encoder_hid_proj_type is None:
117
+ self.encoder_hidden_states_proj = None
118
+ elif encoder_hid_proj_type == "linear":
119
+ self.encoder_hidden_states_proj = nn.Linear(embedding_dim, inner_dim)
120
+ else:
121
+ raise ValueError(f"unsupported encoder_hid_proj_type: {encoder_hid_proj_type}")
122
+
123
+ self.positional_embedding = nn.Parameter(torch.zeros(1, num_embeddings + additional_embeddings, inner_dim))
124
+
125
+ if added_emb_type == "prd":
126
+ self.prd_embedding = nn.Parameter(torch.zeros(1, 1, inner_dim))
127
+ elif added_emb_type is None:
128
+ self.prd_embedding = None
129
+ else:
130
+ raise ValueError(
131
+ f"`added_emb_type`: {added_emb_type} is not supported. Make sure to choose one of `'prd'` or `None`."
132
+ )
133
+
134
+ self.transformer_blocks = nn.ModuleList(
135
+ [
136
+ BasicTransformerBlock(
137
+ inner_dim,
138
+ num_attention_heads,
139
+ attention_head_dim,
140
+ dropout=dropout,
141
+ activation_fn="gelu",
142
+ attention_bias=True,
143
+ )
144
+ for d in range(num_layers)
145
+ ]
146
+ )
147
+
148
+ if norm_in_type == "layer":
149
+ self.norm_in = nn.LayerNorm(inner_dim)
150
+ elif norm_in_type is None:
151
+ self.norm_in = None
152
+ else:
153
+ raise ValueError(f"Unsupported norm_in_type: {norm_in_type}.")
154
+
155
+ self.norm_out = nn.LayerNorm(inner_dim)
156
+
157
+ self.proj_to_clip_embeddings = nn.Linear(inner_dim, clip_embed_dim)
158
+
159
+ causal_attention_mask = torch.full(
160
+ [num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], -10000.0
161
+ )
162
+ causal_attention_mask.triu_(1)
163
+ causal_attention_mask = causal_attention_mask[None, ...]
164
+ self.register_buffer("causal_attention_mask", causal_attention_mask, persistent=False)
165
+
166
+ self.clip_mean = nn.Parameter(torch.zeros(1, clip_embed_dim))
167
+ self.clip_std = nn.Parameter(torch.zeros(1, clip_embed_dim))
168
+
169
+ @property
170
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
171
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
172
+ r"""
173
+ Returns:
174
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
175
+ indexed by its weight name.
176
+ """
177
+ # set recursively
178
+ processors = {}
179
+
180
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
181
+ if hasattr(module, "get_processor"):
182
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
183
+
184
+ for sub_name, child in module.named_children():
185
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
186
+
187
+ return processors
188
+
189
+ for name, module in self.named_children():
190
+ fn_recursive_add_processors(name, module, processors)
191
+
192
+ return processors
193
+
194
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
195
+ def set_attn_processor(
196
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
197
+ ):
198
+ r"""
199
+ Sets the attention processor to use to compute attention.
200
+
201
+ Parameters:
202
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
203
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
204
+ for **all** `Attention` layers.
205
+
206
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
207
+ processor. This is strongly recommended when setting trainable attention processors.
208
+
209
+ """
210
+ count = len(self.attn_processors.keys())
211
+
212
+ if isinstance(processor, dict) and len(processor) != count:
213
+ raise ValueError(
214
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
215
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
216
+ )
217
+
218
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
219
+ if hasattr(module, "set_processor"):
220
+ if not isinstance(processor, dict):
221
+ module.set_processor(processor, _remove_lora=_remove_lora)
222
+ else:
223
+ module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
224
+
225
+ for sub_name, child in module.named_children():
226
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
227
+
228
+ for name, module in self.named_children():
229
+ fn_recursive_attn_processor(name, module, processor)
230
+
231
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
232
+ def set_default_attn_processor(self):
233
+ """
234
+ Disables custom attention processors and sets the default attention implementation.
235
+ """
236
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
237
+ processor = AttnAddedKVProcessor()
238
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
239
+ processor = AttnProcessor()
240
+ else:
241
+ raise ValueError(
242
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
243
+ )
244
+
245
+ self.set_attn_processor(processor, _remove_lora=True)
246
+
247
+ def forward(
248
+ self,
249
+ hidden_states,
250
+ timestep: Union[torch.Tensor, float, int],
251
+ proj_embedding: torch.FloatTensor,
252
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
253
+ attention_mask: Optional[torch.BoolTensor] = None,
254
+ return_dict: bool = True,
255
+ ):
256
+ """
257
+ The [`PriorTransformer`] forward method.
258
+
259
+ Args:
260
+ hidden_states (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
261
+ The currently predicted image embeddings.
262
+ timestep (`torch.LongTensor`):
263
+ Current denoising step.
264
+ proj_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
265
+ Projected embedding vector the denoising process is conditioned on.
266
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_embeddings, embedding_dim)`):
267
+ Hidden states of the text embeddings the denoising process is conditioned on.
268
+ attention_mask (`torch.BoolTensor` of shape `(batch_size, num_embeddings)`):
269
+ Text mask for the text embeddings.
270
+ return_dict (`bool`, *optional*, defaults to `True`):
271
+ Whether or not to return a [`~models.prior_transformer.PriorTransformerOutput`] instead of a plain
272
+ tuple.
273
+
274
+ Returns:
275
+ [`~models.prior_transformer.PriorTransformerOutput`] or `tuple`:
276
+ If return_dict is True, a [`~models.prior_transformer.PriorTransformerOutput`] is returned, otherwise a
277
+ tuple is returned where the first element is the sample tensor.
278
+ """
279
+ batch_size = hidden_states.shape[0]
280
+
281
+ timesteps = timestep
282
+ if not torch.is_tensor(timesteps):
283
+ timesteps = torch.tensor([timesteps], dtype=torch.long, device=hidden_states.device)
284
+ elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
285
+ timesteps = timesteps[None].to(hidden_states.device)
286
+
287
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
288
+ timesteps = timesteps * torch.ones(batch_size, dtype=timesteps.dtype, device=timesteps.device)
289
+
290
+ timesteps_projected = self.time_proj(timesteps)
291
+
292
+ # timesteps does not contain any weights and will always return f32 tensors
293
+ # but time_embedding might be fp16, so we need to cast here.
294
+ timesteps_projected = timesteps_projected.to(dtype=self.dtype)
295
+ time_embeddings = self.time_embedding(timesteps_projected)
296
+
297
+ if self.embedding_proj_norm is not None:
298
+ proj_embedding = self.embedding_proj_norm(proj_embedding)
299
+
300
+ proj_embeddings = self.embedding_proj(proj_embedding)
301
+ if self.encoder_hidden_states_proj is not None and encoder_hidden_states is not None:
302
+ encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states)
303
+ elif self.encoder_hidden_states_proj is not None and encoder_hidden_states is None:
304
+ raise ValueError("`encoder_hidden_states_proj` requires `encoder_hidden_states` to be set")
305
+
306
+ hidden_states = self.proj_in(hidden_states)
307
+
308
+ positional_embeddings = self.positional_embedding.to(hidden_states.dtype)
309
+
310
+ additional_embeds = []
311
+ additional_embeddings_len = 0
312
+
313
+ if encoder_hidden_states is not None:
314
+ additional_embeds.append(encoder_hidden_states)
315
+ additional_embeddings_len += encoder_hidden_states.shape[1]
316
+
317
+ if len(proj_embeddings.shape) == 2:
318
+ proj_embeddings = proj_embeddings[:, None, :]
319
+
320
+ if len(hidden_states.shape) == 2:
321
+ hidden_states = hidden_states[:, None, :]
322
+
323
+ additional_embeds = additional_embeds + [
324
+ proj_embeddings,
325
+ time_embeddings[:, None, :],
326
+ hidden_states,
327
+ ]
328
+
329
+ if self.prd_embedding is not None:
330
+ prd_embedding = self.prd_embedding.to(hidden_states.dtype).expand(batch_size, -1, -1)
331
+ additional_embeds.append(prd_embedding)
332
+
333
+ hidden_states = torch.cat(
334
+ additional_embeds,
335
+ dim=1,
336
+ )
337
+
338
+ # Allow positional_embedding to not include the `addtional_embeddings` and instead pad it with zeros for these additional tokens
339
+ additional_embeddings_len = additional_embeddings_len + proj_embeddings.shape[1] + 1
340
+ if positional_embeddings.shape[1] < hidden_states.shape[1]:
341
+ positional_embeddings = F.pad(
342
+ positional_embeddings,
343
+ (
344
+ 0,
345
+ 0,
346
+ additional_embeddings_len,
347
+ self.prd_embedding.shape[1] if self.prd_embedding is not None else 0,
348
+ ),
349
+ value=0.0,
350
+ )
351
+
352
+ hidden_states = hidden_states + positional_embeddings
353
+
354
+ if attention_mask is not None:
355
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
356
+ attention_mask = F.pad(attention_mask, (0, self.additional_embeddings), value=0.0)
357
+ attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype)
358
+ attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, dim=0)
359
+
360
+ if self.norm_in is not None:
361
+ hidden_states = self.norm_in(hidden_states)
362
+
363
+ for block in self.transformer_blocks:
364
+ hidden_states = block(hidden_states, attention_mask=attention_mask)
365
+
366
+ hidden_states = self.norm_out(hidden_states)
367
+
368
+ if self.prd_embedding is not None:
369
+ hidden_states = hidden_states[:, -1]
370
+ else:
371
+ hidden_states = hidden_states[:, additional_embeddings_len:]
372
+
373
+ predicted_image_embedding = self.proj_to_clip_embeddings(hidden_states)
374
+
375
+ if not return_dict:
376
+ return (predicted_image_embedding,)
377
+
378
+ return PriorTransformerOutput(predicted_image_embedding=predicted_image_embedding)
379
+
380
+ def post_process_latents(self, prior_latents):
381
+ prior_latents = (prior_latents * self.clip_std) + self.clip_mean
382
+ return prior_latents