File size: 5,497 Bytes
132e8c4
 
 
 
 
 
1a6f91c
 
 
 
132e8c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a6f91c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132e8c4
 
 
 
 
 
 
 
 
1a6f91c
132e8c4
 
 
 
 
1a6f91c
 
132e8c4
 
 
 
 
 
 
 
1a6f91c
132e8c4
 
 
 
 
 
 
1a6f91c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132e8c4
1a6f91c
 
 
 
 
132e8c4
1a6f91c
 
 
 
132e8c4
 
1a6f91c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
from typing import Dict, Any, Union, Optional
import torch
from diffusers import LTXPipeline, LTXImageToVideoPipeline
from PIL import Image
import base64
import io
import tempfile
import numpy as np
from moviepy.editor import ImageSequenceClip
import os

class EndpointHandler:
    def __init__(self, path: str = ""):
        """Initialize the LTX Video handler with both text-to-video and image-to-video pipelines.
        
        Args:
            path (str): Path to the model weights directory
        """
        # Load both pipelines with bfloat16 precision as recommended in docs
        self.text_to_video = LTXPipeline.from_pretrained(
            path,
            torch_dtype=torch.bfloat16
        ).to("cuda")
        
        self.image_to_video = LTXImageToVideoPipeline.from_pretrained(
            path,
            torch_dtype=torch.bfloat16
        ).to("cuda")

        # Enable memory optimizations
        self.text_to_video.enable_model_cpu_offload()
        self.image_to_video.enable_model_cpu_offload()
        
        # Set default FPS
        self.fps = 24

    def _create_video_file(self, images: torch.Tensor, fps: int = 24) -> bytes:
        """Convert frames to an MP4 video file.
        
        Args:
            images (torch.Tensor): Generated frames tensor
            fps (int): Frames per second for the output video
            
        Returns:
            bytes: MP4 video file content
        """
        # Convert tensor to numpy array
        video_np = images.squeeze(0).permute(1, 2, 3, 0).cpu().float().numpy()
        video_np = (video_np * 255).astype(np.uint8)
        
        # Get dimensions
        height, width = video_np.shape[1:3]
        
        # Create temporary file
        output_path = tempfile.mktemp(suffix=".mp4")
        
        try:
            # Create video clip and write to file
            clip = ImageSequenceClip(list(video_np), fps=fps)
            resized = clip.resize((width, height))
            resized.write_videofile(output_path, codec="libx264", audio=False)
            
            # Read the video file
            with open(output_path, "rb") as f:
                video_content = f.read()
                
            return video_content
            
        finally:
            # Cleanup
            if os.path.exists(output_path):
                os.remove(output_path)
            
            # Clear memory
            del video_np
            torch.cuda.empty_cache()

    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        """Process the input data and generate video using LTX.
        
        Args:
            data (Dict[str, Any]): Input data containing:
                - prompt (str): Text description for video generation
                - image (Optional[str]): Base64 encoded image for image-to-video generation
                - num_frames (Optional[int]): Number of frames to generate (default: 24)
                - fps (Optional[int]): Frames per second (default: 24)
                - guidance_scale (Optional[float]): Guidance scale (default: 7.5)
                - num_inference_steps (Optional[int]): Number of inference steps (default: 50)
        
        Returns:
            Dict[str, Any]: Dictionary containing:
                - video: Base64 encoded MP4 video
                - content-type: MIME type of the video
        """
        # Extract parameters
        prompt = data.get("prompt")
        if not prompt:
            raise ValueError("'prompt' is required in the input data")

        # Get optional parameters with defaults
        num_frames = data.get("num_frames", 24)
        fps = data.get("fps", self.fps)
        guidance_scale = data.get("guidance_scale", 7.5)
        num_inference_steps = data.get("num_inference_steps", 50)

        # Check if image is provided for image-to-video generation
        image_data = data.get("image")
        
        try:
            with torch.no_grad():
                if image_data:
                    # Decode base64 image
                    image_bytes = base64.b64decode(image_data)
                    image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
                    
                    # Generate video from image
                    output = self.image_to_video(
                        prompt=prompt,
                        image=image,
                        num_frames=num_frames,
                        guidance_scale=guidance_scale,
                        num_inference_steps=num_inference_steps,
                        output_type="pt"
                    ).images
                else:
                    # Generate video from text only
                    output = self.text_to_video(
                        prompt=prompt,
                        num_frames=num_frames,
                        guidance_scale=guidance_scale,
                        num_inference_steps=num_inference_steps,
                        output_type="pt"
                    ).images

                # Convert frames to video file
                video_content = self._create_video_file(output, fps=fps)
                
                # Encode video to base64
                video_base64 = base64.b64encode(video_content).decode('utf-8')

                return {
                    "video": video_base64,
                    "content-type": "video/mp4"
                }

        except Exception as e:
            raise RuntimeError(f"Error generating video: {str(e)}")