sayakpaul HF staff commited on
Commit
8cf98bd
·
verified ·
1 Parent(s): 81d9422

Upload 14 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/mountain.mp4 filter=lfs diff=lfs merge=lfs -text
37
+ assets/river.mp4 filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,13 +1,132 @@
1
- ---
2
- title: Q8 Ltx Video
3
- emoji: 👀
4
- colorFrom: red
5
- colorTo: indigo
6
- sdk: gradio
7
- sdk_version: 5.12.0
8
- app_file: app.py
9
- pinned: false
10
- short_description: Generate videos with LTX-Video fast!
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Q8 LTX-Video optimized for Ada
2
+
3
+ This repository shows how to use the Q8 kernels from [`KONAKONA666/q8_kernels`](https://github.com/KONAKONA666/q8_kernels) with `diffusers` to optimize inference of [LTX-Video](https://huggingface.co/Lightricks/LTX-Video) on ADA GPUs. Go from 16.192 secs to 9.572 secs while reducing memory from 7GBs to 5GBs without quality loss 🤪 With `torch.compile()`, the time reduces further to 6.747 secs 🔥
4
+
5
+ The Q8 transformer checkpoint is available here: [`sayakpaul/q8-ltx-video`](https://hf.co/sayakpaul/q8-ltx-video).
6
+
7
+ ## Getting started
8
+
9
+ Install the dependencies:
10
+
11
+ ```bash
12
+ pip install -U transformers accelerate
13
+ git clone https://github.com/huggingface/diffusers && cd diffusers && pip install -e . && cd ..
14
+ ```
15
+
16
+ Then install `q8_kernels`, following instructions from [here](https://github.com/KONAKONA666/q8_kernels/?tab=readme-ov-file#installation).
17
+
18
+ To run inference with the Q8 kernels, we need some minor changes in `diffusers`. Apply [this patch](https://github.com/sayakpaul/q8-ltx-video/blob/368f549ca5136daf89049c9efe32748e73aca317/updates.patch) to take those into account:
19
+
20
+ ```bash
21
+ git apply updates.patch
22
+ ```
23
+
24
+ Now we can run inference:
25
+
26
+ ```bash
27
+ python inference.py \
28
+ --prompt="A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage"
29
+ --negative_prompt="worst quality, inconsistent motion, blurry, jittery, distorted" \
30
+ --q8_transformer_path="sayakpaul/q8-ltx-video"
31
+ ```
32
+
33
+ ## Why does the repo exist and some more details?
34
+
35
+ There already exists [`KONAKONA666/LTX-Video`](https://github.com/KONAKONA666/LTX-Video). Then why this repo?
36
+
37
+ That repo uses custom implementations of the LTX-Video pipeline components and can be hard to directly use in `diffusers`. This repo repurposes the kernels from the `q8_kernels` on the components directly from `diffusers`.
38
+
39
+ <details>
40
+ <summary>More details</summary>
41
+
42
+ We do this by first converting the state dict of the original [LTX-Video transformer](https://huggingface.co/Lightricks/LTX-Video/tree/main/transformer). This includes FP8 quantization. This process also requires replacing:
43
+
44
+ * linear layers of the model
45
+ * RMSNorms of the model
46
+ * GELUs of the model
47
+
48
+ before the converted state dict is loaded into the model. Some layer params are kept in FP32 and some layers are not even quantized. Replacement utilities are in [`q8_ltx.py`](./q8_ltx.py).
49
+
50
+ The model can then be serialized. The conversion and serialization are coded in [`conversion_utils.py`](./conversion_utils.py).
51
+
52
+ During loading the model and using it for inference, we:
53
+
54
+ * initialize the transformer model under a "meta" device
55
+ * follow the same layer replacement scheme as detailed above
56
+ * populate the converted state dict
57
+ * replace the attention processors to use [the flash attention implementation](https://github.com/KONAKONA666/q8_kernels/blob/9cee3f3d4ca5ec8ab463179be32c8001e31f8f33/q8_kernels/functional/flash_attention.py) one from `q8_kernels`
58
+
59
+ Refer [here](https://github.com/sayakpaul/q8-ltx-video/blob/368f549ca5136daf89049c9efe32748e73aca317/inference.py#L48) more details. Additionally, we leverage [flash-attention implementation](https://github.com/sayakpaul/q8-ltx-video/blob/368f549ca5136daf89049c9efe32748e73aca317/q8_attention_processors.py#L44) from `q8_kernels` which provides further speedup.
60
+
61
+ </details>
62
+
63
+ ## Performance
64
+
65
+ Below numbers were obtained for `max_sequence_length=512`, `num_inference_steps=50`, `num_frames=81`, `resolution=480x704`. Rest of the arguments were fixed at their default values as noticed in the [pipeline call signature of LTX-Video](https://github.com/huggingface/diffusers/blob/4b9f1c7d8c2e476eed38af3144b79105a5efcd93/src/diffusers/pipelines/ltx/pipeline_ltx.py#L496). The numbers also don't include the VAE decoding time to solely focus on the transformer.
66
+
67
+
68
+ | | **Time (Secs)** | **Memory (MB)** |
69
+ |:-----------:|:-----------:|:-----------:|
70
+ | Non Q8 | 16.192 | 7172.86 |
71
+ | Non Q8 (+ compile) | 16.205 | - |
72
+ | Q8 | 9.572 | 5413.51 |
73
+ | Q8 (+ compile) | 6.747 | - |
74
+
75
+ Benchmarking script is available in [`benchmark.py`](./benchmark.py). You would need to download the precomputed
76
+ prompt embeddings from [here](https://huggingface.co/sayakpaul/q8-ltx-video/blob/main/prompt_embeds.pt) before running the benchmark.
77
+
78
+ <details>
79
+ <summary>Env</summary>
80
+
81
+ ```bash
82
+ +-----------------------------------------------------------------------------------------+
83
+ | NVIDIA-SMI 560.35.05 Driver Version: 560.35.05 CUDA Version: 12.6 |
84
+ |-----------------------------------------+------------------------+----------------------+
85
+ | GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
86
+ | Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
87
+ | | | MIG M. |
88
+ |=========================================+========================+======================|
89
+ | 0 NVIDIA GeForce RTX 4090 Off | 00000000:01:00.0 Off | Off |
90
+ | 0% 46C P8 18W / 450W | 2MiB / 24564MiB | 0% Default |
91
+ | | | N/A |
92
+ +-----------------------------------------+------------------------+----------------------+
93
+ ```
94
+
95
+ `diffusers-cli env`:
96
+
97
+ ```bash
98
+ - 🤗 Diffusers version: 0.33.0.dev0
99
+ - Platform: Linux-6.8.0-49-generic-x86_64-with-glibc2.39
100
+ - Running on Google Colab?: No
101
+ - Python version: 3.10.12
102
+ - PyTorch version (GPU?): 2.5.1+cu124 (True)
103
+ - Flax version (CPU?/GPU?/TPU?): not installed (NA)
104
+ - Jax version: not installed
105
+ - JaxLib version: not installed
106
+ - Huggingface_hub version: 0.27.0
107
+ - Transformers version: 4.47.1
108
+ - Accelerate version: 1.2.1
109
+ - PEFT version: 0.13.2
110
+ - Bitsandbytes version: 0.44.1
111
+ - Safetensors version: 0.4.4
112
+ - xFormers version: 0.0.29.post1
113
+ - Accelerator: NVIDIA GeForce RTX 4090, 24564 MiB
114
+ NVIDIA GeForce RTX 4090, 24564 MiB
115
+ - Using GPU in script?: <fill in>
116
+ - Using distributed or parallel set-up in script?: <fill in>
117
+ ```
118
+
119
+ </details>
120
+
121
+ > [!NOTE]
122
+ > The RoPE implementation from `q8_kernels` [isn't usable as of 1st Jan 2025](https://github.com/KONAKONA666/q8_kernels/blob/9cee3f3d4ca5ec8ab463179be32c8001e31f8f33/q8_kernels/functional/rope.py#L26). So, we resort to using [the one](https://github.com/huggingface/diffusers/blob/91008aabc4b8dbd96a356ab6f457f3bd84b10e8b/src/diffusers/models/transformers/transformer_ltx.py#L464) from `diffusers`.
123
+
124
+
125
+ ## Comparison
126
+
127
+ Check out [this page](https://wandb.ai/sayakpaul/q8-ltx-video/runs/89h6ac5) on Weights and Biases that provides some comparative results. Generated videos are also available [here](./videos/).
128
+
129
+ ## Acknowledgement
130
+
131
+ KONAKONA666's works on [`KONAKONA666/q8_kernels`](https://github.com/KONAKONA666/q8_kernels) and [KONAKONA666/LTX-Video](https://github.com/KONAKONA666/LTX-Video).
132
+
app.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from app_utils import prepare_pipeline, compute_hash
3
+ import os
4
+ from diffusers.utils import export_to_video
5
+ import tempfile
6
+ import torch
7
+ from inference import load_text_encoding_pipeline
8
+
9
+
10
+ text_encoding_pipeline = load_text_encoding_pipeline()
11
+ inference_pipeline = prepare_pipeline()
12
+
13
+
14
+ def create_advanced_options():
15
+ with gr.Accordion("Advanced Options (Optional)", open=False):
16
+ seed = gr.Slider(label="Seed", minimum=0, maximum=1000000, step=1, value=646373)
17
+ inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=50, step=1, value=30)
18
+ guidance_scale = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=5.0, step=0.1, value=3.0)
19
+ max_sequence_length = gr.Slider(label="Maximum sequence length", minimum=128, maximum=512, step=1, value=128)
20
+ fps = gr.Slider(label="FPS", minimum=21, maximum=30, step=1, value=24)
21
+ return [
22
+ seed,
23
+ inference_steps,
24
+ guidance_scale,
25
+ max_sequence_length,
26
+ fps
27
+ ]
28
+
29
+
30
+ @torch.no_grad()
31
+ def generate_video_from_text(prompt, negative_prompt, seed, steps, guidance_scale, max_sequence_length, fps):
32
+ global text_encoding_pipeline
33
+ prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask = text_encoding_pipeline.encode_prompt(
34
+ prompt=prompt, negative_prompt=negative_prompt, max_sequence_length=max_sequence_length
35
+ )
36
+ global inference_pipeline
37
+ video = inference_pipeline(
38
+ prompt_embeds=prompt_embeds,
39
+ prompt_attention_mask=prompt_attention_mask,
40
+ negative_prompt_embeds=negative_prompt_embeds,
41
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
42
+ guidance_scale=guidance_scale,
43
+ width=768,
44
+ height=512,
45
+ num_frames=121,
46
+ num_inference_steps=steps,
47
+ max_sequence_length=max_sequence_length,
48
+ generator=torch.manual_seed(seed),
49
+ ).frames[0]
50
+ out_path = tempfile.mkstemp(suffix=".mp4")
51
+ export_to_video(video, out_path, fps=fps)
52
+ return out_path
53
+
54
+
55
+ with gr.Blocks(theme=gr.themes.Soft()) as iface:
56
+ with gr.Row(elem_id="title-row"):
57
+ gr.Markdown(
58
+ """
59
+ <div style="text-align: center; margin-bottom: 1em">
60
+ <h1 style="font-size: 2.5em; font-weight: 600; margin: 0.5em 0;">Fast Video Generation with <a href="https://github.com/sayakpaul/q8-ltx-video">Q8 LTX Video</a></h1>
61
+ </div>
62
+ """
63
+ )
64
+ with gr.Row(elem_id="title-row"):
65
+ gr.HTML( # add technical report link
66
+ """
67
+ <div style="display:flex;column-gap:4px;">
68
+ <span>This space is modified from the original <a href="https://huggingface.co/spaces/Lightricks/LTX-Video-Playground">LTX-Video playground</a>. It uses optimized Q8 kernels along with torch.compile to allow for ultra-fast video
69
+ generation. As a result, it restricts generations to 121 frames with 512x768 resolution. For more details, refer to <a href="https://github.com/sayakpaul/q8-ltx-video">this link</a>.
70
+ </div>
71
+ """
72
+ )
73
+ with gr.Accordion(" 📖 Tips for Best Results", open=False, elem_id="instructions-accordion"):
74
+ gr.Markdown(
75
+ """
76
+ 📝 Prompt Engineering
77
+ When writing prompts, focus on detailed, chronological descriptions of actions and scenes. Include specific movements, appearances, camera angles, and environmental details - all in a single flowing paragraph. Start directly with the action, and keep descriptions literal and precise. Think like a cinematographer describing a shot list. Keep within 200 words.
78
+ For best results, build your prompts using this structure:
79
+ - Start with main action in a single sentence
80
+ - Add specific details about movements and gestures
81
+ - Describe character/object appearances precisely
82
+ - Include background and environment details
83
+ - Specify camera angles and movements
84
+ - Describe lighting and colors
85
+ - Note any changes or sudden events
86
+ See examples for more inspiration.
87
+ 🎮 Parameter Guide
88
+ - Resolution Preset: Higher resolutions for detailed scenes, lower for faster generation and simpler scenes
89
+ - Seed: Save seed values to recreate specific styles or compositions you like
90
+ - Guidance Scale: 3-3.5 are the recommended values
91
+ - Inference Steps: More steps (40+) for quality, fewer steps (20-30) for speed
92
+ - When using detailed prompts, use a higher `max_sequence_length` value.
93
+ """
94
+ )
95
+
96
+ with gr.Tabs():
97
+ # Text to Video Tab
98
+ with gr.TabItem("Text to Video"):
99
+ with gr.Row():
100
+ with gr.Column():
101
+ txt2vid_prompt = gr.Textbox(
102
+ label="Enter Your Prompt",
103
+ placeholder="Describe the video you want to generate (minimum 50 characters)...",
104
+ value="A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage.",
105
+ lines=5,
106
+ )
107
+
108
+ txt2vid_negative_prompt = gr.Textbox(
109
+ label="Enter Negative Prompt",
110
+ placeholder="Describe what you don't want in the video...",
111
+ value="low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly",
112
+ lines=2,
113
+ )
114
+ txt2vid_advanced = create_advanced_options()
115
+
116
+ txt2vid_generate = gr.Button(
117
+ "Generate Video",
118
+ variant="primary",
119
+ size="lg",
120
+ )
121
+
122
+ with gr.Column():
123
+ txt2vid_output = gr.Video(label="Generated Output")
124
+
125
+ with gr.Row():
126
+ gr.Examples(
127
+ examples=[
128
+ [
129
+ "A clear, turquoise river flows through a rocky canyon, cascading over a small waterfall and forming a pool of water at the bottom.The river is the main focus of the scene, with its clear water reflecting the surrounding trees and rocks. The canyon walls are steep and rocky, with some vegetation growing on them. The trees are mostly pine trees, with their green needles contrasting with the brown and gray rocks. The overall tone of the scene is one of peace and tranquility.",
130
+ "low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly",
131
+ "assets/river.mp4",
132
+ ],
133
+ [
134
+ "The camera pans over a snow-covered mountain range, revealing a vast expanse of snow-capped peaks and valleys.The mountains are covered in a thick layer of snow, with some areas appearing almost white while others have a slightly darker, almost grayish hue. The peaks are jagged and irregular, with some rising sharply into the sky while others are more rounded. The valleys are deep and narrow, with steep slopes that are also covered in snow. The trees in the foreground are mostly bare, with only a few leaves remaining on their branches. The sky is overcast, with thick clouds obscuring the sun. The overall impression is one of peace and tranquility, with the snow-covered mountains standing as a testament to the power and beauty of nature.",
135
+ "low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly",
136
+ "assets/mountain.mp4",
137
+ ],
138
+ ],
139
+ inputs=[txt2vid_prompt, txt2vid_negative_prompt, txt2vid_output],
140
+ label="Example Text-to-Video Generations",
141
+ )
142
+
143
+ txt2vid_generate.click(
144
+ fn=generate_video_from_text,
145
+ inputs=[
146
+ txt2vid_prompt,
147
+ txt2vid_negative_prompt,
148
+ *txt2vid_advanced,
149
+ ],
150
+ outputs=txt2vid_output,
151
+ concurrency_limit=1,
152
+ concurrency_id="generate_video_from_text",
153
+ )
154
+
155
+ if __name__ == "__main__":
156
+ iface.queue(max_size=64, default_concurrency_limit=1, api_open=False).launch(share=True, show_api=False)
app_utils.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from inference import load_q8_transformer
2
+ import hashlib
3
+ from q8_kernels.graph.graph import make_dynamic_graphed_callable
4
+ from argparse import Namespace
5
+ from diffusers import LTXPipeline
6
+ import types
7
+ import torch
8
+
9
+ # To account for the type-casting in `ff_output` of `LTXVideoTransformerBlock`
10
+ def patched_ltx_transformer_forward(
11
+ self,
12
+ hidden_states: torch.Tensor,
13
+ encoder_hidden_states: torch.Tensor,
14
+ temb: torch.Tensor,
15
+ image_rotary_emb = None,
16
+ encoder_attention_mask = None,
17
+ ) -> torch.Tensor:
18
+ batch_size = hidden_states.size(0)
19
+ norm_hidden_states = self.norm1(hidden_states)
20
+
21
+ num_ada_params = self.scale_shift_table.shape[0]
22
+ ada_values = self.scale_shift_table[None, None] + temb.reshape(batch_size, temb.size(1), num_ada_params, -1)
23
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
24
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
25
+
26
+ attn_hidden_states = self.attn1(
27
+ hidden_states=norm_hidden_states,
28
+ encoder_hidden_states=None,
29
+ image_rotary_emb=image_rotary_emb,
30
+ )
31
+ hidden_states = hidden_states + attn_hidden_states * gate_msa
32
+
33
+ attn_hidden_states = self.attn2(
34
+ hidden_states,
35
+ encoder_hidden_states=encoder_hidden_states,
36
+ image_rotary_emb=None,
37
+ attention_mask=encoder_attention_mask,
38
+ )
39
+ hidden_states = hidden_states + attn_hidden_states
40
+ norm_hidden_states = self.norm2(hidden_states) * (1 + scale_mlp) + shift_mlp
41
+
42
+ ff_output = self.ff(norm_hidden_states).to(norm_hidden_states.dtype)
43
+ hidden_states = hidden_states + ff_output * gate_mlp
44
+
45
+ return hidden_states
46
+
47
+ def load_transformer():
48
+ args = Namespace()
49
+ args.q8_transformer_path = "sayakpaul/q8-ltx-video"
50
+ transformer = load_q8_transformer(args)
51
+
52
+ transformer.to(torch.bfloat16)
53
+ for b in transformer.transformer_blocks:
54
+ b.to(dtype=torch.float)
55
+
56
+ for n, m in transformer.transformer_blocks.named_parameters():
57
+ if "scale_shift_table" in n:
58
+ m.data = m.data.to(torch.bfloat16)
59
+
60
+ for b in transformer.transformer_blocks:
61
+ b.forward = types.MethodType(patched_ltx_transformer_forward, b)
62
+
63
+ transformer.forward = make_dynamic_graphed_callable(transformer.forward)
64
+ return transformer
65
+
66
+ def warmup_transformer(pipe):
67
+ prompt_embeds = torch.load("prompt_embeds.pt", map_location="cuda", weights_only=True)
68
+ for _ in range(5):
69
+ _ = pipe(
70
+ **prompt_embeds,
71
+ output_type="latent",
72
+ width=768,
73
+ height=512,
74
+ num_frames=121
75
+ )
76
+
77
+ def prepare_pipeline():
78
+ pipe = LTXPipeline.from_pretrained("Lightricks/LTX-Video", text_encoder=None, torch_dtype=torch.bfloat16)
79
+ pipe.transformer = load_transformer()
80
+ pipe = pipe.to("cuda")
81
+ pipe.transformer.compile()
82
+ pipe.set_progress_bar_config(disable=True)
83
+
84
+ warmup_transformer(pipe)
85
+ return pipe
86
+
87
+
88
+ def compute_hash(text: str) -> str:
89
+ # Encode the text to bytes
90
+ text_bytes = text.encode("utf-8")
91
+
92
+ # Create a SHA-256 hash object
93
+ hash_object = hashlib.sha256()
94
+
95
+ # Update the hash object with the text bytes
96
+ hash_object.update(text_bytes)
97
+
98
+ # Return the hexadecimal representation of the hash
99
+ return hash_object.hexdigest()
assets/mountain.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:540d193b2819996741ae23cf0bf106e1c3b226880b44ae1e5e54a8f38a181e9f
3
+ size 1463129
assets/river.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:56f21a785ac95e20e0933e553a1f5f82946737f4077ff4d71bf11de5b0ba5162
3
+ size 1788643
assets/woman.mp4 ADDED
Binary file (342 kB). View file
 
conversion_utils.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ References:
3
+ https://github.com/KONAKONA666/q8_kernels/blob/9cee3f3d4ca5ec8ab463179be32c8001e31f8f33/q8_kernels/utils/convert_weights.py
4
+ """
5
+
6
+ import torch
7
+ from q8_ltx import replace_gelu, replace_linear, replace_rms_norm, MODULES_TO_NOT_CONVERT
8
+ import argparse
9
+ from diffusers import LTXVideoTransformer3DModel
10
+ from q8_kernels.functional.quantizer import quantize
11
+ from q8_kernels.functional.fast_hadamard import hadamard_transform
12
+
13
+
14
+ def convert_state_dict(orig_state_dict):
15
+ prefix = "transformer_blocks"
16
+ transformer_block_keys = []
17
+ non_transformer_block_keys = []
18
+ for k in orig_state_dict:
19
+ if prefix in k:
20
+ transformer_block_keys.append(k)
21
+ else:
22
+ non_transformer_block_keys.append(k)
23
+ attn_keys = []
24
+ ffn_keys = []
25
+ scale_shift_keys = []
26
+ for k in transformer_block_keys:
27
+ if "attn" in k:
28
+ attn_keys.append(k)
29
+ for k in transformer_block_keys:
30
+ if "ff" in k:
31
+ ffn_keys.append(k)
32
+ for k in transformer_block_keys:
33
+ if "scale_shift_table" in k:
34
+ scale_shift_keys.append(k)
35
+
36
+ assert len(attn_keys + ffn_keys + scale_shift_keys) == len(transformer_block_keys), "error"
37
+
38
+ new_state_dict = {}
39
+ for k in attn_keys:
40
+ new_key = k
41
+ if "norm" in k and "weight" in k:
42
+ new_state_dict[new_key] = orig_state_dict[k].float()
43
+ elif "bias" in k:
44
+ new_state_dict[new_key] = orig_state_dict[k].float()
45
+ elif "weight" in k:
46
+ w_quant, w_scales = quantize(hadamard_transform(orig_state_dict[k].cuda().to(torch.bfloat16)))
47
+ assert w_quant.dtype == torch.int8, k
48
+ new_state_dict[new_key] = w_quant
49
+ new_state_dict[new_key.replace("weight", "scales")] = w_scales
50
+
51
+ for k in ffn_keys:
52
+ new_key = k
53
+
54
+ if "bias" in k:
55
+ new_state_dict[new_key] = orig_state_dict[k].float()
56
+ elif "weight" in k:
57
+ w_quant, w_scales = quantize(hadamard_transform(orig_state_dict[k].cuda().to(torch.bfloat16)))
58
+ assert w_quant.dtype == torch.int8, k
59
+ new_state_dict[new_key] = w_quant
60
+ new_state_dict[new_key.replace("weight", "scales")] = w_scales
61
+
62
+ for k in scale_shift_keys:
63
+ new_state_dict[k] = orig_state_dict[k]
64
+
65
+ for k in non_transformer_block_keys:
66
+ new_state_dict[k] = orig_state_dict[k]
67
+
68
+ return new_state_dict
69
+
70
+
71
+ @torch.no_grad()
72
+ def main(args):
73
+ transformer = LTXVideoTransformer3DModel.from_pretrained(args.input_path, subfolder="transformer").to("cuda")
74
+ new_state_dict = convert_state_dict(transformer.state_dict())
75
+ transformer = replace_gelu(transformer)[0]
76
+ transformer = replace_linear(transformer)[0]
77
+ transformer = replace_rms_norm(transformer)[0]
78
+
79
+ m, u = transformer.load_state_dict(new_state_dict, strict=True)
80
+ for name, module in transformer.named_modules():
81
+ if any(n in name for n in MODULES_TO_NOT_CONVERT):
82
+ if hasattr(module, "weight"):
83
+ assert module.weight.dtype == torch.float32
84
+ elif hasattr(module, "linear"):
85
+ assert module.linear.weight.dtype == torch.float32
86
+ elif getattr(module, "weight", None) is not None:
87
+ print(f"Non FP32 {name=} {module.weight.dtype=}")
88
+ if "to_" in name:
89
+ assert module.weight.dtype != torch.float32, f"{name=}, {module.weight.dtype=}"
90
+
91
+ transformer.save_pretrained(args.output_path)
92
+ print(f"Model saved in {args.output_path}")
93
+
94
+
95
+ if __name__ == "__main__":
96
+ parser = argparse.ArgumentParser()
97
+
98
+ parser.add_argument("--input_path", type=str, required=True)
99
+ parser.add_argument("--output_path", type=str, required=True)
100
+
101
+ args = parser.parse_args()
102
+
103
+ main(args)
inference.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import LTXPipeline, LTXVideoTransformer3DModel
2
+ from huggingface_hub import hf_hub_download
3
+ import argparse
4
+ import os
5
+ from q8_ltx import check_transformer_replaced_correctly, replace_gelu, replace_linear, replace_rms_norm
6
+ import safetensors.torch
7
+ from q8_kernels.graph.graph import make_dynamic_graphed_callable
8
+ import torch
9
+ import gc
10
+ from diffusers.utils import export_to_video
11
+
12
+
13
+ # Taken from
14
+ # https://github.com/KONAKONA666/LTX-Video/blob/c8462ed2e359cda4dec7f49d98029994e850dc90/inference.py#L115C1-L138C28
15
+ def convert_prompt_to_filename(text: str, max_len: int = 20) -> str:
16
+ # Remove non-letters and convert to lowercase
17
+ clean_text = "".join(char.lower() for char in text if char.isalpha() or char.isspace())
18
+ # Split into words
19
+ words = clean_text.split()
20
+
21
+ # Build result string keeping track of length
22
+ result = []
23
+ current_length = 0
24
+
25
+ for word in words:
26
+ # Add word length plus 1 for underscore (except for first word)
27
+ new_length = current_length + len(word)
28
+ if new_length <= max_len:
29
+ result.append(word)
30
+ current_length += len(word)
31
+ else:
32
+ break
33
+
34
+ return "-".join(result)
35
+
36
+
37
+ def load_text_encoding_pipeline():
38
+ return LTXPipeline.from_pretrained(
39
+ "Lightricks/LTX-Video", transformer=None, vae=None, torch_dtype=torch.bfloat16
40
+ ).to("cuda")
41
+
42
+
43
+ def encode_prompt(pipe, prompt, negative_prompt, max_sequence_length=128):
44
+ prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask = pipe.encode_prompt(
45
+ prompt=prompt, negative_prompt=negative_prompt, max_sequence_length=max_sequence_length
46
+ )
47
+ return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
48
+
49
+
50
+ def load_q8_transformer(args):
51
+ with torch.device("meta"):
52
+ transformer_config = LTXVideoTransformer3DModel.load_config("Lightricks/LTX-Video", subfolder="transformer")
53
+ transformer = LTXVideoTransformer3DModel.from_config(transformer_config)
54
+
55
+ transformer = replace_gelu(transformer)[0]
56
+ transformer = replace_linear(transformer)[0]
57
+ transformer = replace_rms_norm(transformer)[0]
58
+
59
+ if os.path.isfile(f"{args.q8_transformer_path}/diffusion_pytorch_model.safetensors"):
60
+ state_dict = safetensors.torch.load_file(f"{args.q8_transformer_path}/diffusion_pytorch_model.safetensors")
61
+ else:
62
+ state_dict = safetensors.torch.load_file(
63
+ hf_hub_download(args.q8_transformer_path, "diffusion_pytorch_model.safetensors")
64
+ )
65
+ transformer.load_state_dict(state_dict, strict=True, assign=True)
66
+ check_transformer_replaced_correctly(transformer)
67
+ return transformer
68
+
69
+
70
+ @torch.no_grad()
71
+ def main(args):
72
+ text_encoding_pipeline = load_text_encoding_pipeline()
73
+ prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask = encode_prompt(
74
+ pipe=text_encoding_pipeline,
75
+ prompt=args.prompt,
76
+ negative_prompt=args.negative_prompt,
77
+ max_sequence_length=args.max_sequence_length,
78
+ )
79
+ del text_encoding_pipeline
80
+ torch.cuda.empty_cache()
81
+ torch.cuda.reset_peak_memory_stats()
82
+ gc.collect()
83
+
84
+ if args.q8_transformer_path:
85
+ transformer = load_q8_transformer(args)
86
+ pipe = LTXPipeline.from_pretrained("Lightricks/LTX-Video", transformer=None, text_encoder=None)
87
+ pipe.transformer = transformer
88
+
89
+ pipe.transformer = pipe.transformer.to(torch.bfloat16)
90
+ for b in pipe.transformer.transformer_blocks:
91
+ b.to(dtype=torch.float)
92
+
93
+ for n, m in pipe.transformer.transformer_blocks.named_parameters():
94
+ if "scale_shift_table" in n:
95
+ m.data = m.data.to(torch.bfloat16)
96
+
97
+ pipe.transformer.forward = make_dynamic_graphed_callable(pipe.transformer.forward)
98
+ pipe.vae = pipe.vae.to(torch.bfloat16)
99
+
100
+ else:
101
+ pipe = LTXPipeline.from_pretrained("Lightricks/LTX-Video", text_encoder=None, torch_dtype=torch.bfloat16)
102
+
103
+ pipe = pipe.to("cuda")
104
+
105
+ width, height = args.resolution.split("x")[::-1]
106
+ video = pipe(
107
+ prompt_embeds=prompt_embeds,
108
+ prompt_attention_mask=prompt_attention_mask,
109
+ negative_prompt_embeds=negative_prompt_embeds,
110
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
111
+ width=int(width),
112
+ height=int(height),
113
+ num_frames=args.num_frames,
114
+ num_inference_steps=args.steps,
115
+ max_sequence_length=args.max_sequence_length,
116
+ generator=torch.manual_seed(2025),
117
+ ).frames[0]
118
+ print(f"Max memory: {torch.cuda.max_memory_allocated() / 1024 / 1024} MB.")
119
+
120
+ if args.out_path is None:
121
+ filename_from_prompt = convert_prompt_to_filename(args.prompt, max_len=30)
122
+ base_filename = f"{filename_from_prompt}_{args.num_frames}x{height}x{width}"
123
+ base_filename += "_q8" if args.q8_transformer_path is not None else ""
124
+ args.out_path = base_filename + ".mp4"
125
+ export_to_video(video, args.out_path, fps=24)
126
+
127
+
128
+ if __name__ == "__main__":
129
+ parser = argparse.ArgumentParser()
130
+ parser.add_argument("--q8_transformer_path", type=str, default=None)
131
+ parser.add_argument("--prompt", type=str)
132
+ parser.add_argument("--negative_prompt", type=str, default=None)
133
+ parser.add_argument("--num_frames", type=int, default=81)
134
+ parser.add_argument("--resolution", type=str, default="480x704")
135
+ parser.add_argument("--steps", type=int, default=50)
136
+ parser.add_argument("--max_sequence_length", type=int, default=512)
137
+ parser.add_argument("--out_path", type=str, default=None)
138
+ args = parser.parse_args()
139
+ main(args)
pipeline_ltx.py ADDED
@@ -0,0 +1,811 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Lightricks and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ To account for certain changes pertaining to Q8Linear.
17
+ """
18
+
19
+ import inspect
20
+ from typing import Any, Callable, Dict, List, Optional, Union
21
+
22
+ import numpy as np
23
+ import torch
24
+ from transformers import T5EncoderModel, T5TokenizerFast
25
+
26
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
27
+ from diffusers.loaders import FromSingleFileMixin, LTXVideoLoraLoaderMixin
28
+ from diffusers.models.autoencoders import AutoencoderKLLTXVideo
29
+ from diffusers.models.transformers import LTXVideoTransformer3DModel
30
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
31
+ from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring
32
+ from diffusers.utils.torch_utils import randn_tensor
33
+ from diffusers.video_processor import VideoProcessor
34
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
35
+ from diffusers.pipelines.ltx.pipeline_output import LTXPipelineOutput
36
+
37
+ try:
38
+ import q8_kernels # noqa
39
+ from q8_kernels.modules.linear import Q8Linear
40
+ except:
41
+ Q8Linear = None
42
+
43
+
44
+ if is_torch_xla_available():
45
+ import torch_xla.core.xla_model as xm
46
+
47
+ XLA_AVAILABLE = True
48
+ else:
49
+ XLA_AVAILABLE = False
50
+
51
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
52
+
53
+ EXAMPLE_DOC_STRING = """
54
+ Examples:
55
+ ```py
56
+ >>> import torch
57
+ >>> from diffusers import LTXPipeline
58
+ >>> from diffusers.utils import export_to_video
59
+
60
+ >>> pipe = LTXPipeline.from_pretrained("Lightricks/LTX-Video", torch_dtype=torch.bfloat16)
61
+ >>> pipe.to("cuda")
62
+
63
+ >>> prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage"
64
+ >>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
65
+
66
+ >>> video = pipe(
67
+ ... prompt=prompt,
68
+ ... negative_prompt=negative_prompt,
69
+ ... width=704,
70
+ ... height=480,
71
+ ... num_frames=161,
72
+ ... num_inference_steps=50,
73
+ ... ).frames[0]
74
+ >>> export_to_video(video, "output.mp4", fps=24)
75
+ ```
76
+ """
77
+
78
+
79
+ # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
80
+ def calculate_shift(
81
+ image_seq_len,
82
+ base_seq_len: int = 256,
83
+ max_seq_len: int = 4096,
84
+ base_shift: float = 0.5,
85
+ max_shift: float = 1.16,
86
+ ):
87
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
88
+ b = base_shift - m * base_seq_len
89
+ mu = image_seq_len * m + b
90
+ return mu
91
+
92
+
93
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
94
+ def retrieve_timesteps(
95
+ scheduler,
96
+ num_inference_steps: Optional[int] = None,
97
+ device: Optional[Union[str, torch.device]] = None,
98
+ timesteps: Optional[List[int]] = None,
99
+ sigmas: Optional[List[float]] = None,
100
+ **kwargs,
101
+ ):
102
+ r"""
103
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
104
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
105
+
106
+ Args:
107
+ scheduler (`SchedulerMixin`):
108
+ The scheduler to get timesteps from.
109
+ num_inference_steps (`int`):
110
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
111
+ must be `None`.
112
+ device (`str` or `torch.device`, *optional*):
113
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
114
+ timesteps (`List[int]`, *optional*):
115
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
116
+ `num_inference_steps` and `sigmas` must be `None`.
117
+ sigmas (`List[float]`, *optional*):
118
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
119
+ `num_inference_steps` and `timesteps` must be `None`.
120
+
121
+ Returns:
122
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
123
+ second element is the number of inference steps.
124
+ """
125
+ if timesteps is not None and sigmas is not None:
126
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
127
+ if timesteps is not None:
128
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
129
+ if not accepts_timesteps:
130
+ raise ValueError(
131
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
132
+ f" timestep schedules. Please check whether you are using the correct scheduler."
133
+ )
134
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
135
+ timesteps = scheduler.timesteps
136
+ num_inference_steps = len(timesteps)
137
+ elif sigmas is not None:
138
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
139
+ if not accept_sigmas:
140
+ raise ValueError(
141
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
142
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
143
+ )
144
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
145
+ timesteps = scheduler.timesteps
146
+ num_inference_steps = len(timesteps)
147
+ else:
148
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
149
+ timesteps = scheduler.timesteps
150
+ return timesteps, num_inference_steps
151
+
152
+
153
+ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixin):
154
+ r"""
155
+ Pipeline for text-to-video generation.
156
+
157
+ Reference: https://github.com/Lightricks/LTX-Video
158
+
159
+ Args:
160
+ transformer ([`LTXVideoTransformer3DModel`]):
161
+ Conditional Transformer architecture to denoise the encoded video latents.
162
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
163
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
164
+ vae ([`AutoencoderKLLTXVideo`]):
165
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
166
+ text_encoder ([`T5EncoderModel`]):
167
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
168
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
169
+ tokenizer (`CLIPTokenizer`):
170
+ Tokenizer of class
171
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
172
+ tokenizer (`T5TokenizerFast`):
173
+ Second Tokenizer of class
174
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
175
+ """
176
+
177
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
178
+ _optional_components = []
179
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
180
+
181
+ def __init__(
182
+ self,
183
+ scheduler: FlowMatchEulerDiscreteScheduler,
184
+ vae: AutoencoderKLLTXVideo,
185
+ text_encoder: T5EncoderModel,
186
+ tokenizer: T5TokenizerFast,
187
+ transformer: LTXVideoTransformer3DModel,
188
+ ):
189
+ super().__init__()
190
+
191
+ self.register_modules(
192
+ vae=vae,
193
+ text_encoder=text_encoder,
194
+ tokenizer=tokenizer,
195
+ transformer=transformer,
196
+ scheduler=scheduler,
197
+ )
198
+
199
+ self.vae_spatial_compression_ratio = (
200
+ self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32
201
+ )
202
+ self.vae_temporal_compression_ratio = (
203
+ self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8
204
+ )
205
+ self.transformer_spatial_patch_size = (
206
+ self.transformer.config.patch_size if getattr(self, "transformer", None) is not None else 1
207
+ )
208
+ self.transformer_temporal_patch_size = (
209
+ self.transformer.config.patch_size_t if getattr(self, "transformer") is not None else 1
210
+ )
211
+
212
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio)
213
+ self.tokenizer_max_length = (
214
+ self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 128
215
+ )
216
+
217
+ def _get_t5_prompt_embeds(
218
+ self,
219
+ prompt: Union[str, List[str]] = None,
220
+ num_videos_per_prompt: int = 1,
221
+ max_sequence_length: int = 128,
222
+ device: Optional[torch.device] = None,
223
+ dtype: Optional[torch.dtype] = None,
224
+ ):
225
+ device = device or self._execution_device
226
+ dtype = dtype or self.text_encoder.dtype
227
+
228
+ prompt = [prompt] if isinstance(prompt, str) else prompt
229
+ batch_size = len(prompt)
230
+
231
+ text_inputs = self.tokenizer(
232
+ prompt,
233
+ padding="max_length",
234
+ max_length=max_sequence_length,
235
+ truncation=True,
236
+ add_special_tokens=True,
237
+ return_tensors="pt",
238
+ )
239
+ text_input_ids = text_inputs.input_ids
240
+ prompt_attention_mask = text_inputs.attention_mask
241
+ prompt_attention_mask = prompt_attention_mask.bool().to(device)
242
+
243
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
244
+
245
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
246
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
247
+ logger.warning(
248
+ "The following part of your input was truncated because `max_sequence_length` is set to "
249
+ f" {max_sequence_length} tokens: {removed_text}"
250
+ )
251
+
252
+ prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
253
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
254
+
255
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
256
+ _, seq_len, _ = prompt_embeds.shape
257
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
258
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
259
+
260
+ prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
261
+ prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1)
262
+
263
+ return prompt_embeds, prompt_attention_mask
264
+
265
+ # Copied from diffusers.pipelines.mochi.pipeline_mochi.MochiPipeline.encode_prompt with 256->128
266
+ def encode_prompt(
267
+ self,
268
+ prompt: Union[str, List[str]],
269
+ negative_prompt: Optional[Union[str, List[str]]] = None,
270
+ do_classifier_free_guidance: bool = True,
271
+ num_videos_per_prompt: int = 1,
272
+ prompt_embeds: Optional[torch.Tensor] = None,
273
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
274
+ prompt_attention_mask: Optional[torch.Tensor] = None,
275
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
276
+ max_sequence_length: int = 128,
277
+ device: Optional[torch.device] = None,
278
+ dtype: Optional[torch.dtype] = None,
279
+ ):
280
+ r"""
281
+ Encodes the prompt into text encoder hidden states.
282
+
283
+ Args:
284
+ prompt (`str` or `List[str]`, *optional*):
285
+ prompt to be encoded
286
+ negative_prompt (`str` or `List[str]`, *optional*):
287
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
288
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
289
+ less than `1`).
290
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
291
+ Whether to use classifier free guidance or not.
292
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
293
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
294
+ prompt_embeds (`torch.Tensor`, *optional*):
295
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
296
+ provided, text embeddings will be generated from `prompt` input argument.
297
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
298
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
299
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
300
+ argument.
301
+ device: (`torch.device`, *optional*):
302
+ torch device
303
+ dtype: (`torch.dtype`, *optional*):
304
+ torch dtype
305
+ """
306
+ print(f"{max_sequence_length=}")
307
+ device = device or self._execution_device
308
+
309
+ prompt = [prompt] if isinstance(prompt, str) else prompt
310
+ if prompt is not None:
311
+ batch_size = len(prompt)
312
+ else:
313
+ batch_size = prompt_embeds.shape[0]
314
+
315
+ if prompt_embeds is None:
316
+ prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(
317
+ prompt=prompt,
318
+ num_videos_per_prompt=num_videos_per_prompt,
319
+ max_sequence_length=max_sequence_length,
320
+ device=device,
321
+ dtype=dtype,
322
+ )
323
+
324
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
325
+ negative_prompt = negative_prompt or ""
326
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
327
+
328
+ if prompt is not None and type(prompt) is not type(negative_prompt):
329
+ raise TypeError(
330
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
331
+ f" {type(prompt)}."
332
+ )
333
+ elif batch_size != len(negative_prompt):
334
+ raise ValueError(
335
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
336
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
337
+ " the batch size of `prompt`."
338
+ )
339
+
340
+ negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds(
341
+ prompt=negative_prompt,
342
+ num_videos_per_prompt=num_videos_per_prompt,
343
+ max_sequence_length=max_sequence_length,
344
+ device=device,
345
+ dtype=dtype,
346
+ )
347
+
348
+ return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
349
+
350
+ def check_inputs(
351
+ self,
352
+ prompt,
353
+ height,
354
+ width,
355
+ callback_on_step_end_tensor_inputs=None,
356
+ prompt_embeds=None,
357
+ negative_prompt_embeds=None,
358
+ prompt_attention_mask=None,
359
+ negative_prompt_attention_mask=None,
360
+ ):
361
+ if height % 32 != 0 or width % 32 != 0:
362
+ raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
363
+
364
+ if callback_on_step_end_tensor_inputs is not None and not all(
365
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
366
+ ):
367
+ raise ValueError(
368
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
369
+ )
370
+
371
+ if prompt is not None and prompt_embeds is not None:
372
+ raise ValueError(
373
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
374
+ " only forward one of the two."
375
+ )
376
+ elif prompt is None and prompt_embeds is None:
377
+ raise ValueError(
378
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
379
+ )
380
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
381
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
382
+
383
+ if prompt_embeds is not None and prompt_attention_mask is None:
384
+ raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
385
+
386
+ if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
387
+ raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
388
+
389
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
390
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
391
+ raise ValueError(
392
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
393
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
394
+ f" {negative_prompt_embeds.shape}."
395
+ )
396
+ if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
397
+ raise ValueError(
398
+ "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
399
+ f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
400
+ f" {negative_prompt_attention_mask.shape}."
401
+ )
402
+
403
+ @staticmethod
404
+ def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:
405
+ # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p].
406
+ # The patch dimensions are then permuted and collapsed into the channel dimension of shape:
407
+ # [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor).
408
+ # dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features
409
+ batch_size, num_channels, num_frames, height, width = latents.shape
410
+ post_patch_num_frames = num_frames // patch_size_t
411
+ post_patch_height = height // patch_size
412
+ post_patch_width = width // patch_size
413
+ latents = latents.reshape(
414
+ batch_size,
415
+ -1,
416
+ post_patch_num_frames,
417
+ patch_size_t,
418
+ post_patch_height,
419
+ patch_size,
420
+ post_patch_width,
421
+ patch_size,
422
+ )
423
+ latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3)
424
+ return latents
425
+
426
+ @staticmethod
427
+ def _unpack_latents(
428
+ latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1
429
+ ) -> torch.Tensor:
430
+ # Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions)
431
+ # are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of
432
+ # what happens in the `_pack_latents` method.
433
+ batch_size = latents.size(0)
434
+ latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size)
435
+ latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3)
436
+ return latents
437
+
438
+ @staticmethod
439
+ def _normalize_latents(
440
+ latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
441
+ ) -> torch.Tensor:
442
+ # Normalize latents across the channel dimension [B, C, F, H, W]
443
+ latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
444
+ latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
445
+ latents = (latents - latents_mean) * scaling_factor / latents_std
446
+ return latents
447
+
448
+ @staticmethod
449
+ def _denormalize_latents(
450
+ latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
451
+ ) -> torch.Tensor:
452
+ # Denormalize latents across the channel dimension [B, C, F, H, W]
453
+ latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
454
+ latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
455
+ latents = latents * latents_std / scaling_factor + latents_mean
456
+ return latents
457
+
458
+ def prepare_latents(
459
+ self,
460
+ batch_size: int = 1,
461
+ num_channels_latents: int = 128,
462
+ height: int = 512,
463
+ width: int = 704,
464
+ num_frames: int = 161,
465
+ dtype: Optional[torch.dtype] = None,
466
+ device: Optional[torch.device] = None,
467
+ generator: Optional[torch.Generator] = None,
468
+ latents: Optional[torch.Tensor] = None,
469
+ ) -> torch.Tensor:
470
+ if latents is not None:
471
+ return latents.to(device=device, dtype=dtype)
472
+
473
+ height = height // self.vae_spatial_compression_ratio
474
+ width = width // self.vae_spatial_compression_ratio
475
+ num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
476
+
477
+ shape = (batch_size, num_channels_latents, num_frames, height, width)
478
+
479
+ if isinstance(generator, list) and len(generator) != batch_size:
480
+ raise ValueError(
481
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
482
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
483
+ )
484
+
485
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
486
+ latents = self._pack_latents(
487
+ latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
488
+ )
489
+ return latents
490
+
491
+ @property
492
+ def guidance_scale(self):
493
+ return self._guidance_scale
494
+
495
+ @property
496
+ def do_classifier_free_guidance(self):
497
+ return self._guidance_scale > 1.0
498
+
499
+ @property
500
+ def num_timesteps(self):
501
+ return self._num_timesteps
502
+
503
+ @property
504
+ def attention_kwargs(self):
505
+ return self._attention_kwargs
506
+
507
+ @property
508
+ def interrupt(self):
509
+ return self._interrupt
510
+
511
+ @torch.no_grad()
512
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
513
+ def __call__(
514
+ self,
515
+ prompt: Union[str, List[str]] = None,
516
+ negative_prompt: Optional[Union[str, List[str]]] = None,
517
+ height: int = 512,
518
+ width: int = 704,
519
+ num_frames: int = 161,
520
+ frame_rate: int = 25,
521
+ num_inference_steps: int = 50,
522
+ timesteps: List[int] = None,
523
+ guidance_scale: float = 3,
524
+ num_videos_per_prompt: Optional[int] = 1,
525
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
526
+ latents: Optional[torch.Tensor] = None,
527
+ prompt_embeds: Optional[torch.Tensor] = None,
528
+ prompt_attention_mask: Optional[torch.Tensor] = None,
529
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
530
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
531
+ decode_timestep: Union[float, List[float]] = 0.0,
532
+ decode_noise_scale: Optional[Union[float, List[float]]] = None,
533
+ output_type: Optional[str] = "pil",
534
+ return_dict: bool = True,
535
+ attention_kwargs: Optional[Dict[str, Any]] = None,
536
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
537
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
538
+ max_sequence_length: int = 128,
539
+ ):
540
+ r"""
541
+ Function invoked when calling the pipeline for generation.
542
+
543
+ Args:
544
+ prompt (`str` or `List[str]`, *optional*):
545
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
546
+ instead.
547
+ height (`int`, defaults to `512`):
548
+ The height in pixels of the generated image. This is set to 480 by default for the best results.
549
+ width (`int`, defaults to `704`):
550
+ The width in pixels of the generated image. This is set to 848 by default for the best results.
551
+ num_frames (`int`, defaults to `161`):
552
+ The number of video frames to generate
553
+ num_inference_steps (`int`, *optional*, defaults to 50):
554
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
555
+ expense of slower inference.
556
+ timesteps (`List[int]`, *optional*):
557
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
558
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
559
+ passed will be used. Must be in descending order.
560
+ guidance_scale (`float`, defaults to `3 `):
561
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
562
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
563
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
564
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
565
+ usually at the expense of lower image quality.
566
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
567
+ The number of videos to generate per prompt.
568
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
569
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
570
+ to make generation deterministic.
571
+ latents (`torch.Tensor`, *optional*):
572
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
573
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
574
+ tensor will ge generated by sampling using the supplied random `generator`.
575
+ prompt_embeds (`torch.Tensor`, *optional*):
576
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
577
+ provided, text embeddings will be generated from `prompt` input argument.
578
+ prompt_attention_mask (`torch.Tensor`, *optional*):
579
+ Pre-generated attention mask for text embeddings.
580
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
581
+ Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
582
+ provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
583
+ negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
584
+ Pre-generated attention mask for negative text embeddings.
585
+ decode_timestep (`float`, defaults to `0.0`):
586
+ The timestep at which generated video is decoded.
587
+ decode_noise_scale (`float`, defaults to `None`):
588
+ The interpolation factor between random noise and denoised latents at the decode timestep.
589
+ output_type (`str`, *optional*, defaults to `"pil"`):
590
+ The output format of the generate image. Choose between
591
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
592
+ return_dict (`bool`, *optional*, defaults to `True`):
593
+ Whether or not to return a [`~pipelines.ltx.LTXPipelineOutput`] instead of a plain tuple.
594
+ attention_kwargs (`dict`, *optional*):
595
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
596
+ `self.processor` in
597
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
598
+ callback_on_step_end (`Callable`, *optional*):
599
+ A function that calls at the end of each denoising steps during the inference. The function is called
600
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
601
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
602
+ `callback_on_step_end_tensor_inputs`.
603
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
604
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
605
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
606
+ `._callback_tensor_inputs` attribute of your pipeline class.
607
+ max_sequence_length (`int` defaults to `128 `):
608
+ Maximum sequence length to use with the `prompt`.
609
+
610
+ Examples:
611
+
612
+ Returns:
613
+ [`~pipelines.ltx.LTXPipelineOutput`] or `tuple`:
614
+ If `return_dict` is `True`, [`~pipelines.ltx.LTXPipelineOutput`] is returned, otherwise a `tuple` is
615
+ returned where the first element is a list with the generated images.
616
+ """
617
+
618
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
619
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
620
+
621
+ # 1. Check inputs. Raise error if not correct
622
+ self.check_inputs(
623
+ prompt=prompt,
624
+ height=height,
625
+ width=width,
626
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
627
+ prompt_embeds=prompt_embeds,
628
+ negative_prompt_embeds=negative_prompt_embeds,
629
+ prompt_attention_mask=prompt_attention_mask,
630
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
631
+ )
632
+
633
+ self._guidance_scale = guidance_scale
634
+ self._attention_kwargs = attention_kwargs
635
+ self._interrupt = False
636
+
637
+ # 2. Define call parameters
638
+ if prompt is not None and isinstance(prompt, str):
639
+ batch_size = 1
640
+ elif prompt is not None and isinstance(prompt, list):
641
+ batch_size = len(prompt)
642
+ else:
643
+ batch_size = prompt_embeds.shape[0]
644
+
645
+ device = self._execution_device
646
+
647
+ # 3. Prepare text embeddings
648
+ (
649
+ prompt_embeds,
650
+ prompt_attention_mask,
651
+ negative_prompt_embeds,
652
+ negative_prompt_attention_mask,
653
+ ) = self.encode_prompt(
654
+ prompt=prompt,
655
+ negative_prompt=negative_prompt,
656
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
657
+ num_videos_per_prompt=num_videos_per_prompt,
658
+ prompt_embeds=prompt_embeds,
659
+ negative_prompt_embeds=negative_prompt_embeds,
660
+ prompt_attention_mask=prompt_attention_mask,
661
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
662
+ max_sequence_length=max_sequence_length,
663
+ device=device,
664
+ )
665
+ if self.do_classifier_free_guidance:
666
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
667
+ prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
668
+
669
+ if Q8Linear is not None and isinstance(self.transformer.transformer_blocks[0].attn1.to_q, Q8Linear):
670
+ prompt_attention_mask = prompt_attention_mask.to(torch.int64)
671
+ prompt_attention_mask = prompt_attention_mask.argmin(-1).int().squeeze()
672
+ prompt_attention_mask[prompt_attention_mask == 0] = max_sequence_length
673
+
674
+ # 4. Prepare latent variables
675
+ num_channels_latents = self.transformer.config.in_channels
676
+ latents = self.prepare_latents(
677
+ batch_size * num_videos_per_prompt,
678
+ num_channels_latents,
679
+ height,
680
+ width,
681
+ num_frames,
682
+ torch.float32,
683
+ device,
684
+ generator,
685
+ latents,
686
+ )
687
+
688
+ # 5. Prepare timesteps
689
+ latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
690
+ latent_height = height // self.vae_spatial_compression_ratio
691
+ latent_width = width // self.vae_spatial_compression_ratio
692
+ video_sequence_length = latent_num_frames * latent_height * latent_width
693
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
694
+ mu = calculate_shift(
695
+ video_sequence_length,
696
+ self.scheduler.config.base_image_seq_len,
697
+ self.scheduler.config.max_image_seq_len,
698
+ self.scheduler.config.base_shift,
699
+ self.scheduler.config.max_shift,
700
+ )
701
+ timesteps, num_inference_steps = retrieve_timesteps(
702
+ self.scheduler,
703
+ num_inference_steps,
704
+ device,
705
+ timesteps,
706
+ sigmas=sigmas,
707
+ mu=mu,
708
+ )
709
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
710
+ self._num_timesteps = len(timesteps)
711
+
712
+ # 6. Prepare micro-conditions
713
+ latent_frame_rate = frame_rate / self.vae_temporal_compression_ratio
714
+ rope_interpolation_scale = (
715
+ 1 / latent_frame_rate,
716
+ self.vae_spatial_compression_ratio,
717
+ self.vae_spatial_compression_ratio,
718
+ )
719
+
720
+ # 7. Denoising loop
721
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
722
+ for i, t in enumerate(timesteps):
723
+ if self.interrupt:
724
+ continue
725
+
726
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
727
+ latent_model_input = latent_model_input.to(prompt_embeds.dtype)
728
+
729
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
730
+ timestep = t.expand(latent_model_input.shape[0])
731
+
732
+ noise_pred = self.transformer(
733
+ hidden_states=latent_model_input,
734
+ encoder_hidden_states=prompt_embeds,
735
+ timestep=timestep,
736
+ encoder_attention_mask=prompt_attention_mask,
737
+ num_frames=latent_num_frames,
738
+ height=latent_height,
739
+ width=latent_width,
740
+ rope_interpolation_scale=rope_interpolation_scale,
741
+ attention_kwargs=attention_kwargs,
742
+ return_dict=False,
743
+ )[0]
744
+ noise_pred = noise_pred.float()
745
+
746
+ if self.do_classifier_free_guidance:
747
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
748
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
749
+
750
+ # compute the previous noisy sample x_t -> x_t-1
751
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
752
+
753
+ if callback_on_step_end is not None:
754
+ callback_kwargs = {}
755
+ for k in callback_on_step_end_tensor_inputs:
756
+ callback_kwargs[k] = locals()[k]
757
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
758
+
759
+ latents = callback_outputs.pop("latents", latents)
760
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
761
+
762
+ # call the callback, if provided
763
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
764
+ progress_bar.update()
765
+
766
+ if XLA_AVAILABLE:
767
+ xm.mark_step()
768
+
769
+ if output_type == "latent":
770
+ video = latents
771
+ else:
772
+ latents = self._unpack_latents(
773
+ latents,
774
+ latent_num_frames,
775
+ latent_height,
776
+ latent_width,
777
+ self.transformer_spatial_patch_size,
778
+ self.transformer_temporal_patch_size,
779
+ )
780
+ latents = self._denormalize_latents(
781
+ latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
782
+ )
783
+ latents = latents.to(prompt_embeds.dtype)
784
+
785
+ if not self.vae.config.timestep_conditioning:
786
+ timestep = None
787
+ else:
788
+ noise = torch.randn(latents.shape, generator=generator, device=device, dtype=latents.dtype)
789
+ if not isinstance(decode_timestep, list):
790
+ decode_timestep = [decode_timestep] * batch_size
791
+ if decode_noise_scale is None:
792
+ decode_noise_scale = decode_timestep
793
+ elif not isinstance(decode_noise_scale, list):
794
+ decode_noise_scale = [decode_noise_scale] * batch_size
795
+
796
+ timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype)
797
+ decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[
798
+ :, None, None, None, None
799
+ ]
800
+ latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
801
+
802
+ video = self.vae.decode(latents, timestep, return_dict=False)[0]
803
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
804
+
805
+ # Offload all models
806
+ self.maybe_free_model_hooks()
807
+
808
+ if not return_dict:
809
+ return (video,)
810
+
811
+ return LTXPipelineOutput(frames=video)
prompt_embeds.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2095fe12d573adf5adf6e2f0d163d494f792fe83deaeb6b20e67a28e6e44e913
3
+ size 8391808
q8_attention_processors.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Reference:
3
+ https://github.com/KONAKONA666/q8_kernels/blob/9cee3f3d4ca5ec8ab463179be32c8001e31f8f33/q8_kernels/modules/attention.py
4
+ """
5
+
6
+ import torch
7
+ import q8_kernels.functional as Q8F
8
+ from diffusers.models.transformers.transformer_ltx import apply_rotary_emb
9
+ from diffusers.models.attention import Attention
10
+
11
+ NON_MM_PRECISION_TYPE = torch.bfloat16
12
+ MM_PRECISION_TYPE = torch.bfloat16
13
+
14
+
15
+ class LTXVideoQ8AttentionProcessor:
16
+ def __call__(
17
+ self,
18
+ attn: Attention,
19
+ hidden_states: torch.Tensor,
20
+ encoder_hidden_states=None,
21
+ attention_mask=None,
22
+ image_rotary_emb=None,
23
+ ) -> torch.Tensor:
24
+ if attention_mask is not None and attention_mask.ndim > 1:
25
+ attention_mask = attention_mask.argmin(-1).squeeze().int()
26
+
27
+ if encoder_hidden_states is None:
28
+ encoder_hidden_states = hidden_states
29
+
30
+ query = attn.to_q(hidden_states)
31
+ key = attn.to_k(encoder_hidden_states)
32
+ value = attn.to_v(encoder_hidden_states)
33
+
34
+ query = attn.norm_q(query, NON_MM_PRECISION_TYPE)
35
+ key = attn.norm_k(key, NON_MM_PRECISION_TYPE)
36
+
37
+ if image_rotary_emb is not None:
38
+ query = apply_rotary_emb(query, image_rotary_emb)
39
+ key = apply_rotary_emb(key, image_rotary_emb)
40
+
41
+ query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
42
+ key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
43
+ value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
44
+
45
+ hidden_states = Q8F.flash_attention.flash_attn_func(
46
+ query, key, value, batch_mask=attention_mask, apply_qk_hadamard=True
47
+ )
48
+ hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
49
+
50
+ hidden_states = attn.to_out[0](hidden_states)
51
+ hidden_states = attn.to_out[1](hidden_states)
52
+ return hidden_states.to(NON_MM_PRECISION_TYPE)
q8_ltx.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ References:
3
+ https://github.com/KONAKONA666/q8_kernels/blob/9cee3f3d4ca5ec8ab463179be32c8001e31f8f33/q8_kernels/utils/convert_weights.py
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from q8_kernels.modules.rms_norm import RMSNorm as QRMSNorm
9
+ from diffusers.models.normalization import RMSNorm
10
+ from q8_kernels.modules.activations import GELU as QGELU
11
+ from diffusers.models.activations import GELU
12
+ from q8_kernels.modules.linear import Q8Linear
13
+ from q8_attention_processors import LTXVideoQ8AttentionProcessor
14
+
15
+ MODULES_TO_NOT_CONVERT = ["proj_in", "time_embed", "caption_projection", "proj_out"]
16
+
17
+
18
+ def replace_linear(model, current_key_name=None, replaced=False):
19
+ for name, child in model.named_children():
20
+ if current_key_name is None:
21
+ current_key_name = []
22
+ current_key_name.append(name)
23
+
24
+ if isinstance(child, nn.Linear) and name not in MODULES_TO_NOT_CONVERT:
25
+ # Check if the current key is not in the `modules_to_not_convert`
26
+ current_key_name_str = ".".join(current_key_name)
27
+ if not any(
28
+ (key + "." in current_key_name_str) or (key == current_key_name_str) for key in MODULES_TO_NOT_CONVERT
29
+ ):
30
+ new_linear = Q8Linear(
31
+ child.in_features, child.out_features, bias=child.bias is not None, device=child.weight.device
32
+ )
33
+ setattr(model, name, new_linear)
34
+ replaced = True
35
+ else:
36
+ replace_linear(model=child, current_key_name=current_key_name, replaced=replaced)
37
+
38
+ current_key_name.pop(-1)
39
+
40
+ return model, replaced
41
+
42
+
43
+ def get_parent_module_and_attr(root, dotted_name: str):
44
+ """
45
+ Splits 'a.b.c' into:
46
+ - parent module = root.a.b
47
+ - attr_name = 'c'
48
+ """
49
+ parts = dotted_name.split(".")
50
+ *parent_parts, attr_name = parts
51
+ parent_module = root
52
+ for p in parent_parts:
53
+ parent_module = getattr(parent_module, p)
54
+ return parent_module, attr_name
55
+
56
+
57
+ def replace_rms_norm(model):
58
+ modules_to_replace = []
59
+ for dotted_name, module in model.named_modules():
60
+ if isinstance(module, RMSNorm):
61
+ modules_to_replace.append((dotted_name, module))
62
+
63
+ replaced = False
64
+ for dotted_name, module in modules_to_replace:
65
+ parent, attr_name = get_parent_module_and_attr(model, dotted_name)
66
+ new_norm = QRMSNorm(
67
+ dim=module.dim,
68
+ elementwise_affine=module.elementwise_affine,
69
+ )
70
+ setattr(parent, attr_name, new_norm)
71
+ replaced = True
72
+
73
+ return model, replaced
74
+
75
+
76
+ def replace_gelu(model, replaced=False):
77
+ for name, child in model.named_children():
78
+ if isinstance(child, GELU):
79
+ new_gelu = QGELU(
80
+ dim_in=child.proj.in_features,
81
+ dim_out=child.proj.out_features,
82
+ approximate=child.approximate,
83
+ bias=child.proj.bias is not None,
84
+ )
85
+ setattr(model, name, new_gelu)
86
+ replaced = True
87
+ else:
88
+ replace_gelu(model=child, replaced=replaced)
89
+
90
+ return model, replaced
91
+
92
+
93
+ def set_attn_processors(model, processor):
94
+ def fn_recursive_attn_processor(name, module: torch.nn.Module, processor):
95
+ if hasattr(module, "set_processor"):
96
+ module.set_processor(processor)
97
+ for sub_name, child in module.named_children():
98
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
99
+
100
+ for name, module in model.named_children():
101
+ fn_recursive_attn_processor(name, module, processor)
102
+
103
+
104
+ def attn_processors(model) -> dict:
105
+ processors = {}
106
+
107
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: dict):
108
+ if hasattr(module, "get_processor"):
109
+ processors[f"{name}.processor"] = module.get_processor()
110
+
111
+ for sub_name, child in module.named_children():
112
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
113
+
114
+ return processors
115
+
116
+ for name, module in model.named_children():
117
+ fn_recursive_add_processors(name, module, processors)
118
+
119
+ return processors
120
+
121
+
122
+ def check_transformer_replaced_correctly(model):
123
+ for block in model.transformer_blocks:
124
+ assert isinstance(block.attn1.to_q, Q8Linear), f"{type(block.attn1.to_q)=} not linear."
125
+ assert isinstance(block.attn2.to_q, Q8Linear), f"{type(block.attn2.to_q)=} not linear."
126
+ assert block.attn1.to_q.weight.dtype == torch.int8, f"{block.attn1.to_q.weight.dtype=}."
127
+ assert block.attn2.to_q.weight.dtype == torch.int8, f"{name=} {block.attn2.to_q.weight.dtype=}."
128
+
129
+ for name, module in model.named_modules():
130
+ if "norm" in name and "norm_out" not in name:
131
+ assert isinstance(module, QRMSNorm), f"{name=}, {type(module)=}"
132
+
133
+ for block in model.transformer_blocks:
134
+ assert isinstance(block.ff.net[0], QGELU), f"{type(block.ff.net[0])=}"
135
+ if getattr(block.ff.net[0], "proj", None) is not None:
136
+ assert block.ff.net[0].proj.weight.dtype == torch.int8, f"{block.ff.net[0].proj.weight.dtype=}."
137
+
138
+ set_attn_processors(model, LTXVideoQ8AttentionProcessor())
139
+ all_attn_processors = attn_processors(model)
140
+ for k, v in all_attn_processors.items():
141
+ assert isinstance(v, LTXVideoQ8AttentionProcessor), f"{name} is not of type LTXVideoQ8AttentionProcessor."
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch
2
+ diffusers
3
+ transformers
4
+ accelerate
5
+ imageio
6
+ Pillow
7
+ -f git+git://github.com/KONAKONA666/q8_kernels.git