Spaces:
Paused
Paused
Upload 14 files
Browse files- .gitattributes +2 -0
- LICENSE +201 -0
- README.md +132 -13
- app.py +156 -0
- app_utils.py +99 -0
- assets/mountain.mp4 +3 -0
- assets/river.mp4 +3 -0
- assets/woman.mp4 +0 -0
- conversion_utils.py +103 -0
- inference.py +139 -0
- pipeline_ltx.py +811 -0
- prompt_embeds.pt +3 -0
- q8_attention_processors.py +52 -0
- q8_ltx.py +141 -0
- requirements.txt +7 -0
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
assets/mountain.mp4 filter=lfs diff=lfs merge=lfs -text
|
37 |
+
assets/river.mp4 filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
README.md
CHANGED
@@ -1,13 +1,132 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Q8 LTX-Video optimized for Ada
|
2 |
+
|
3 |
+
This repository shows how to use the Q8 kernels from [`KONAKONA666/q8_kernels`](https://github.com/KONAKONA666/q8_kernels) with `diffusers` to optimize inference of [LTX-Video](https://huggingface.co/Lightricks/LTX-Video) on ADA GPUs. Go from 16.192 secs to 9.572 secs while reducing memory from 7GBs to 5GBs without quality loss 🤪 With `torch.compile()`, the time reduces further to 6.747 secs 🔥
|
4 |
+
|
5 |
+
The Q8 transformer checkpoint is available here: [`sayakpaul/q8-ltx-video`](https://hf.co/sayakpaul/q8-ltx-video).
|
6 |
+
|
7 |
+
## Getting started
|
8 |
+
|
9 |
+
Install the dependencies:
|
10 |
+
|
11 |
+
```bash
|
12 |
+
pip install -U transformers accelerate
|
13 |
+
git clone https://github.com/huggingface/diffusers && cd diffusers && pip install -e . && cd ..
|
14 |
+
```
|
15 |
+
|
16 |
+
Then install `q8_kernels`, following instructions from [here](https://github.com/KONAKONA666/q8_kernels/?tab=readme-ov-file#installation).
|
17 |
+
|
18 |
+
To run inference with the Q8 kernels, we need some minor changes in `diffusers`. Apply [this patch](https://github.com/sayakpaul/q8-ltx-video/blob/368f549ca5136daf89049c9efe32748e73aca317/updates.patch) to take those into account:
|
19 |
+
|
20 |
+
```bash
|
21 |
+
git apply updates.patch
|
22 |
+
```
|
23 |
+
|
24 |
+
Now we can run inference:
|
25 |
+
|
26 |
+
```bash
|
27 |
+
python inference.py \
|
28 |
+
--prompt="A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage"
|
29 |
+
--negative_prompt="worst quality, inconsistent motion, blurry, jittery, distorted" \
|
30 |
+
--q8_transformer_path="sayakpaul/q8-ltx-video"
|
31 |
+
```
|
32 |
+
|
33 |
+
## Why does the repo exist and some more details?
|
34 |
+
|
35 |
+
There already exists [`KONAKONA666/LTX-Video`](https://github.com/KONAKONA666/LTX-Video). Then why this repo?
|
36 |
+
|
37 |
+
That repo uses custom implementations of the LTX-Video pipeline components and can be hard to directly use in `diffusers`. This repo repurposes the kernels from the `q8_kernels` on the components directly from `diffusers`.
|
38 |
+
|
39 |
+
<details>
|
40 |
+
<summary>More details</summary>
|
41 |
+
|
42 |
+
We do this by first converting the state dict of the original [LTX-Video transformer](https://huggingface.co/Lightricks/LTX-Video/tree/main/transformer). This includes FP8 quantization. This process also requires replacing:
|
43 |
+
|
44 |
+
* linear layers of the model
|
45 |
+
* RMSNorms of the model
|
46 |
+
* GELUs of the model
|
47 |
+
|
48 |
+
before the converted state dict is loaded into the model. Some layer params are kept in FP32 and some layers are not even quantized. Replacement utilities are in [`q8_ltx.py`](./q8_ltx.py).
|
49 |
+
|
50 |
+
The model can then be serialized. The conversion and serialization are coded in [`conversion_utils.py`](./conversion_utils.py).
|
51 |
+
|
52 |
+
During loading the model and using it for inference, we:
|
53 |
+
|
54 |
+
* initialize the transformer model under a "meta" device
|
55 |
+
* follow the same layer replacement scheme as detailed above
|
56 |
+
* populate the converted state dict
|
57 |
+
* replace the attention processors to use [the flash attention implementation](https://github.com/KONAKONA666/q8_kernels/blob/9cee3f3d4ca5ec8ab463179be32c8001e31f8f33/q8_kernels/functional/flash_attention.py) one from `q8_kernels`
|
58 |
+
|
59 |
+
Refer [here](https://github.com/sayakpaul/q8-ltx-video/blob/368f549ca5136daf89049c9efe32748e73aca317/inference.py#L48) more details. Additionally, we leverage [flash-attention implementation](https://github.com/sayakpaul/q8-ltx-video/blob/368f549ca5136daf89049c9efe32748e73aca317/q8_attention_processors.py#L44) from `q8_kernels` which provides further speedup.
|
60 |
+
|
61 |
+
</details>
|
62 |
+
|
63 |
+
## Performance
|
64 |
+
|
65 |
+
Below numbers were obtained for `max_sequence_length=512`, `num_inference_steps=50`, `num_frames=81`, `resolution=480x704`. Rest of the arguments were fixed at their default values as noticed in the [pipeline call signature of LTX-Video](https://github.com/huggingface/diffusers/blob/4b9f1c7d8c2e476eed38af3144b79105a5efcd93/src/diffusers/pipelines/ltx/pipeline_ltx.py#L496). The numbers also don't include the VAE decoding time to solely focus on the transformer.
|
66 |
+
|
67 |
+
|
68 |
+
| | **Time (Secs)** | **Memory (MB)** |
|
69 |
+
|:-----------:|:-----------:|:-----------:|
|
70 |
+
| Non Q8 | 16.192 | 7172.86 |
|
71 |
+
| Non Q8 (+ compile) | 16.205 | - |
|
72 |
+
| Q8 | 9.572 | 5413.51 |
|
73 |
+
| Q8 (+ compile) | 6.747 | - |
|
74 |
+
|
75 |
+
Benchmarking script is available in [`benchmark.py`](./benchmark.py). You would need to download the precomputed
|
76 |
+
prompt embeddings from [here](https://huggingface.co/sayakpaul/q8-ltx-video/blob/main/prompt_embeds.pt) before running the benchmark.
|
77 |
+
|
78 |
+
<details>
|
79 |
+
<summary>Env</summary>
|
80 |
+
|
81 |
+
```bash
|
82 |
+
+-----------------------------------------------------------------------------------------+
|
83 |
+
| NVIDIA-SMI 560.35.05 Driver Version: 560.35.05 CUDA Version: 12.6 |
|
84 |
+
|-----------------------------------------+------------------------+----------------------+
|
85 |
+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
|
86 |
+
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
|
87 |
+
| | | MIG M. |
|
88 |
+
|=========================================+========================+======================|
|
89 |
+
| 0 NVIDIA GeForce RTX 4090 Off | 00000000:01:00.0 Off | Off |
|
90 |
+
| 0% 46C P8 18W / 450W | 2MiB / 24564MiB | 0% Default |
|
91 |
+
| | | N/A |
|
92 |
+
+-----------------------------------------+------------------------+----------------------+
|
93 |
+
```
|
94 |
+
|
95 |
+
`diffusers-cli env`:
|
96 |
+
|
97 |
+
```bash
|
98 |
+
- 🤗 Diffusers version: 0.33.0.dev0
|
99 |
+
- Platform: Linux-6.8.0-49-generic-x86_64-with-glibc2.39
|
100 |
+
- Running on Google Colab?: No
|
101 |
+
- Python version: 3.10.12
|
102 |
+
- PyTorch version (GPU?): 2.5.1+cu124 (True)
|
103 |
+
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
|
104 |
+
- Jax version: not installed
|
105 |
+
- JaxLib version: not installed
|
106 |
+
- Huggingface_hub version: 0.27.0
|
107 |
+
- Transformers version: 4.47.1
|
108 |
+
- Accelerate version: 1.2.1
|
109 |
+
- PEFT version: 0.13.2
|
110 |
+
- Bitsandbytes version: 0.44.1
|
111 |
+
- Safetensors version: 0.4.4
|
112 |
+
- xFormers version: 0.0.29.post1
|
113 |
+
- Accelerator: NVIDIA GeForce RTX 4090, 24564 MiB
|
114 |
+
NVIDIA GeForce RTX 4090, 24564 MiB
|
115 |
+
- Using GPU in script?: <fill in>
|
116 |
+
- Using distributed or parallel set-up in script?: <fill in>
|
117 |
+
```
|
118 |
+
|
119 |
+
</details>
|
120 |
+
|
121 |
+
> [!NOTE]
|
122 |
+
> The RoPE implementation from `q8_kernels` [isn't usable as of 1st Jan 2025](https://github.com/KONAKONA666/q8_kernels/blob/9cee3f3d4ca5ec8ab463179be32c8001e31f8f33/q8_kernels/functional/rope.py#L26). So, we resort to using [the one](https://github.com/huggingface/diffusers/blob/91008aabc4b8dbd96a356ab6f457f3bd84b10e8b/src/diffusers/models/transformers/transformer_ltx.py#L464) from `diffusers`.
|
123 |
+
|
124 |
+
|
125 |
+
## Comparison
|
126 |
+
|
127 |
+
Check out [this page](https://wandb.ai/sayakpaul/q8-ltx-video/runs/89h6ac5) on Weights and Biases that provides some comparative results. Generated videos are also available [here](./videos/).
|
128 |
+
|
129 |
+
## Acknowledgement
|
130 |
+
|
131 |
+
KONAKONA666's works on [`KONAKONA666/q8_kernels`](https://github.com/KONAKONA666/q8_kernels) and [KONAKONA666/LTX-Video](https://github.com/KONAKONA666/LTX-Video).
|
132 |
+
|
app.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from app_utils import prepare_pipeline, compute_hash
|
3 |
+
import os
|
4 |
+
from diffusers.utils import export_to_video
|
5 |
+
import tempfile
|
6 |
+
import torch
|
7 |
+
from inference import load_text_encoding_pipeline
|
8 |
+
|
9 |
+
|
10 |
+
text_encoding_pipeline = load_text_encoding_pipeline()
|
11 |
+
inference_pipeline = prepare_pipeline()
|
12 |
+
|
13 |
+
|
14 |
+
def create_advanced_options():
|
15 |
+
with gr.Accordion("Advanced Options (Optional)", open=False):
|
16 |
+
seed = gr.Slider(label="Seed", minimum=0, maximum=1000000, step=1, value=646373)
|
17 |
+
inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=50, step=1, value=30)
|
18 |
+
guidance_scale = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=5.0, step=0.1, value=3.0)
|
19 |
+
max_sequence_length = gr.Slider(label="Maximum sequence length", minimum=128, maximum=512, step=1, value=128)
|
20 |
+
fps = gr.Slider(label="FPS", minimum=21, maximum=30, step=1, value=24)
|
21 |
+
return [
|
22 |
+
seed,
|
23 |
+
inference_steps,
|
24 |
+
guidance_scale,
|
25 |
+
max_sequence_length,
|
26 |
+
fps
|
27 |
+
]
|
28 |
+
|
29 |
+
|
30 |
+
@torch.no_grad()
|
31 |
+
def generate_video_from_text(prompt, negative_prompt, seed, steps, guidance_scale, max_sequence_length, fps):
|
32 |
+
global text_encoding_pipeline
|
33 |
+
prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask = text_encoding_pipeline.encode_prompt(
|
34 |
+
prompt=prompt, negative_prompt=negative_prompt, max_sequence_length=max_sequence_length
|
35 |
+
)
|
36 |
+
global inference_pipeline
|
37 |
+
video = inference_pipeline(
|
38 |
+
prompt_embeds=prompt_embeds,
|
39 |
+
prompt_attention_mask=prompt_attention_mask,
|
40 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
41 |
+
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
42 |
+
guidance_scale=guidance_scale,
|
43 |
+
width=768,
|
44 |
+
height=512,
|
45 |
+
num_frames=121,
|
46 |
+
num_inference_steps=steps,
|
47 |
+
max_sequence_length=max_sequence_length,
|
48 |
+
generator=torch.manual_seed(seed),
|
49 |
+
).frames[0]
|
50 |
+
out_path = tempfile.mkstemp(suffix=".mp4")
|
51 |
+
export_to_video(video, out_path, fps=fps)
|
52 |
+
return out_path
|
53 |
+
|
54 |
+
|
55 |
+
with gr.Blocks(theme=gr.themes.Soft()) as iface:
|
56 |
+
with gr.Row(elem_id="title-row"):
|
57 |
+
gr.Markdown(
|
58 |
+
"""
|
59 |
+
<div style="text-align: center; margin-bottom: 1em">
|
60 |
+
<h1 style="font-size: 2.5em; font-weight: 600; margin: 0.5em 0;">Fast Video Generation with <a href="https://github.com/sayakpaul/q8-ltx-video">Q8 LTX Video</a></h1>
|
61 |
+
</div>
|
62 |
+
"""
|
63 |
+
)
|
64 |
+
with gr.Row(elem_id="title-row"):
|
65 |
+
gr.HTML( # add technical report link
|
66 |
+
"""
|
67 |
+
<div style="display:flex;column-gap:4px;">
|
68 |
+
<span>This space is modified from the original <a href="https://huggingface.co/spaces/Lightricks/LTX-Video-Playground">LTX-Video playground</a>. It uses optimized Q8 kernels along with torch.compile to allow for ultra-fast video
|
69 |
+
generation. As a result, it restricts generations to 121 frames with 512x768 resolution. For more details, refer to <a href="https://github.com/sayakpaul/q8-ltx-video">this link</a>.
|
70 |
+
</div>
|
71 |
+
"""
|
72 |
+
)
|
73 |
+
with gr.Accordion(" 📖 Tips for Best Results", open=False, elem_id="instructions-accordion"):
|
74 |
+
gr.Markdown(
|
75 |
+
"""
|
76 |
+
📝 Prompt Engineering
|
77 |
+
When writing prompts, focus on detailed, chronological descriptions of actions and scenes. Include specific movements, appearances, camera angles, and environmental details - all in a single flowing paragraph. Start directly with the action, and keep descriptions literal and precise. Think like a cinematographer describing a shot list. Keep within 200 words.
|
78 |
+
For best results, build your prompts using this structure:
|
79 |
+
- Start with main action in a single sentence
|
80 |
+
- Add specific details about movements and gestures
|
81 |
+
- Describe character/object appearances precisely
|
82 |
+
- Include background and environment details
|
83 |
+
- Specify camera angles and movements
|
84 |
+
- Describe lighting and colors
|
85 |
+
- Note any changes or sudden events
|
86 |
+
See examples for more inspiration.
|
87 |
+
🎮 Parameter Guide
|
88 |
+
- Resolution Preset: Higher resolutions for detailed scenes, lower for faster generation and simpler scenes
|
89 |
+
- Seed: Save seed values to recreate specific styles or compositions you like
|
90 |
+
- Guidance Scale: 3-3.5 are the recommended values
|
91 |
+
- Inference Steps: More steps (40+) for quality, fewer steps (20-30) for speed
|
92 |
+
- When using detailed prompts, use a higher `max_sequence_length` value.
|
93 |
+
"""
|
94 |
+
)
|
95 |
+
|
96 |
+
with gr.Tabs():
|
97 |
+
# Text to Video Tab
|
98 |
+
with gr.TabItem("Text to Video"):
|
99 |
+
with gr.Row():
|
100 |
+
with gr.Column():
|
101 |
+
txt2vid_prompt = gr.Textbox(
|
102 |
+
label="Enter Your Prompt",
|
103 |
+
placeholder="Describe the video you want to generate (minimum 50 characters)...",
|
104 |
+
value="A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage.",
|
105 |
+
lines=5,
|
106 |
+
)
|
107 |
+
|
108 |
+
txt2vid_negative_prompt = gr.Textbox(
|
109 |
+
label="Enter Negative Prompt",
|
110 |
+
placeholder="Describe what you don't want in the video...",
|
111 |
+
value="low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly",
|
112 |
+
lines=2,
|
113 |
+
)
|
114 |
+
txt2vid_advanced = create_advanced_options()
|
115 |
+
|
116 |
+
txt2vid_generate = gr.Button(
|
117 |
+
"Generate Video",
|
118 |
+
variant="primary",
|
119 |
+
size="lg",
|
120 |
+
)
|
121 |
+
|
122 |
+
with gr.Column():
|
123 |
+
txt2vid_output = gr.Video(label="Generated Output")
|
124 |
+
|
125 |
+
with gr.Row():
|
126 |
+
gr.Examples(
|
127 |
+
examples=[
|
128 |
+
[
|
129 |
+
"A clear, turquoise river flows through a rocky canyon, cascading over a small waterfall and forming a pool of water at the bottom.The river is the main focus of the scene, with its clear water reflecting the surrounding trees and rocks. The canyon walls are steep and rocky, with some vegetation growing on them. The trees are mostly pine trees, with their green needles contrasting with the brown and gray rocks. The overall tone of the scene is one of peace and tranquility.",
|
130 |
+
"low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly",
|
131 |
+
"assets/river.mp4",
|
132 |
+
],
|
133 |
+
[
|
134 |
+
"The camera pans over a snow-covered mountain range, revealing a vast expanse of snow-capped peaks and valleys.The mountains are covered in a thick layer of snow, with some areas appearing almost white while others have a slightly darker, almost grayish hue. The peaks are jagged and irregular, with some rising sharply into the sky while others are more rounded. The valleys are deep and narrow, with steep slopes that are also covered in snow. The trees in the foreground are mostly bare, with only a few leaves remaining on their branches. The sky is overcast, with thick clouds obscuring the sun. The overall impression is one of peace and tranquility, with the snow-covered mountains standing as a testament to the power and beauty of nature.",
|
135 |
+
"low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly",
|
136 |
+
"assets/mountain.mp4",
|
137 |
+
],
|
138 |
+
],
|
139 |
+
inputs=[txt2vid_prompt, txt2vid_negative_prompt, txt2vid_output],
|
140 |
+
label="Example Text-to-Video Generations",
|
141 |
+
)
|
142 |
+
|
143 |
+
txt2vid_generate.click(
|
144 |
+
fn=generate_video_from_text,
|
145 |
+
inputs=[
|
146 |
+
txt2vid_prompt,
|
147 |
+
txt2vid_negative_prompt,
|
148 |
+
*txt2vid_advanced,
|
149 |
+
],
|
150 |
+
outputs=txt2vid_output,
|
151 |
+
concurrency_limit=1,
|
152 |
+
concurrency_id="generate_video_from_text",
|
153 |
+
)
|
154 |
+
|
155 |
+
if __name__ == "__main__":
|
156 |
+
iface.queue(max_size=64, default_concurrency_limit=1, api_open=False).launch(share=True, show_api=False)
|
app_utils.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from inference import load_q8_transformer
|
2 |
+
import hashlib
|
3 |
+
from q8_kernels.graph.graph import make_dynamic_graphed_callable
|
4 |
+
from argparse import Namespace
|
5 |
+
from diffusers import LTXPipeline
|
6 |
+
import types
|
7 |
+
import torch
|
8 |
+
|
9 |
+
# To account for the type-casting in `ff_output` of `LTXVideoTransformerBlock`
|
10 |
+
def patched_ltx_transformer_forward(
|
11 |
+
self,
|
12 |
+
hidden_states: torch.Tensor,
|
13 |
+
encoder_hidden_states: torch.Tensor,
|
14 |
+
temb: torch.Tensor,
|
15 |
+
image_rotary_emb = None,
|
16 |
+
encoder_attention_mask = None,
|
17 |
+
) -> torch.Tensor:
|
18 |
+
batch_size = hidden_states.size(0)
|
19 |
+
norm_hidden_states = self.norm1(hidden_states)
|
20 |
+
|
21 |
+
num_ada_params = self.scale_shift_table.shape[0]
|
22 |
+
ada_values = self.scale_shift_table[None, None] + temb.reshape(batch_size, temb.size(1), num_ada_params, -1)
|
23 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
|
24 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
|
25 |
+
|
26 |
+
attn_hidden_states = self.attn1(
|
27 |
+
hidden_states=norm_hidden_states,
|
28 |
+
encoder_hidden_states=None,
|
29 |
+
image_rotary_emb=image_rotary_emb,
|
30 |
+
)
|
31 |
+
hidden_states = hidden_states + attn_hidden_states * gate_msa
|
32 |
+
|
33 |
+
attn_hidden_states = self.attn2(
|
34 |
+
hidden_states,
|
35 |
+
encoder_hidden_states=encoder_hidden_states,
|
36 |
+
image_rotary_emb=None,
|
37 |
+
attention_mask=encoder_attention_mask,
|
38 |
+
)
|
39 |
+
hidden_states = hidden_states + attn_hidden_states
|
40 |
+
norm_hidden_states = self.norm2(hidden_states) * (1 + scale_mlp) + shift_mlp
|
41 |
+
|
42 |
+
ff_output = self.ff(norm_hidden_states).to(norm_hidden_states.dtype)
|
43 |
+
hidden_states = hidden_states + ff_output * gate_mlp
|
44 |
+
|
45 |
+
return hidden_states
|
46 |
+
|
47 |
+
def load_transformer():
|
48 |
+
args = Namespace()
|
49 |
+
args.q8_transformer_path = "sayakpaul/q8-ltx-video"
|
50 |
+
transformer = load_q8_transformer(args)
|
51 |
+
|
52 |
+
transformer.to(torch.bfloat16)
|
53 |
+
for b in transformer.transformer_blocks:
|
54 |
+
b.to(dtype=torch.float)
|
55 |
+
|
56 |
+
for n, m in transformer.transformer_blocks.named_parameters():
|
57 |
+
if "scale_shift_table" in n:
|
58 |
+
m.data = m.data.to(torch.bfloat16)
|
59 |
+
|
60 |
+
for b in transformer.transformer_blocks:
|
61 |
+
b.forward = types.MethodType(patched_ltx_transformer_forward, b)
|
62 |
+
|
63 |
+
transformer.forward = make_dynamic_graphed_callable(transformer.forward)
|
64 |
+
return transformer
|
65 |
+
|
66 |
+
def warmup_transformer(pipe):
|
67 |
+
prompt_embeds = torch.load("prompt_embeds.pt", map_location="cuda", weights_only=True)
|
68 |
+
for _ in range(5):
|
69 |
+
_ = pipe(
|
70 |
+
**prompt_embeds,
|
71 |
+
output_type="latent",
|
72 |
+
width=768,
|
73 |
+
height=512,
|
74 |
+
num_frames=121
|
75 |
+
)
|
76 |
+
|
77 |
+
def prepare_pipeline():
|
78 |
+
pipe = LTXPipeline.from_pretrained("Lightricks/LTX-Video", text_encoder=None, torch_dtype=torch.bfloat16)
|
79 |
+
pipe.transformer = load_transformer()
|
80 |
+
pipe = pipe.to("cuda")
|
81 |
+
pipe.transformer.compile()
|
82 |
+
pipe.set_progress_bar_config(disable=True)
|
83 |
+
|
84 |
+
warmup_transformer(pipe)
|
85 |
+
return pipe
|
86 |
+
|
87 |
+
|
88 |
+
def compute_hash(text: str) -> str:
|
89 |
+
# Encode the text to bytes
|
90 |
+
text_bytes = text.encode("utf-8")
|
91 |
+
|
92 |
+
# Create a SHA-256 hash object
|
93 |
+
hash_object = hashlib.sha256()
|
94 |
+
|
95 |
+
# Update the hash object with the text bytes
|
96 |
+
hash_object.update(text_bytes)
|
97 |
+
|
98 |
+
# Return the hexadecimal representation of the hash
|
99 |
+
return hash_object.hexdigest()
|
assets/mountain.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:540d193b2819996741ae23cf0bf106e1c3b226880b44ae1e5e54a8f38a181e9f
|
3 |
+
size 1463129
|
assets/river.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:56f21a785ac95e20e0933e553a1f5f82946737f4077ff4d71bf11de5b0ba5162
|
3 |
+
size 1788643
|
assets/woman.mp4
ADDED
Binary file (342 kB). View file
|
|
conversion_utils.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
References:
|
3 |
+
https://github.com/KONAKONA666/q8_kernels/blob/9cee3f3d4ca5ec8ab463179be32c8001e31f8f33/q8_kernels/utils/convert_weights.py
|
4 |
+
"""
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from q8_ltx import replace_gelu, replace_linear, replace_rms_norm, MODULES_TO_NOT_CONVERT
|
8 |
+
import argparse
|
9 |
+
from diffusers import LTXVideoTransformer3DModel
|
10 |
+
from q8_kernels.functional.quantizer import quantize
|
11 |
+
from q8_kernels.functional.fast_hadamard import hadamard_transform
|
12 |
+
|
13 |
+
|
14 |
+
def convert_state_dict(orig_state_dict):
|
15 |
+
prefix = "transformer_blocks"
|
16 |
+
transformer_block_keys = []
|
17 |
+
non_transformer_block_keys = []
|
18 |
+
for k in orig_state_dict:
|
19 |
+
if prefix in k:
|
20 |
+
transformer_block_keys.append(k)
|
21 |
+
else:
|
22 |
+
non_transformer_block_keys.append(k)
|
23 |
+
attn_keys = []
|
24 |
+
ffn_keys = []
|
25 |
+
scale_shift_keys = []
|
26 |
+
for k in transformer_block_keys:
|
27 |
+
if "attn" in k:
|
28 |
+
attn_keys.append(k)
|
29 |
+
for k in transformer_block_keys:
|
30 |
+
if "ff" in k:
|
31 |
+
ffn_keys.append(k)
|
32 |
+
for k in transformer_block_keys:
|
33 |
+
if "scale_shift_table" in k:
|
34 |
+
scale_shift_keys.append(k)
|
35 |
+
|
36 |
+
assert len(attn_keys + ffn_keys + scale_shift_keys) == len(transformer_block_keys), "error"
|
37 |
+
|
38 |
+
new_state_dict = {}
|
39 |
+
for k in attn_keys:
|
40 |
+
new_key = k
|
41 |
+
if "norm" in k and "weight" in k:
|
42 |
+
new_state_dict[new_key] = orig_state_dict[k].float()
|
43 |
+
elif "bias" in k:
|
44 |
+
new_state_dict[new_key] = orig_state_dict[k].float()
|
45 |
+
elif "weight" in k:
|
46 |
+
w_quant, w_scales = quantize(hadamard_transform(orig_state_dict[k].cuda().to(torch.bfloat16)))
|
47 |
+
assert w_quant.dtype == torch.int8, k
|
48 |
+
new_state_dict[new_key] = w_quant
|
49 |
+
new_state_dict[new_key.replace("weight", "scales")] = w_scales
|
50 |
+
|
51 |
+
for k in ffn_keys:
|
52 |
+
new_key = k
|
53 |
+
|
54 |
+
if "bias" in k:
|
55 |
+
new_state_dict[new_key] = orig_state_dict[k].float()
|
56 |
+
elif "weight" in k:
|
57 |
+
w_quant, w_scales = quantize(hadamard_transform(orig_state_dict[k].cuda().to(torch.bfloat16)))
|
58 |
+
assert w_quant.dtype == torch.int8, k
|
59 |
+
new_state_dict[new_key] = w_quant
|
60 |
+
new_state_dict[new_key.replace("weight", "scales")] = w_scales
|
61 |
+
|
62 |
+
for k in scale_shift_keys:
|
63 |
+
new_state_dict[k] = orig_state_dict[k]
|
64 |
+
|
65 |
+
for k in non_transformer_block_keys:
|
66 |
+
new_state_dict[k] = orig_state_dict[k]
|
67 |
+
|
68 |
+
return new_state_dict
|
69 |
+
|
70 |
+
|
71 |
+
@torch.no_grad()
|
72 |
+
def main(args):
|
73 |
+
transformer = LTXVideoTransformer3DModel.from_pretrained(args.input_path, subfolder="transformer").to("cuda")
|
74 |
+
new_state_dict = convert_state_dict(transformer.state_dict())
|
75 |
+
transformer = replace_gelu(transformer)[0]
|
76 |
+
transformer = replace_linear(transformer)[0]
|
77 |
+
transformer = replace_rms_norm(transformer)[0]
|
78 |
+
|
79 |
+
m, u = transformer.load_state_dict(new_state_dict, strict=True)
|
80 |
+
for name, module in transformer.named_modules():
|
81 |
+
if any(n in name for n in MODULES_TO_NOT_CONVERT):
|
82 |
+
if hasattr(module, "weight"):
|
83 |
+
assert module.weight.dtype == torch.float32
|
84 |
+
elif hasattr(module, "linear"):
|
85 |
+
assert module.linear.weight.dtype == torch.float32
|
86 |
+
elif getattr(module, "weight", None) is not None:
|
87 |
+
print(f"Non FP32 {name=} {module.weight.dtype=}")
|
88 |
+
if "to_" in name:
|
89 |
+
assert module.weight.dtype != torch.float32, f"{name=}, {module.weight.dtype=}"
|
90 |
+
|
91 |
+
transformer.save_pretrained(args.output_path)
|
92 |
+
print(f"Model saved in {args.output_path}")
|
93 |
+
|
94 |
+
|
95 |
+
if __name__ == "__main__":
|
96 |
+
parser = argparse.ArgumentParser()
|
97 |
+
|
98 |
+
parser.add_argument("--input_path", type=str, required=True)
|
99 |
+
parser.add_argument("--output_path", type=str, required=True)
|
100 |
+
|
101 |
+
args = parser.parse_args()
|
102 |
+
|
103 |
+
main(args)
|
inference.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from diffusers import LTXPipeline, LTXVideoTransformer3DModel
|
2 |
+
from huggingface_hub import hf_hub_download
|
3 |
+
import argparse
|
4 |
+
import os
|
5 |
+
from q8_ltx import check_transformer_replaced_correctly, replace_gelu, replace_linear, replace_rms_norm
|
6 |
+
import safetensors.torch
|
7 |
+
from q8_kernels.graph.graph import make_dynamic_graphed_callable
|
8 |
+
import torch
|
9 |
+
import gc
|
10 |
+
from diffusers.utils import export_to_video
|
11 |
+
|
12 |
+
|
13 |
+
# Taken from
|
14 |
+
# https://github.com/KONAKONA666/LTX-Video/blob/c8462ed2e359cda4dec7f49d98029994e850dc90/inference.py#L115C1-L138C28
|
15 |
+
def convert_prompt_to_filename(text: str, max_len: int = 20) -> str:
|
16 |
+
# Remove non-letters and convert to lowercase
|
17 |
+
clean_text = "".join(char.lower() for char in text if char.isalpha() or char.isspace())
|
18 |
+
# Split into words
|
19 |
+
words = clean_text.split()
|
20 |
+
|
21 |
+
# Build result string keeping track of length
|
22 |
+
result = []
|
23 |
+
current_length = 0
|
24 |
+
|
25 |
+
for word in words:
|
26 |
+
# Add word length plus 1 for underscore (except for first word)
|
27 |
+
new_length = current_length + len(word)
|
28 |
+
if new_length <= max_len:
|
29 |
+
result.append(word)
|
30 |
+
current_length += len(word)
|
31 |
+
else:
|
32 |
+
break
|
33 |
+
|
34 |
+
return "-".join(result)
|
35 |
+
|
36 |
+
|
37 |
+
def load_text_encoding_pipeline():
|
38 |
+
return LTXPipeline.from_pretrained(
|
39 |
+
"Lightricks/LTX-Video", transformer=None, vae=None, torch_dtype=torch.bfloat16
|
40 |
+
).to("cuda")
|
41 |
+
|
42 |
+
|
43 |
+
def encode_prompt(pipe, prompt, negative_prompt, max_sequence_length=128):
|
44 |
+
prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask = pipe.encode_prompt(
|
45 |
+
prompt=prompt, negative_prompt=negative_prompt, max_sequence_length=max_sequence_length
|
46 |
+
)
|
47 |
+
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
|
48 |
+
|
49 |
+
|
50 |
+
def load_q8_transformer(args):
|
51 |
+
with torch.device("meta"):
|
52 |
+
transformer_config = LTXVideoTransformer3DModel.load_config("Lightricks/LTX-Video", subfolder="transformer")
|
53 |
+
transformer = LTXVideoTransformer3DModel.from_config(transformer_config)
|
54 |
+
|
55 |
+
transformer = replace_gelu(transformer)[0]
|
56 |
+
transformer = replace_linear(transformer)[0]
|
57 |
+
transformer = replace_rms_norm(transformer)[0]
|
58 |
+
|
59 |
+
if os.path.isfile(f"{args.q8_transformer_path}/diffusion_pytorch_model.safetensors"):
|
60 |
+
state_dict = safetensors.torch.load_file(f"{args.q8_transformer_path}/diffusion_pytorch_model.safetensors")
|
61 |
+
else:
|
62 |
+
state_dict = safetensors.torch.load_file(
|
63 |
+
hf_hub_download(args.q8_transformer_path, "diffusion_pytorch_model.safetensors")
|
64 |
+
)
|
65 |
+
transformer.load_state_dict(state_dict, strict=True, assign=True)
|
66 |
+
check_transformer_replaced_correctly(transformer)
|
67 |
+
return transformer
|
68 |
+
|
69 |
+
|
70 |
+
@torch.no_grad()
|
71 |
+
def main(args):
|
72 |
+
text_encoding_pipeline = load_text_encoding_pipeline()
|
73 |
+
prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask = encode_prompt(
|
74 |
+
pipe=text_encoding_pipeline,
|
75 |
+
prompt=args.prompt,
|
76 |
+
negative_prompt=args.negative_prompt,
|
77 |
+
max_sequence_length=args.max_sequence_length,
|
78 |
+
)
|
79 |
+
del text_encoding_pipeline
|
80 |
+
torch.cuda.empty_cache()
|
81 |
+
torch.cuda.reset_peak_memory_stats()
|
82 |
+
gc.collect()
|
83 |
+
|
84 |
+
if args.q8_transformer_path:
|
85 |
+
transformer = load_q8_transformer(args)
|
86 |
+
pipe = LTXPipeline.from_pretrained("Lightricks/LTX-Video", transformer=None, text_encoder=None)
|
87 |
+
pipe.transformer = transformer
|
88 |
+
|
89 |
+
pipe.transformer = pipe.transformer.to(torch.bfloat16)
|
90 |
+
for b in pipe.transformer.transformer_blocks:
|
91 |
+
b.to(dtype=torch.float)
|
92 |
+
|
93 |
+
for n, m in pipe.transformer.transformer_blocks.named_parameters():
|
94 |
+
if "scale_shift_table" in n:
|
95 |
+
m.data = m.data.to(torch.bfloat16)
|
96 |
+
|
97 |
+
pipe.transformer.forward = make_dynamic_graphed_callable(pipe.transformer.forward)
|
98 |
+
pipe.vae = pipe.vae.to(torch.bfloat16)
|
99 |
+
|
100 |
+
else:
|
101 |
+
pipe = LTXPipeline.from_pretrained("Lightricks/LTX-Video", text_encoder=None, torch_dtype=torch.bfloat16)
|
102 |
+
|
103 |
+
pipe = pipe.to("cuda")
|
104 |
+
|
105 |
+
width, height = args.resolution.split("x")[::-1]
|
106 |
+
video = pipe(
|
107 |
+
prompt_embeds=prompt_embeds,
|
108 |
+
prompt_attention_mask=prompt_attention_mask,
|
109 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
110 |
+
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
111 |
+
width=int(width),
|
112 |
+
height=int(height),
|
113 |
+
num_frames=args.num_frames,
|
114 |
+
num_inference_steps=args.steps,
|
115 |
+
max_sequence_length=args.max_sequence_length,
|
116 |
+
generator=torch.manual_seed(2025),
|
117 |
+
).frames[0]
|
118 |
+
print(f"Max memory: {torch.cuda.max_memory_allocated() / 1024 / 1024} MB.")
|
119 |
+
|
120 |
+
if args.out_path is None:
|
121 |
+
filename_from_prompt = convert_prompt_to_filename(args.prompt, max_len=30)
|
122 |
+
base_filename = f"{filename_from_prompt}_{args.num_frames}x{height}x{width}"
|
123 |
+
base_filename += "_q8" if args.q8_transformer_path is not None else ""
|
124 |
+
args.out_path = base_filename + ".mp4"
|
125 |
+
export_to_video(video, args.out_path, fps=24)
|
126 |
+
|
127 |
+
|
128 |
+
if __name__ == "__main__":
|
129 |
+
parser = argparse.ArgumentParser()
|
130 |
+
parser.add_argument("--q8_transformer_path", type=str, default=None)
|
131 |
+
parser.add_argument("--prompt", type=str)
|
132 |
+
parser.add_argument("--negative_prompt", type=str, default=None)
|
133 |
+
parser.add_argument("--num_frames", type=int, default=81)
|
134 |
+
parser.add_argument("--resolution", type=str, default="480x704")
|
135 |
+
parser.add_argument("--steps", type=int, default=50)
|
136 |
+
parser.add_argument("--max_sequence_length", type=int, default=512)
|
137 |
+
parser.add_argument("--out_path", type=str, default=None)
|
138 |
+
args = parser.parse_args()
|
139 |
+
main(args)
|
pipeline_ltx.py
ADDED
@@ -0,0 +1,811 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Lightricks and The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""
|
16 |
+
To account for certain changes pertaining to Q8Linear.
|
17 |
+
"""
|
18 |
+
|
19 |
+
import inspect
|
20 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
21 |
+
|
22 |
+
import numpy as np
|
23 |
+
import torch
|
24 |
+
from transformers import T5EncoderModel, T5TokenizerFast
|
25 |
+
|
26 |
+
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
27 |
+
from diffusers.loaders import FromSingleFileMixin, LTXVideoLoraLoaderMixin
|
28 |
+
from diffusers.models.autoencoders import AutoencoderKLLTXVideo
|
29 |
+
from diffusers.models.transformers import LTXVideoTransformer3DModel
|
30 |
+
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
31 |
+
from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring
|
32 |
+
from diffusers.utils.torch_utils import randn_tensor
|
33 |
+
from diffusers.video_processor import VideoProcessor
|
34 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
35 |
+
from diffusers.pipelines.ltx.pipeline_output import LTXPipelineOutput
|
36 |
+
|
37 |
+
try:
|
38 |
+
import q8_kernels # noqa
|
39 |
+
from q8_kernels.modules.linear import Q8Linear
|
40 |
+
except:
|
41 |
+
Q8Linear = None
|
42 |
+
|
43 |
+
|
44 |
+
if is_torch_xla_available():
|
45 |
+
import torch_xla.core.xla_model as xm
|
46 |
+
|
47 |
+
XLA_AVAILABLE = True
|
48 |
+
else:
|
49 |
+
XLA_AVAILABLE = False
|
50 |
+
|
51 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
52 |
+
|
53 |
+
EXAMPLE_DOC_STRING = """
|
54 |
+
Examples:
|
55 |
+
```py
|
56 |
+
>>> import torch
|
57 |
+
>>> from diffusers import LTXPipeline
|
58 |
+
>>> from diffusers.utils import export_to_video
|
59 |
+
|
60 |
+
>>> pipe = LTXPipeline.from_pretrained("Lightricks/LTX-Video", torch_dtype=torch.bfloat16)
|
61 |
+
>>> pipe.to("cuda")
|
62 |
+
|
63 |
+
>>> prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage"
|
64 |
+
>>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
|
65 |
+
|
66 |
+
>>> video = pipe(
|
67 |
+
... prompt=prompt,
|
68 |
+
... negative_prompt=negative_prompt,
|
69 |
+
... width=704,
|
70 |
+
... height=480,
|
71 |
+
... num_frames=161,
|
72 |
+
... num_inference_steps=50,
|
73 |
+
... ).frames[0]
|
74 |
+
>>> export_to_video(video, "output.mp4", fps=24)
|
75 |
+
```
|
76 |
+
"""
|
77 |
+
|
78 |
+
|
79 |
+
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
|
80 |
+
def calculate_shift(
|
81 |
+
image_seq_len,
|
82 |
+
base_seq_len: int = 256,
|
83 |
+
max_seq_len: int = 4096,
|
84 |
+
base_shift: float = 0.5,
|
85 |
+
max_shift: float = 1.16,
|
86 |
+
):
|
87 |
+
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
88 |
+
b = base_shift - m * base_seq_len
|
89 |
+
mu = image_seq_len * m + b
|
90 |
+
return mu
|
91 |
+
|
92 |
+
|
93 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
94 |
+
def retrieve_timesteps(
|
95 |
+
scheduler,
|
96 |
+
num_inference_steps: Optional[int] = None,
|
97 |
+
device: Optional[Union[str, torch.device]] = None,
|
98 |
+
timesteps: Optional[List[int]] = None,
|
99 |
+
sigmas: Optional[List[float]] = None,
|
100 |
+
**kwargs,
|
101 |
+
):
|
102 |
+
r"""
|
103 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
104 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
105 |
+
|
106 |
+
Args:
|
107 |
+
scheduler (`SchedulerMixin`):
|
108 |
+
The scheduler to get timesteps from.
|
109 |
+
num_inference_steps (`int`):
|
110 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
111 |
+
must be `None`.
|
112 |
+
device (`str` or `torch.device`, *optional*):
|
113 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
114 |
+
timesteps (`List[int]`, *optional*):
|
115 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
116 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
117 |
+
sigmas (`List[float]`, *optional*):
|
118 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
119 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
120 |
+
|
121 |
+
Returns:
|
122 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
123 |
+
second element is the number of inference steps.
|
124 |
+
"""
|
125 |
+
if timesteps is not None and sigmas is not None:
|
126 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
127 |
+
if timesteps is not None:
|
128 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
129 |
+
if not accepts_timesteps:
|
130 |
+
raise ValueError(
|
131 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
132 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
133 |
+
)
|
134 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
135 |
+
timesteps = scheduler.timesteps
|
136 |
+
num_inference_steps = len(timesteps)
|
137 |
+
elif sigmas is not None:
|
138 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
139 |
+
if not accept_sigmas:
|
140 |
+
raise ValueError(
|
141 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
142 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
143 |
+
)
|
144 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
145 |
+
timesteps = scheduler.timesteps
|
146 |
+
num_inference_steps = len(timesteps)
|
147 |
+
else:
|
148 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
149 |
+
timesteps = scheduler.timesteps
|
150 |
+
return timesteps, num_inference_steps
|
151 |
+
|
152 |
+
|
153 |
+
class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixin):
|
154 |
+
r"""
|
155 |
+
Pipeline for text-to-video generation.
|
156 |
+
|
157 |
+
Reference: https://github.com/Lightricks/LTX-Video
|
158 |
+
|
159 |
+
Args:
|
160 |
+
transformer ([`LTXVideoTransformer3DModel`]):
|
161 |
+
Conditional Transformer architecture to denoise the encoded video latents.
|
162 |
+
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
163 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
164 |
+
vae ([`AutoencoderKLLTXVideo`]):
|
165 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
166 |
+
text_encoder ([`T5EncoderModel`]):
|
167 |
+
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
|
168 |
+
the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
|
169 |
+
tokenizer (`CLIPTokenizer`):
|
170 |
+
Tokenizer of class
|
171 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
|
172 |
+
tokenizer (`T5TokenizerFast`):
|
173 |
+
Second Tokenizer of class
|
174 |
+
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
|
175 |
+
"""
|
176 |
+
|
177 |
+
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
178 |
+
_optional_components = []
|
179 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
180 |
+
|
181 |
+
def __init__(
|
182 |
+
self,
|
183 |
+
scheduler: FlowMatchEulerDiscreteScheduler,
|
184 |
+
vae: AutoencoderKLLTXVideo,
|
185 |
+
text_encoder: T5EncoderModel,
|
186 |
+
tokenizer: T5TokenizerFast,
|
187 |
+
transformer: LTXVideoTransformer3DModel,
|
188 |
+
):
|
189 |
+
super().__init__()
|
190 |
+
|
191 |
+
self.register_modules(
|
192 |
+
vae=vae,
|
193 |
+
text_encoder=text_encoder,
|
194 |
+
tokenizer=tokenizer,
|
195 |
+
transformer=transformer,
|
196 |
+
scheduler=scheduler,
|
197 |
+
)
|
198 |
+
|
199 |
+
self.vae_spatial_compression_ratio = (
|
200 |
+
self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32
|
201 |
+
)
|
202 |
+
self.vae_temporal_compression_ratio = (
|
203 |
+
self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8
|
204 |
+
)
|
205 |
+
self.transformer_spatial_patch_size = (
|
206 |
+
self.transformer.config.patch_size if getattr(self, "transformer", None) is not None else 1
|
207 |
+
)
|
208 |
+
self.transformer_temporal_patch_size = (
|
209 |
+
self.transformer.config.patch_size_t if getattr(self, "transformer") is not None else 1
|
210 |
+
)
|
211 |
+
|
212 |
+
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio)
|
213 |
+
self.tokenizer_max_length = (
|
214 |
+
self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 128
|
215 |
+
)
|
216 |
+
|
217 |
+
def _get_t5_prompt_embeds(
|
218 |
+
self,
|
219 |
+
prompt: Union[str, List[str]] = None,
|
220 |
+
num_videos_per_prompt: int = 1,
|
221 |
+
max_sequence_length: int = 128,
|
222 |
+
device: Optional[torch.device] = None,
|
223 |
+
dtype: Optional[torch.dtype] = None,
|
224 |
+
):
|
225 |
+
device = device or self._execution_device
|
226 |
+
dtype = dtype or self.text_encoder.dtype
|
227 |
+
|
228 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
229 |
+
batch_size = len(prompt)
|
230 |
+
|
231 |
+
text_inputs = self.tokenizer(
|
232 |
+
prompt,
|
233 |
+
padding="max_length",
|
234 |
+
max_length=max_sequence_length,
|
235 |
+
truncation=True,
|
236 |
+
add_special_tokens=True,
|
237 |
+
return_tensors="pt",
|
238 |
+
)
|
239 |
+
text_input_ids = text_inputs.input_ids
|
240 |
+
prompt_attention_mask = text_inputs.attention_mask
|
241 |
+
prompt_attention_mask = prompt_attention_mask.bool().to(device)
|
242 |
+
|
243 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
244 |
+
|
245 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
246 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
247 |
+
logger.warning(
|
248 |
+
"The following part of your input was truncated because `max_sequence_length` is set to "
|
249 |
+
f" {max_sequence_length} tokens: {removed_text}"
|
250 |
+
)
|
251 |
+
|
252 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
|
253 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
254 |
+
|
255 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
256 |
+
_, seq_len, _ = prompt_embeds.shape
|
257 |
+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
258 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
259 |
+
|
260 |
+
prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
|
261 |
+
prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1)
|
262 |
+
|
263 |
+
return prompt_embeds, prompt_attention_mask
|
264 |
+
|
265 |
+
# Copied from diffusers.pipelines.mochi.pipeline_mochi.MochiPipeline.encode_prompt with 256->128
|
266 |
+
def encode_prompt(
|
267 |
+
self,
|
268 |
+
prompt: Union[str, List[str]],
|
269 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
270 |
+
do_classifier_free_guidance: bool = True,
|
271 |
+
num_videos_per_prompt: int = 1,
|
272 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
273 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
274 |
+
prompt_attention_mask: Optional[torch.Tensor] = None,
|
275 |
+
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
276 |
+
max_sequence_length: int = 128,
|
277 |
+
device: Optional[torch.device] = None,
|
278 |
+
dtype: Optional[torch.dtype] = None,
|
279 |
+
):
|
280 |
+
r"""
|
281 |
+
Encodes the prompt into text encoder hidden states.
|
282 |
+
|
283 |
+
Args:
|
284 |
+
prompt (`str` or `List[str]`, *optional*):
|
285 |
+
prompt to be encoded
|
286 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
287 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
288 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
289 |
+
less than `1`).
|
290 |
+
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
291 |
+
Whether to use classifier free guidance or not.
|
292 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
293 |
+
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
294 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
295 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
296 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
297 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
298 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
299 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
300 |
+
argument.
|
301 |
+
device: (`torch.device`, *optional*):
|
302 |
+
torch device
|
303 |
+
dtype: (`torch.dtype`, *optional*):
|
304 |
+
torch dtype
|
305 |
+
"""
|
306 |
+
print(f"{max_sequence_length=}")
|
307 |
+
device = device or self._execution_device
|
308 |
+
|
309 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
310 |
+
if prompt is not None:
|
311 |
+
batch_size = len(prompt)
|
312 |
+
else:
|
313 |
+
batch_size = prompt_embeds.shape[0]
|
314 |
+
|
315 |
+
if prompt_embeds is None:
|
316 |
+
prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(
|
317 |
+
prompt=prompt,
|
318 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
319 |
+
max_sequence_length=max_sequence_length,
|
320 |
+
device=device,
|
321 |
+
dtype=dtype,
|
322 |
+
)
|
323 |
+
|
324 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
325 |
+
negative_prompt = negative_prompt or ""
|
326 |
+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
327 |
+
|
328 |
+
if prompt is not None and type(prompt) is not type(negative_prompt):
|
329 |
+
raise TypeError(
|
330 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
331 |
+
f" {type(prompt)}."
|
332 |
+
)
|
333 |
+
elif batch_size != len(negative_prompt):
|
334 |
+
raise ValueError(
|
335 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
336 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
337 |
+
" the batch size of `prompt`."
|
338 |
+
)
|
339 |
+
|
340 |
+
negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds(
|
341 |
+
prompt=negative_prompt,
|
342 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
343 |
+
max_sequence_length=max_sequence_length,
|
344 |
+
device=device,
|
345 |
+
dtype=dtype,
|
346 |
+
)
|
347 |
+
|
348 |
+
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
|
349 |
+
|
350 |
+
def check_inputs(
|
351 |
+
self,
|
352 |
+
prompt,
|
353 |
+
height,
|
354 |
+
width,
|
355 |
+
callback_on_step_end_tensor_inputs=None,
|
356 |
+
prompt_embeds=None,
|
357 |
+
negative_prompt_embeds=None,
|
358 |
+
prompt_attention_mask=None,
|
359 |
+
negative_prompt_attention_mask=None,
|
360 |
+
):
|
361 |
+
if height % 32 != 0 or width % 32 != 0:
|
362 |
+
raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
|
363 |
+
|
364 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
365 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
366 |
+
):
|
367 |
+
raise ValueError(
|
368 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
369 |
+
)
|
370 |
+
|
371 |
+
if prompt is not None and prompt_embeds is not None:
|
372 |
+
raise ValueError(
|
373 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
374 |
+
" only forward one of the two."
|
375 |
+
)
|
376 |
+
elif prompt is None and prompt_embeds is None:
|
377 |
+
raise ValueError(
|
378 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
379 |
+
)
|
380 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
381 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
382 |
+
|
383 |
+
if prompt_embeds is not None and prompt_attention_mask is None:
|
384 |
+
raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
|
385 |
+
|
386 |
+
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
|
387 |
+
raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
|
388 |
+
|
389 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
390 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
391 |
+
raise ValueError(
|
392 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
393 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
394 |
+
f" {negative_prompt_embeds.shape}."
|
395 |
+
)
|
396 |
+
if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
|
397 |
+
raise ValueError(
|
398 |
+
"`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
|
399 |
+
f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
|
400 |
+
f" {negative_prompt_attention_mask.shape}."
|
401 |
+
)
|
402 |
+
|
403 |
+
@staticmethod
|
404 |
+
def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:
|
405 |
+
# Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p].
|
406 |
+
# The patch dimensions are then permuted and collapsed into the channel dimension of shape:
|
407 |
+
# [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor).
|
408 |
+
# dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features
|
409 |
+
batch_size, num_channels, num_frames, height, width = latents.shape
|
410 |
+
post_patch_num_frames = num_frames // patch_size_t
|
411 |
+
post_patch_height = height // patch_size
|
412 |
+
post_patch_width = width // patch_size
|
413 |
+
latents = latents.reshape(
|
414 |
+
batch_size,
|
415 |
+
-1,
|
416 |
+
post_patch_num_frames,
|
417 |
+
patch_size_t,
|
418 |
+
post_patch_height,
|
419 |
+
patch_size,
|
420 |
+
post_patch_width,
|
421 |
+
patch_size,
|
422 |
+
)
|
423 |
+
latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3)
|
424 |
+
return latents
|
425 |
+
|
426 |
+
@staticmethod
|
427 |
+
def _unpack_latents(
|
428 |
+
latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1
|
429 |
+
) -> torch.Tensor:
|
430 |
+
# Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions)
|
431 |
+
# are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of
|
432 |
+
# what happens in the `_pack_latents` method.
|
433 |
+
batch_size = latents.size(0)
|
434 |
+
latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size)
|
435 |
+
latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
436 |
+
return latents
|
437 |
+
|
438 |
+
@staticmethod
|
439 |
+
def _normalize_latents(
|
440 |
+
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
|
441 |
+
) -> torch.Tensor:
|
442 |
+
# Normalize latents across the channel dimension [B, C, F, H, W]
|
443 |
+
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
444 |
+
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
445 |
+
latents = (latents - latents_mean) * scaling_factor / latents_std
|
446 |
+
return latents
|
447 |
+
|
448 |
+
@staticmethod
|
449 |
+
def _denormalize_latents(
|
450 |
+
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
|
451 |
+
) -> torch.Tensor:
|
452 |
+
# Denormalize latents across the channel dimension [B, C, F, H, W]
|
453 |
+
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
454 |
+
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
455 |
+
latents = latents * latents_std / scaling_factor + latents_mean
|
456 |
+
return latents
|
457 |
+
|
458 |
+
def prepare_latents(
|
459 |
+
self,
|
460 |
+
batch_size: int = 1,
|
461 |
+
num_channels_latents: int = 128,
|
462 |
+
height: int = 512,
|
463 |
+
width: int = 704,
|
464 |
+
num_frames: int = 161,
|
465 |
+
dtype: Optional[torch.dtype] = None,
|
466 |
+
device: Optional[torch.device] = None,
|
467 |
+
generator: Optional[torch.Generator] = None,
|
468 |
+
latents: Optional[torch.Tensor] = None,
|
469 |
+
) -> torch.Tensor:
|
470 |
+
if latents is not None:
|
471 |
+
return latents.to(device=device, dtype=dtype)
|
472 |
+
|
473 |
+
height = height // self.vae_spatial_compression_ratio
|
474 |
+
width = width // self.vae_spatial_compression_ratio
|
475 |
+
num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
|
476 |
+
|
477 |
+
shape = (batch_size, num_channels_latents, num_frames, height, width)
|
478 |
+
|
479 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
480 |
+
raise ValueError(
|
481 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
482 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
483 |
+
)
|
484 |
+
|
485 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
486 |
+
latents = self._pack_latents(
|
487 |
+
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
|
488 |
+
)
|
489 |
+
return latents
|
490 |
+
|
491 |
+
@property
|
492 |
+
def guidance_scale(self):
|
493 |
+
return self._guidance_scale
|
494 |
+
|
495 |
+
@property
|
496 |
+
def do_classifier_free_guidance(self):
|
497 |
+
return self._guidance_scale > 1.0
|
498 |
+
|
499 |
+
@property
|
500 |
+
def num_timesteps(self):
|
501 |
+
return self._num_timesteps
|
502 |
+
|
503 |
+
@property
|
504 |
+
def attention_kwargs(self):
|
505 |
+
return self._attention_kwargs
|
506 |
+
|
507 |
+
@property
|
508 |
+
def interrupt(self):
|
509 |
+
return self._interrupt
|
510 |
+
|
511 |
+
@torch.no_grad()
|
512 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
513 |
+
def __call__(
|
514 |
+
self,
|
515 |
+
prompt: Union[str, List[str]] = None,
|
516 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
517 |
+
height: int = 512,
|
518 |
+
width: int = 704,
|
519 |
+
num_frames: int = 161,
|
520 |
+
frame_rate: int = 25,
|
521 |
+
num_inference_steps: int = 50,
|
522 |
+
timesteps: List[int] = None,
|
523 |
+
guidance_scale: float = 3,
|
524 |
+
num_videos_per_prompt: Optional[int] = 1,
|
525 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
526 |
+
latents: Optional[torch.Tensor] = None,
|
527 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
528 |
+
prompt_attention_mask: Optional[torch.Tensor] = None,
|
529 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
530 |
+
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
531 |
+
decode_timestep: Union[float, List[float]] = 0.0,
|
532 |
+
decode_noise_scale: Optional[Union[float, List[float]]] = None,
|
533 |
+
output_type: Optional[str] = "pil",
|
534 |
+
return_dict: bool = True,
|
535 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
536 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
537 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
538 |
+
max_sequence_length: int = 128,
|
539 |
+
):
|
540 |
+
r"""
|
541 |
+
Function invoked when calling the pipeline for generation.
|
542 |
+
|
543 |
+
Args:
|
544 |
+
prompt (`str` or `List[str]`, *optional*):
|
545 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
546 |
+
instead.
|
547 |
+
height (`int`, defaults to `512`):
|
548 |
+
The height in pixels of the generated image. This is set to 480 by default for the best results.
|
549 |
+
width (`int`, defaults to `704`):
|
550 |
+
The width in pixels of the generated image. This is set to 848 by default for the best results.
|
551 |
+
num_frames (`int`, defaults to `161`):
|
552 |
+
The number of video frames to generate
|
553 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
554 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
555 |
+
expense of slower inference.
|
556 |
+
timesteps (`List[int]`, *optional*):
|
557 |
+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
558 |
+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
559 |
+
passed will be used. Must be in descending order.
|
560 |
+
guidance_scale (`float`, defaults to `3 `):
|
561 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
562 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
563 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
564 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
565 |
+
usually at the expense of lower image quality.
|
566 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
567 |
+
The number of videos to generate per prompt.
|
568 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
569 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
570 |
+
to make generation deterministic.
|
571 |
+
latents (`torch.Tensor`, *optional*):
|
572 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
573 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
574 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
575 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
576 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
577 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
578 |
+
prompt_attention_mask (`torch.Tensor`, *optional*):
|
579 |
+
Pre-generated attention mask for text embeddings.
|
580 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
581 |
+
Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
|
582 |
+
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
|
583 |
+
negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
|
584 |
+
Pre-generated attention mask for negative text embeddings.
|
585 |
+
decode_timestep (`float`, defaults to `0.0`):
|
586 |
+
The timestep at which generated video is decoded.
|
587 |
+
decode_noise_scale (`float`, defaults to `None`):
|
588 |
+
The interpolation factor between random noise and denoised latents at the decode timestep.
|
589 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
590 |
+
The output format of the generate image. Choose between
|
591 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
592 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
593 |
+
Whether or not to return a [`~pipelines.ltx.LTXPipelineOutput`] instead of a plain tuple.
|
594 |
+
attention_kwargs (`dict`, *optional*):
|
595 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
596 |
+
`self.processor` in
|
597 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
598 |
+
callback_on_step_end (`Callable`, *optional*):
|
599 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
600 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
601 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
602 |
+
`callback_on_step_end_tensor_inputs`.
|
603 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
604 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
605 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
606 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
607 |
+
max_sequence_length (`int` defaults to `128 `):
|
608 |
+
Maximum sequence length to use with the `prompt`.
|
609 |
+
|
610 |
+
Examples:
|
611 |
+
|
612 |
+
Returns:
|
613 |
+
[`~pipelines.ltx.LTXPipelineOutput`] or `tuple`:
|
614 |
+
If `return_dict` is `True`, [`~pipelines.ltx.LTXPipelineOutput`] is returned, otherwise a `tuple` is
|
615 |
+
returned where the first element is a list with the generated images.
|
616 |
+
"""
|
617 |
+
|
618 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
619 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
620 |
+
|
621 |
+
# 1. Check inputs. Raise error if not correct
|
622 |
+
self.check_inputs(
|
623 |
+
prompt=prompt,
|
624 |
+
height=height,
|
625 |
+
width=width,
|
626 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
627 |
+
prompt_embeds=prompt_embeds,
|
628 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
629 |
+
prompt_attention_mask=prompt_attention_mask,
|
630 |
+
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
631 |
+
)
|
632 |
+
|
633 |
+
self._guidance_scale = guidance_scale
|
634 |
+
self._attention_kwargs = attention_kwargs
|
635 |
+
self._interrupt = False
|
636 |
+
|
637 |
+
# 2. Define call parameters
|
638 |
+
if prompt is not None and isinstance(prompt, str):
|
639 |
+
batch_size = 1
|
640 |
+
elif prompt is not None and isinstance(prompt, list):
|
641 |
+
batch_size = len(prompt)
|
642 |
+
else:
|
643 |
+
batch_size = prompt_embeds.shape[0]
|
644 |
+
|
645 |
+
device = self._execution_device
|
646 |
+
|
647 |
+
# 3. Prepare text embeddings
|
648 |
+
(
|
649 |
+
prompt_embeds,
|
650 |
+
prompt_attention_mask,
|
651 |
+
negative_prompt_embeds,
|
652 |
+
negative_prompt_attention_mask,
|
653 |
+
) = self.encode_prompt(
|
654 |
+
prompt=prompt,
|
655 |
+
negative_prompt=negative_prompt,
|
656 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
657 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
658 |
+
prompt_embeds=prompt_embeds,
|
659 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
660 |
+
prompt_attention_mask=prompt_attention_mask,
|
661 |
+
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
662 |
+
max_sequence_length=max_sequence_length,
|
663 |
+
device=device,
|
664 |
+
)
|
665 |
+
if self.do_classifier_free_guidance:
|
666 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
667 |
+
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
|
668 |
+
|
669 |
+
if Q8Linear is not None and isinstance(self.transformer.transformer_blocks[0].attn1.to_q, Q8Linear):
|
670 |
+
prompt_attention_mask = prompt_attention_mask.to(torch.int64)
|
671 |
+
prompt_attention_mask = prompt_attention_mask.argmin(-1).int().squeeze()
|
672 |
+
prompt_attention_mask[prompt_attention_mask == 0] = max_sequence_length
|
673 |
+
|
674 |
+
# 4. Prepare latent variables
|
675 |
+
num_channels_latents = self.transformer.config.in_channels
|
676 |
+
latents = self.prepare_latents(
|
677 |
+
batch_size * num_videos_per_prompt,
|
678 |
+
num_channels_latents,
|
679 |
+
height,
|
680 |
+
width,
|
681 |
+
num_frames,
|
682 |
+
torch.float32,
|
683 |
+
device,
|
684 |
+
generator,
|
685 |
+
latents,
|
686 |
+
)
|
687 |
+
|
688 |
+
# 5. Prepare timesteps
|
689 |
+
latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
|
690 |
+
latent_height = height // self.vae_spatial_compression_ratio
|
691 |
+
latent_width = width // self.vae_spatial_compression_ratio
|
692 |
+
video_sequence_length = latent_num_frames * latent_height * latent_width
|
693 |
+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
694 |
+
mu = calculate_shift(
|
695 |
+
video_sequence_length,
|
696 |
+
self.scheduler.config.base_image_seq_len,
|
697 |
+
self.scheduler.config.max_image_seq_len,
|
698 |
+
self.scheduler.config.base_shift,
|
699 |
+
self.scheduler.config.max_shift,
|
700 |
+
)
|
701 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
702 |
+
self.scheduler,
|
703 |
+
num_inference_steps,
|
704 |
+
device,
|
705 |
+
timesteps,
|
706 |
+
sigmas=sigmas,
|
707 |
+
mu=mu,
|
708 |
+
)
|
709 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
710 |
+
self._num_timesteps = len(timesteps)
|
711 |
+
|
712 |
+
# 6. Prepare micro-conditions
|
713 |
+
latent_frame_rate = frame_rate / self.vae_temporal_compression_ratio
|
714 |
+
rope_interpolation_scale = (
|
715 |
+
1 / latent_frame_rate,
|
716 |
+
self.vae_spatial_compression_ratio,
|
717 |
+
self.vae_spatial_compression_ratio,
|
718 |
+
)
|
719 |
+
|
720 |
+
# 7. Denoising loop
|
721 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
722 |
+
for i, t in enumerate(timesteps):
|
723 |
+
if self.interrupt:
|
724 |
+
continue
|
725 |
+
|
726 |
+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
727 |
+
latent_model_input = latent_model_input.to(prompt_embeds.dtype)
|
728 |
+
|
729 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
730 |
+
timestep = t.expand(latent_model_input.shape[0])
|
731 |
+
|
732 |
+
noise_pred = self.transformer(
|
733 |
+
hidden_states=latent_model_input,
|
734 |
+
encoder_hidden_states=prompt_embeds,
|
735 |
+
timestep=timestep,
|
736 |
+
encoder_attention_mask=prompt_attention_mask,
|
737 |
+
num_frames=latent_num_frames,
|
738 |
+
height=latent_height,
|
739 |
+
width=latent_width,
|
740 |
+
rope_interpolation_scale=rope_interpolation_scale,
|
741 |
+
attention_kwargs=attention_kwargs,
|
742 |
+
return_dict=False,
|
743 |
+
)[0]
|
744 |
+
noise_pred = noise_pred.float()
|
745 |
+
|
746 |
+
if self.do_classifier_free_guidance:
|
747 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
748 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
749 |
+
|
750 |
+
# compute the previous noisy sample x_t -> x_t-1
|
751 |
+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
752 |
+
|
753 |
+
if callback_on_step_end is not None:
|
754 |
+
callback_kwargs = {}
|
755 |
+
for k in callback_on_step_end_tensor_inputs:
|
756 |
+
callback_kwargs[k] = locals()[k]
|
757 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
758 |
+
|
759 |
+
latents = callback_outputs.pop("latents", latents)
|
760 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
761 |
+
|
762 |
+
# call the callback, if provided
|
763 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
764 |
+
progress_bar.update()
|
765 |
+
|
766 |
+
if XLA_AVAILABLE:
|
767 |
+
xm.mark_step()
|
768 |
+
|
769 |
+
if output_type == "latent":
|
770 |
+
video = latents
|
771 |
+
else:
|
772 |
+
latents = self._unpack_latents(
|
773 |
+
latents,
|
774 |
+
latent_num_frames,
|
775 |
+
latent_height,
|
776 |
+
latent_width,
|
777 |
+
self.transformer_spatial_patch_size,
|
778 |
+
self.transformer_temporal_patch_size,
|
779 |
+
)
|
780 |
+
latents = self._denormalize_latents(
|
781 |
+
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
|
782 |
+
)
|
783 |
+
latents = latents.to(prompt_embeds.dtype)
|
784 |
+
|
785 |
+
if not self.vae.config.timestep_conditioning:
|
786 |
+
timestep = None
|
787 |
+
else:
|
788 |
+
noise = torch.randn(latents.shape, generator=generator, device=device, dtype=latents.dtype)
|
789 |
+
if not isinstance(decode_timestep, list):
|
790 |
+
decode_timestep = [decode_timestep] * batch_size
|
791 |
+
if decode_noise_scale is None:
|
792 |
+
decode_noise_scale = decode_timestep
|
793 |
+
elif not isinstance(decode_noise_scale, list):
|
794 |
+
decode_noise_scale = [decode_noise_scale] * batch_size
|
795 |
+
|
796 |
+
timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype)
|
797 |
+
decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[
|
798 |
+
:, None, None, None, None
|
799 |
+
]
|
800 |
+
latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
|
801 |
+
|
802 |
+
video = self.vae.decode(latents, timestep, return_dict=False)[0]
|
803 |
+
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
804 |
+
|
805 |
+
# Offload all models
|
806 |
+
self.maybe_free_model_hooks()
|
807 |
+
|
808 |
+
if not return_dict:
|
809 |
+
return (video,)
|
810 |
+
|
811 |
+
return LTXPipelineOutput(frames=video)
|
prompt_embeds.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2095fe12d573adf5adf6e2f0d163d494f792fe83deaeb6b20e67a28e6e44e913
|
3 |
+
size 8391808
|
q8_attention_processors.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Reference:
|
3 |
+
https://github.com/KONAKONA666/q8_kernels/blob/9cee3f3d4ca5ec8ab463179be32c8001e31f8f33/q8_kernels/modules/attention.py
|
4 |
+
"""
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import q8_kernels.functional as Q8F
|
8 |
+
from diffusers.models.transformers.transformer_ltx import apply_rotary_emb
|
9 |
+
from diffusers.models.attention import Attention
|
10 |
+
|
11 |
+
NON_MM_PRECISION_TYPE = torch.bfloat16
|
12 |
+
MM_PRECISION_TYPE = torch.bfloat16
|
13 |
+
|
14 |
+
|
15 |
+
class LTXVideoQ8AttentionProcessor:
|
16 |
+
def __call__(
|
17 |
+
self,
|
18 |
+
attn: Attention,
|
19 |
+
hidden_states: torch.Tensor,
|
20 |
+
encoder_hidden_states=None,
|
21 |
+
attention_mask=None,
|
22 |
+
image_rotary_emb=None,
|
23 |
+
) -> torch.Tensor:
|
24 |
+
if attention_mask is not None and attention_mask.ndim > 1:
|
25 |
+
attention_mask = attention_mask.argmin(-1).squeeze().int()
|
26 |
+
|
27 |
+
if encoder_hidden_states is None:
|
28 |
+
encoder_hidden_states = hidden_states
|
29 |
+
|
30 |
+
query = attn.to_q(hidden_states)
|
31 |
+
key = attn.to_k(encoder_hidden_states)
|
32 |
+
value = attn.to_v(encoder_hidden_states)
|
33 |
+
|
34 |
+
query = attn.norm_q(query, NON_MM_PRECISION_TYPE)
|
35 |
+
key = attn.norm_k(key, NON_MM_PRECISION_TYPE)
|
36 |
+
|
37 |
+
if image_rotary_emb is not None:
|
38 |
+
query = apply_rotary_emb(query, image_rotary_emb)
|
39 |
+
key = apply_rotary_emb(key, image_rotary_emb)
|
40 |
+
|
41 |
+
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
42 |
+
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
43 |
+
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
44 |
+
|
45 |
+
hidden_states = Q8F.flash_attention.flash_attn_func(
|
46 |
+
query, key, value, batch_mask=attention_mask, apply_qk_hadamard=True
|
47 |
+
)
|
48 |
+
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
|
49 |
+
|
50 |
+
hidden_states = attn.to_out[0](hidden_states)
|
51 |
+
hidden_states = attn.to_out[1](hidden_states)
|
52 |
+
return hidden_states.to(NON_MM_PRECISION_TYPE)
|
q8_ltx.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
References:
|
3 |
+
https://github.com/KONAKONA666/q8_kernels/blob/9cee3f3d4ca5ec8ab463179be32c8001e31f8f33/q8_kernels/utils/convert_weights.py
|
4 |
+
"""
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from q8_kernels.modules.rms_norm import RMSNorm as QRMSNorm
|
9 |
+
from diffusers.models.normalization import RMSNorm
|
10 |
+
from q8_kernels.modules.activations import GELU as QGELU
|
11 |
+
from diffusers.models.activations import GELU
|
12 |
+
from q8_kernels.modules.linear import Q8Linear
|
13 |
+
from q8_attention_processors import LTXVideoQ8AttentionProcessor
|
14 |
+
|
15 |
+
MODULES_TO_NOT_CONVERT = ["proj_in", "time_embed", "caption_projection", "proj_out"]
|
16 |
+
|
17 |
+
|
18 |
+
def replace_linear(model, current_key_name=None, replaced=False):
|
19 |
+
for name, child in model.named_children():
|
20 |
+
if current_key_name is None:
|
21 |
+
current_key_name = []
|
22 |
+
current_key_name.append(name)
|
23 |
+
|
24 |
+
if isinstance(child, nn.Linear) and name not in MODULES_TO_NOT_CONVERT:
|
25 |
+
# Check if the current key is not in the `modules_to_not_convert`
|
26 |
+
current_key_name_str = ".".join(current_key_name)
|
27 |
+
if not any(
|
28 |
+
(key + "." in current_key_name_str) or (key == current_key_name_str) for key in MODULES_TO_NOT_CONVERT
|
29 |
+
):
|
30 |
+
new_linear = Q8Linear(
|
31 |
+
child.in_features, child.out_features, bias=child.bias is not None, device=child.weight.device
|
32 |
+
)
|
33 |
+
setattr(model, name, new_linear)
|
34 |
+
replaced = True
|
35 |
+
else:
|
36 |
+
replace_linear(model=child, current_key_name=current_key_name, replaced=replaced)
|
37 |
+
|
38 |
+
current_key_name.pop(-1)
|
39 |
+
|
40 |
+
return model, replaced
|
41 |
+
|
42 |
+
|
43 |
+
def get_parent_module_and_attr(root, dotted_name: str):
|
44 |
+
"""
|
45 |
+
Splits 'a.b.c' into:
|
46 |
+
- parent module = root.a.b
|
47 |
+
- attr_name = 'c'
|
48 |
+
"""
|
49 |
+
parts = dotted_name.split(".")
|
50 |
+
*parent_parts, attr_name = parts
|
51 |
+
parent_module = root
|
52 |
+
for p in parent_parts:
|
53 |
+
parent_module = getattr(parent_module, p)
|
54 |
+
return parent_module, attr_name
|
55 |
+
|
56 |
+
|
57 |
+
def replace_rms_norm(model):
|
58 |
+
modules_to_replace = []
|
59 |
+
for dotted_name, module in model.named_modules():
|
60 |
+
if isinstance(module, RMSNorm):
|
61 |
+
modules_to_replace.append((dotted_name, module))
|
62 |
+
|
63 |
+
replaced = False
|
64 |
+
for dotted_name, module in modules_to_replace:
|
65 |
+
parent, attr_name = get_parent_module_and_attr(model, dotted_name)
|
66 |
+
new_norm = QRMSNorm(
|
67 |
+
dim=module.dim,
|
68 |
+
elementwise_affine=module.elementwise_affine,
|
69 |
+
)
|
70 |
+
setattr(parent, attr_name, new_norm)
|
71 |
+
replaced = True
|
72 |
+
|
73 |
+
return model, replaced
|
74 |
+
|
75 |
+
|
76 |
+
def replace_gelu(model, replaced=False):
|
77 |
+
for name, child in model.named_children():
|
78 |
+
if isinstance(child, GELU):
|
79 |
+
new_gelu = QGELU(
|
80 |
+
dim_in=child.proj.in_features,
|
81 |
+
dim_out=child.proj.out_features,
|
82 |
+
approximate=child.approximate,
|
83 |
+
bias=child.proj.bias is not None,
|
84 |
+
)
|
85 |
+
setattr(model, name, new_gelu)
|
86 |
+
replaced = True
|
87 |
+
else:
|
88 |
+
replace_gelu(model=child, replaced=replaced)
|
89 |
+
|
90 |
+
return model, replaced
|
91 |
+
|
92 |
+
|
93 |
+
def set_attn_processors(model, processor):
|
94 |
+
def fn_recursive_attn_processor(name, module: torch.nn.Module, processor):
|
95 |
+
if hasattr(module, "set_processor"):
|
96 |
+
module.set_processor(processor)
|
97 |
+
for sub_name, child in module.named_children():
|
98 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
99 |
+
|
100 |
+
for name, module in model.named_children():
|
101 |
+
fn_recursive_attn_processor(name, module, processor)
|
102 |
+
|
103 |
+
|
104 |
+
def attn_processors(model) -> dict:
|
105 |
+
processors = {}
|
106 |
+
|
107 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: dict):
|
108 |
+
if hasattr(module, "get_processor"):
|
109 |
+
processors[f"{name}.processor"] = module.get_processor()
|
110 |
+
|
111 |
+
for sub_name, child in module.named_children():
|
112 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
113 |
+
|
114 |
+
return processors
|
115 |
+
|
116 |
+
for name, module in model.named_children():
|
117 |
+
fn_recursive_add_processors(name, module, processors)
|
118 |
+
|
119 |
+
return processors
|
120 |
+
|
121 |
+
|
122 |
+
def check_transformer_replaced_correctly(model):
|
123 |
+
for block in model.transformer_blocks:
|
124 |
+
assert isinstance(block.attn1.to_q, Q8Linear), f"{type(block.attn1.to_q)=} not linear."
|
125 |
+
assert isinstance(block.attn2.to_q, Q8Linear), f"{type(block.attn2.to_q)=} not linear."
|
126 |
+
assert block.attn1.to_q.weight.dtype == torch.int8, f"{block.attn1.to_q.weight.dtype=}."
|
127 |
+
assert block.attn2.to_q.weight.dtype == torch.int8, f"{name=} {block.attn2.to_q.weight.dtype=}."
|
128 |
+
|
129 |
+
for name, module in model.named_modules():
|
130 |
+
if "norm" in name and "norm_out" not in name:
|
131 |
+
assert isinstance(module, QRMSNorm), f"{name=}, {type(module)=}"
|
132 |
+
|
133 |
+
for block in model.transformer_blocks:
|
134 |
+
assert isinstance(block.ff.net[0], QGELU), f"{type(block.ff.net[0])=}"
|
135 |
+
if getattr(block.ff.net[0], "proj", None) is not None:
|
136 |
+
assert block.ff.net[0].proj.weight.dtype == torch.int8, f"{block.ff.net[0].proj.weight.dtype=}."
|
137 |
+
|
138 |
+
set_attn_processors(model, LTXVideoQ8AttentionProcessor())
|
139 |
+
all_attn_processors = attn_processors(model)
|
140 |
+
for k, v in all_attn_processors.items():
|
141 |
+
assert isinstance(v, LTXVideoQ8AttentionProcessor), f"{name} is not of type LTXVideoQ8AttentionProcessor."
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
diffusers
|
3 |
+
transformers
|
4 |
+
accelerate
|
5 |
+
imageio
|
6 |
+
Pillow
|
7 |
+
-f git+git://github.com/KONAKONA666/q8_kernels.git
|