jbilcke-hf HF staff commited on
Commit
99df0e2
·
verified ·
1 Parent(s): 28cbc54

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +44 -31
handler.py CHANGED
@@ -155,37 +155,50 @@ class EndpointHandler:
155
  Returns:
156
  Tuple of (video data URI, metadata dictionary)
157
  """
158
- # Process video with Varnish
159
- result = await self.varnish(
160
- input_data=frames,
161
- input_fps=config.fps,
162
- upscale_factor=config.upscale_factor if config.upscale_factor > 1 else None,
163
- enable_interpolation=config.enable_interpolation,
164
- output_fps=config.fps
165
- )
166
-
167
- # Convert to data URI
168
- video_uri = await result.write(
169
- output_type="data-uri",
170
- output_format="mp4",
171
- output_codec="h264",
172
- output_quality=23
173
- )
174
-
175
- # Collect metadata
176
- metadata = {
177
- "width": result.metadata.width,
178
- "height": result.metadata.height,
179
- "num_frames": result.metadata.frame_count,
180
- "fps": result.metadata.fps,
181
- "duration": result.metadata.duration,
182
- "num_inference_steps": config.num_inference_steps,
183
- "seed": config.seed,
184
- "upscale_factor": config.upscale_factor,
185
- "interpolation_enabled": config.enable_interpolation
186
- }
187
-
188
- return video_uri, metadata
 
 
 
 
 
 
 
 
 
 
 
 
 
189
 
190
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
191
  """Process incoming requests for video generation
 
155
  Returns:
156
  Tuple of (video data URI, metadata dictionary)
157
  """
158
+ try:
159
+ logger.info(f"Original frames shape: {frames.shape}")
160
+
161
+ # Remove batch dimension if present
162
+ if len(frames.shape) == 5:
163
+ frames = frames.squeeze(0) # Remove batch dimension
164
+
165
+ logger.info(f"Processed frames shape: {frames.shape}")
166
+
167
+ # Process video with Varnish
168
+ result = await self.varnish(
169
+ input_data=frames,
170
+ input_fps=config.fps,
171
+ output_fps=config.fps,
172
+ upscale_factor=config.upscale_factor if config.upscale_factor > 1 else None,
173
+ enable_interpolation=config.enable_interpolation
174
+ )
175
+
176
+ # Convert to data URI
177
+ video_uri = await result.write(
178
+ output_type="data-uri",
179
+ output_format="mp4",
180
+ output_codec="h264",
181
+ output_quality=23
182
+ )
183
+
184
+ # Collect metadata
185
+ metadata = {
186
+ "width": result.metadata.width,
187
+ "height": result.metadata.height,
188
+ "num_frames": result.metadata.frame_count,
189
+ "fps": result.metadata.fps,
190
+ "duration": result.metadata.duration,
191
+ "num_inference_steps": config.num_inference_steps,
192
+ "seed": config.seed,
193
+ "upscale_factor": config.upscale_factor,
194
+ "interpolation_enabled": config.enable_interpolation
195
+ }
196
+
197
+ return video_uri, metadata
198
+
199
+ except Exception as e:
200
+ logger.error(f"Error in process_frames: {str(e)}")
201
+ raise RuntimeError(f"Failed to process frames: {str(e)}")
202
 
203
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
204
  """Process incoming requests for video generation