LTX-Video-0.9.1-HFIE / handler.py
jbilcke-hf's picture
jbilcke-hf HF staff
Update handler.py
72c8076 verified
raw
history blame
5.5 kB
from typing import Dict, Any, Union, Optional
import torch
from diffusers import LTXPipeline, LTXImageToVideoPipeline
from PIL import Image
import base64
import io
import tempfile
import numpy as np
from moviepy.editor import ImageSequenceClip
import os
class EndpointHandler:
def __init__(self, path: str = ""):
"""Initialize the LTX Video handler with both text-to-video and image-to-video pipelines.
Args:
path (str): Path to the model weights directory
"""
# Load both pipelines with bfloat16 precision as recommended in docs
self.text_to_video = LTXPipeline.from_pretrained(
path,
torch_dtype=torch.bfloat16
).to("cuda")
self.image_to_video = LTXImageToVideoPipeline.from_pretrained(
path,
torch_dtype=torch.bfloat16
).to("cuda")
# Enable memory optimizations
self.text_to_video.enable_model_cpu_offload()
self.image_to_video.enable_model_cpu_offload()
# Set default FPS
self.fps = 24
def _create_video_file(self, images: torch.Tensor, fps: int = 24) -> bytes:
"""Convert frames to an MP4 video file.
Args:
images (torch.Tensor): Generated frames tensor
fps (int): Frames per second for the output video
Returns:
bytes: MP4 video file content
"""
# Convert tensor to numpy array
video_np = images.squeeze(0).permute(1, 2, 3, 0).cpu().float().numpy()
video_np = (video_np * 255).astype(np.uint8)
# Get dimensions
height, width = video_np.shape[1:3]
# Create temporary file
output_path = tempfile.mktemp(suffix=".mp4")
try:
# Create video clip and write to file
clip = ImageSequenceClip(list(video_np), fps=fps)
resized = clip.resize((width, height))
resized.write_videofile(output_path, codec="libx264", audio=False)
# Read the video file
with open(output_path, "rb") as f:
video_content = f.read()
return video_content
finally:
# Cleanup
if os.path.exists(output_path):
os.remove(output_path)
# Clear memory
del video_np
torch.cuda.empty_cache()
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""Process the input data and generate video using LTX.
Args:
data (Dict[str, Any]): Input data containing:
- prompt (str): Text description for video generation
- image (Optional[str]): Base64 encoded image for image-to-video generation
- num_frames (Optional[int]): Number of frames to generate (default: 24)
- fps (Optional[int]): Frames per second (default: 24)
- guidance_scale (Optional[float]): Guidance scale (default: 7.5)
- num_inference_steps (Optional[int]): Number of inference steps (default: 50)
Returns:
Dict[str, Any]: Dictionary containing:
- video: Base64 encoded MP4 video
- content-type: MIME type of the video
"""
# Extract parameters
prompt = data.get("prompt")
if not prompt:
raise ValueError("'prompt' is required in the input data")
# Get optional parameters with defaults
num_frames = data.get("num_frames", 24)
fps = data.get("fps", self.fps)
guidance_scale = data.get("guidance_scale", 7.5)
num_inference_steps = data.get("num_inference_steps", 50)
# Check if image is provided for image-to-video generation
image_data = data.get("image")
try:
with torch.no_grad():
if image_data:
# Decode base64 image
image_bytes = base64.b64decode(image_data)
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
# Generate video from image
output = self.image_to_video(
prompt=prompt,
image=image,
num_frames=num_frames,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
output_type="pt"
).frames[0]
else:
# Generate video from text only
output = self.text_to_video(
prompt=prompt,
num_frames=num_frames,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
output_type="pt"
).frames[0]
# Convert frames to video file
video_content = self._create_video_file(output, fps=fps)
# Encode video to base64
video_base64 = base64.b64encode(video_content).decode('utf-8')
return {
"video": video_base64,
"content-type": "video/mp4"
}
except Exception as e:
raise RuntimeError(f"Error generating video: {str(e)}")