jbilcke-hf HF staff commited on
Commit
132e8c4
·
verified ·
1 Parent(s): 8659e21

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +92 -0
handler.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Any, Union, Optional
2
+ import torch
3
+ from diffusers import LTXPipeline, LTXImageToVideoPipeline
4
+ from PIL import Image
5
+ import base64
6
+ import io
7
+
8
+ class EndpointHandler:
9
+ def __init__(self, path: str = ""):
10
+ """Initialize the LTX Video handler with both text-to-video and image-to-video pipelines.
11
+
12
+ Args:
13
+ path (str): Path to the model weights directory
14
+ """
15
+ # Load both pipelines with bfloat16 precision as recommended in docs
16
+ self.text_to_video = LTXPipeline.from_pretrained(
17
+ path,
18
+ torch_dtype=torch.bfloat16
19
+ ).to("cuda")
20
+
21
+ self.image_to_video = LTXImageToVideoPipeline.from_pretrained(
22
+ path,
23
+ torch_dtype=torch.bfloat16
24
+ ).to("cuda")
25
+
26
+ # Enable memory optimizations
27
+ self.text_to_video.enable_model_cpu_offload()
28
+ self.image_to_video.enable_model_cpu_offload()
29
+
30
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
31
+ """Process the input data and generate video using LTX.
32
+
33
+ Args:
34
+ data (Dict[str, Any]): Input data containing:
35
+ - prompt (str): Text description for video generation
36
+ - image (Optional[str]): Base64 encoded image for image-to-video generation
37
+ - num_frames (Optional[int]): Number of frames to generate (default: 24)
38
+ - guidance_scale (Optional[float]): Guidance scale (default: 7.5)
39
+ - num_inference_steps (Optional[int]): Number of inference steps (default: 50)
40
+
41
+ Returns:
42
+ Dict[str, Any]: Dictionary containing:
43
+ - frames: List of base64 encoded frames
44
+ """
45
+ # Extract parameters
46
+ prompt = data.get("prompt")
47
+ if not prompt:
48
+ raise ValueError("'prompt' is required in the input data")
49
+
50
+ # Get optional parameters with defaults
51
+ num_frames = data.get("num_frames", 24)
52
+ guidance_scale = data.get("guidance_scale", 7.5)
53
+ num_inference_steps = data.get("num_inference_steps", 50)
54
+
55
+ # Check if image is provided for image-to-video generation
56
+ image_data = data.get("image")
57
+
58
+ try:
59
+ if image_data:
60
+ # Decode base64 image
61
+ image_bytes = base64.b64decode(image_data)
62
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
63
+
64
+ # Generate video from image
65
+ output = self.image_to_video(
66
+ prompt=prompt,
67
+ image=image,
68
+ num_frames=num_frames,
69
+ guidance_scale=guidance_scale,
70
+ num_inference_steps=num_inference_steps
71
+ )
72
+ else:
73
+ # Generate video from text only
74
+ output = self.text_to_video(
75
+ prompt=prompt,
76
+ num_frames=num_frames,
77
+ guidance_scale=guidance_scale,
78
+ num_inference_steps=num_inference_steps
79
+ )
80
+
81
+ # Convert frames to base64
82
+ frames = []
83
+ for frame in output.frames[0]: # First element contains the frames
84
+ buffer = io.BytesIO()
85
+ frame.save(buffer, format="PNG")
86
+ frame_base64 = base64.b64encode(buffer.getvalue()).decode()
87
+ frames.append(frame_base64)
88
+
89
+ return {"frames": frames}
90
+
91
+ except Exception as e:
92
+ raise RuntimeError(f"Error generating video: {str(e)}")