Spaces:
Paused
Paused
from diffusers import LTXPipeline, LTXVideoTransformer3DModel | |
from huggingface_hub import hf_hub_download | |
import argparse | |
import os | |
from q8_ltx import check_transformer_replaced_correctly, replace_gelu, replace_linear, replace_rms_norm | |
import safetensors.torch | |
from q8_kernels.graph.graph import make_dynamic_graphed_callable | |
import torch | |
import gc | |
from diffusers.utils import export_to_video | |
# Taken from | |
# https://github.com/KONAKONA666/LTX-Video/blob/c8462ed2e359cda4dec7f49d98029994e850dc90/inference.py#L115C1-L138C28 | |
def convert_prompt_to_filename(text: str, max_len: int = 20) -> str: | |
# Remove non-letters and convert to lowercase | |
clean_text = "".join(char.lower() for char in text if char.isalpha() or char.isspace()) | |
# Split into words | |
words = clean_text.split() | |
# Build result string keeping track of length | |
result = [] | |
current_length = 0 | |
for word in words: | |
# Add word length plus 1 for underscore (except for first word) | |
new_length = current_length + len(word) | |
if new_length <= max_len: | |
result.append(word) | |
current_length += len(word) | |
else: | |
break | |
return "-".join(result) | |
def load_text_encoding_pipeline(): | |
return LTXPipeline.from_pretrained( | |
"Lightricks/LTX-Video", transformer=None, vae=None, torch_dtype=torch.bfloat16 | |
).to("cuda") | |
def encode_prompt(pipe, prompt, negative_prompt, max_sequence_length=128): | |
prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask = pipe.encode_prompt( | |
prompt=prompt, negative_prompt=negative_prompt, max_sequence_length=max_sequence_length | |
) | |
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask | |
def load_q8_transformer(args): | |
with torch.device("meta"): | |
transformer_config = LTXVideoTransformer3DModel.load_config("Lightricks/LTX-Video", subfolder="transformer") | |
transformer = LTXVideoTransformer3DModel.from_config(transformer_config) | |
transformer = replace_gelu(transformer)[0] | |
transformer = replace_linear(transformer)[0] | |
transformer = replace_rms_norm(transformer)[0] | |
if os.path.isfile(f"{args.q8_transformer_path}/diffusion_pytorch_model.safetensors"): | |
state_dict = safetensors.torch.load_file(f"{args.q8_transformer_path}/diffusion_pytorch_model.safetensors") | |
else: | |
state_dict = safetensors.torch.load_file( | |
hf_hub_download(args.q8_transformer_path, "diffusion_pytorch_model.safetensors") | |
) | |
transformer.load_state_dict(state_dict, strict=True, assign=True) | |
check_transformer_replaced_correctly(transformer) | |
return transformer | |
def main(args): | |
text_encoding_pipeline = load_text_encoding_pipeline() | |
prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask = encode_prompt( | |
pipe=text_encoding_pipeline, | |
prompt=args.prompt, | |
negative_prompt=args.negative_prompt, | |
max_sequence_length=args.max_sequence_length, | |
) | |
del text_encoding_pipeline | |
torch.cuda.empty_cache() | |
torch.cuda.reset_peak_memory_stats() | |
gc.collect() | |
if args.q8_transformer_path: | |
transformer = load_q8_transformer(args) | |
pipe = LTXPipeline.from_pretrained("Lightricks/LTX-Video", transformer=None, text_encoder=None) | |
pipe.transformer = transformer | |
pipe.transformer = pipe.transformer.to(torch.bfloat16) | |
for b in pipe.transformer.transformer_blocks: | |
b.to(dtype=torch.float) | |
for n, m in pipe.transformer.transformer_blocks.named_parameters(): | |
if "scale_shift_table" in n: | |
m.data = m.data.to(torch.bfloat16) | |
pipe.transformer.forward = make_dynamic_graphed_callable(pipe.transformer.forward) | |
pipe.vae = pipe.vae.to(torch.bfloat16) | |
else: | |
pipe = LTXPipeline.from_pretrained("Lightricks/LTX-Video", text_encoder=None, torch_dtype=torch.bfloat16) | |
pipe = pipe.to("cuda") | |
width, height = args.resolution.split("x")[::-1] | |
video = pipe( | |
prompt_embeds=prompt_embeds, | |
prompt_attention_mask=prompt_attention_mask, | |
negative_prompt_embeds=negative_prompt_embeds, | |
negative_prompt_attention_mask=negative_prompt_attention_mask, | |
width=int(width), | |
height=int(height), | |
num_frames=args.num_frames, | |
num_inference_steps=args.steps, | |
max_sequence_length=args.max_sequence_length, | |
generator=torch.manual_seed(2025), | |
).frames[0] | |
print(f"Max memory: {torch.cuda.max_memory_allocated() / 1024 / 1024} MB.") | |
if args.out_path is None: | |
filename_from_prompt = convert_prompt_to_filename(args.prompt, max_len=30) | |
base_filename = f"{filename_from_prompt}_{args.num_frames}x{height}x{width}" | |
base_filename += "_q8" if args.q8_transformer_path is not None else "" | |
args.out_path = base_filename + ".mp4" | |
export_to_video(video, args.out_path, fps=24) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--q8_transformer_path", type=str, default=None) | |
parser.add_argument("--prompt", type=str) | |
parser.add_argument("--negative_prompt", type=str, default=None) | |
parser.add_argument("--num_frames", type=int, default=81) | |
parser.add_argument("--resolution", type=str, default="480x704") | |
parser.add_argument("--steps", type=int, default=50) | |
parser.add_argument("--max_sequence_length", type=int, default=512) | |
parser.add_argument("--out_path", type=str, default=None) | |
args = parser.parse_args() | |
main(args) | |