tsqn's picture
Update app.py
1fd2df8 verified
import spaces
from datetime import datetime
import gc
import gradio as gr
import numpy as np
import random
from pathlib import Path
import os
from diffusers import AutoencoderKLLTXVideo, LTXPipeline, LTXVideoTransformer3DModel
from diffusers.utils import export_to_video
from transformers import T5EncoderModel, T5Tokenizer
import torch
from utils import install_packages
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.jit._state.disable()
torch.set_grad_enabled(False)
gc.collect()
torch.cuda.empty_cache()
ckpt_path = Path("a-r-r-o-w/LTX-Video-0.9.1-diffusers")
single_file_url = "https://huggingface.co/Lightricks/LTX-Video/ltx-video-2b-v0.9.1.safetensors"
transformer = LTXVideoTransformer3DModel.from_single_file(
single_file_url, torch_dtype=torch.bfloat16
)
vae = AutoencoderKLLTXVideo.from_single_file(
single_file_url, torch_dtype=torch.bfloat16)
vae.eval()
vae = vae.to("cuda")
text_encoder = T5EncoderModel.from_pretrained(
ckpt_path,
subfolder="text_encoder",
torch_dtype=torch.bfloat16
)
text_encoder.eval()
text_encoder = text_encoder.to("cuda")
tokenizer = T5Tokenizer.from_pretrained(
ckpt_path,
subfolder="tokenizer"
)
pipeline = LTXPipeline.from_single_file(
single_file_url,
transformer=transformer,
text_encoder=text_encoder,
tokenizer=tokenizer,
vae=vae,
torch_dtype=torch.bfloat16
)
# pipeline.enable_model_cpu_offload()
pipeline.vae.enable_tiling()
pipeline.vae.enable_slicing()
pipeline = pipeline.to("cuda")
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1280
@spaces.GPU()
def infer(
prompt,
negative_prompt,
seed,
randomize_seed,
width=704,
height=448,
num_frames=129,
fps=24,
num_inference_steps=30,
progress=gr.Progress(track_tqdm=True),
):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator(device='cuda').manual_seed(seed)
with torch.amp.autocast_mode.autocast('cuda', torch.bfloat16), torch.no_grad(), torch.inference_mode():
video = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
num_frames=num_frames,
# guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
# decode_timestep=decode_timestep,
# decode_noise_scale=decode_noise_scale,
generator=generator,
# max_sequence_length=512,
).frames[0]
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"output_{timestamp}.mp4"
os.makedirs("output", exist_ok=True)
output_path = f"./output/{filename}"
export_to_video(video, output_path, fps=fps)
gc.collect
torch.cuda.empty_cache()
return output_path
css = """
#col-container {
margin: 0 auto;
max-width: 640px;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(" # Text-to-Image Gradio Template")
with gr.Row():
prompt = gr.Textbox(
label="Prompt",
lines=3,
value=str("A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage"),
)
negative_prompt = gr.Textbox(
label="Negative prompt",
lines=3,
value=str("worst quality, blurry, distorted"),
)
with gr.Row():
run_button = gr.Button("Run", scale=0, variant="huggingface")
with gr.Row():
result = gr.Video(label="Result", show_label=False)
with gr.Accordion("Advanced Settings", open=False):
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=704, # Replace with defaults that work for your model
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=448, # Replace with defaults that work for your model
)
with gr.Row():
num_frames = gr.Slider(
label="Number of frames",
minimum=1,
maximum=257,
step=32,
value=129, # Replace with defaults that work for your model
)
fps = gr.Slider(
label="Number of frames per second",
minimum=1,
maximum=30,
step=1,
value=24, # Replace with defaults that work for your model
)
with gr.Row():
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=50,
step=1,
value=30, # Replace with defaults that work for your model
)
gr.on(
triggers=[run_button.click, prompt.submit],
fn=infer,
inputs=[
prompt,
negative_prompt,
seed,
randomize_seed,
width,
height,
num_frames,
fps,
num_inference_steps,
],
outputs=[result],
)
if __name__ == "__main__":
install_packages()
demo.launch()