q8-ltx-video / inference.py
sayakpaul's picture
sayakpaul HF staff
Upload 14 files
8cf98bd verified
from diffusers import LTXPipeline, LTXVideoTransformer3DModel
from huggingface_hub import hf_hub_download
import argparse
import os
from q8_ltx import check_transformer_replaced_correctly, replace_gelu, replace_linear, replace_rms_norm
import safetensors.torch
from q8_kernels.graph.graph import make_dynamic_graphed_callable
import torch
import gc
from diffusers.utils import export_to_video
# Taken from
# https://github.com/KONAKONA666/LTX-Video/blob/c8462ed2e359cda4dec7f49d98029994e850dc90/inference.py#L115C1-L138C28
def convert_prompt_to_filename(text: str, max_len: int = 20) -> str:
# Remove non-letters and convert to lowercase
clean_text = "".join(char.lower() for char in text if char.isalpha() or char.isspace())
# Split into words
words = clean_text.split()
# Build result string keeping track of length
result = []
current_length = 0
for word in words:
# Add word length plus 1 for underscore (except for first word)
new_length = current_length + len(word)
if new_length <= max_len:
result.append(word)
current_length += len(word)
else:
break
return "-".join(result)
def load_text_encoding_pipeline():
return LTXPipeline.from_pretrained(
"Lightricks/LTX-Video", transformer=None, vae=None, torch_dtype=torch.bfloat16
).to("cuda")
def encode_prompt(pipe, prompt, negative_prompt, max_sequence_length=128):
prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask = pipe.encode_prompt(
prompt=prompt, negative_prompt=negative_prompt, max_sequence_length=max_sequence_length
)
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
def load_q8_transformer(args):
with torch.device("meta"):
transformer_config = LTXVideoTransformer3DModel.load_config("Lightricks/LTX-Video", subfolder="transformer")
transformer = LTXVideoTransformer3DModel.from_config(transformer_config)
transformer = replace_gelu(transformer)[0]
transformer = replace_linear(transformer)[0]
transformer = replace_rms_norm(transformer)[0]
if os.path.isfile(f"{args.q8_transformer_path}/diffusion_pytorch_model.safetensors"):
state_dict = safetensors.torch.load_file(f"{args.q8_transformer_path}/diffusion_pytorch_model.safetensors")
else:
state_dict = safetensors.torch.load_file(
hf_hub_download(args.q8_transformer_path, "diffusion_pytorch_model.safetensors")
)
transformer.load_state_dict(state_dict, strict=True, assign=True)
check_transformer_replaced_correctly(transformer)
return transformer
@torch.no_grad()
def main(args):
text_encoding_pipeline = load_text_encoding_pipeline()
prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask = encode_prompt(
pipe=text_encoding_pipeline,
prompt=args.prompt,
negative_prompt=args.negative_prompt,
max_sequence_length=args.max_sequence_length,
)
del text_encoding_pipeline
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
gc.collect()
if args.q8_transformer_path:
transformer = load_q8_transformer(args)
pipe = LTXPipeline.from_pretrained("Lightricks/LTX-Video", transformer=None, text_encoder=None)
pipe.transformer = transformer
pipe.transformer = pipe.transformer.to(torch.bfloat16)
for b in pipe.transformer.transformer_blocks:
b.to(dtype=torch.float)
for n, m in pipe.transformer.transformer_blocks.named_parameters():
if "scale_shift_table" in n:
m.data = m.data.to(torch.bfloat16)
pipe.transformer.forward = make_dynamic_graphed_callable(pipe.transformer.forward)
pipe.vae = pipe.vae.to(torch.bfloat16)
else:
pipe = LTXPipeline.from_pretrained("Lightricks/LTX-Video", text_encoder=None, torch_dtype=torch.bfloat16)
pipe = pipe.to("cuda")
width, height = args.resolution.split("x")[::-1]
video = pipe(
prompt_embeds=prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_embeds=negative_prompt_embeds,
negative_prompt_attention_mask=negative_prompt_attention_mask,
width=int(width),
height=int(height),
num_frames=args.num_frames,
num_inference_steps=args.steps,
max_sequence_length=args.max_sequence_length,
generator=torch.manual_seed(2025),
).frames[0]
print(f"Max memory: {torch.cuda.max_memory_allocated() / 1024 / 1024} MB.")
if args.out_path is None:
filename_from_prompt = convert_prompt_to_filename(args.prompt, max_len=30)
base_filename = f"{filename_from_prompt}_{args.num_frames}x{height}x{width}"
base_filename += "_q8" if args.q8_transformer_path is not None else ""
args.out_path = base_filename + ".mp4"
export_to_video(video, args.out_path, fps=24)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--q8_transformer_path", type=str, default=None)
parser.add_argument("--prompt", type=str)
parser.add_argument("--negative_prompt", type=str, default=None)
parser.add_argument("--num_frames", type=int, default=81)
parser.add_argument("--resolution", type=str, default="480x704")
parser.add_argument("--steps", type=int, default=50)
parser.add_argument("--max_sequence_length", type=int, default=512)
parser.add_argument("--out_path", type=str, default=None)
args = parser.parse_args()
main(args)