cronos3k commited on
Commit
b9f9055
·
verified ·
1 Parent(s): 576aa9a

Update app.py

Browse files

Key changes based on the paper's implementation:

Split the processing into clear stages with memory cleanup between each:

Stage 1: Generate sparse structure
Stage 2: Generate video preview in batches
GLB generation when requested


Added batched processing for video frame generation:

Process 30 frames at a time instead of all 120 at once
Clear CUDA cache after each batch


Added explicit memory management:

torch.cuda.empty_cache() calls at key points
Explicit deletion of large temporary data
Clear video data after saving


Separated high-quality and reduced GLB generation into distinct functions
Added progress visibility with verbose=True for GLB generation to track progress

This should help prevent GPU timeouts by:

Breaking up large operations into smaller chunks
Managing memory more efficiently
Cleaning up resources between steps

Files changed (1) hide show
  1. app.py +57 -50
app.py CHANGED
@@ -35,13 +35,6 @@ def end_session(req: gr.Request):
35
  def preprocess_image(image: Image.Image) -> Tuple[str, Image.Image]:
36
  """
37
  Preprocess the input image.
38
-
39
- Args:
40
- image (Image.Image): The input image.
41
-
42
- Returns:
43
- str: uuid of the trial.
44
- Image.Image: The preprocessed image.
45
  """
46
  processed_image = pipeline.preprocess_image(image)
47
  return processed_image
@@ -102,11 +95,11 @@ def image_to_3d(
102
  req: gr.Request,
103
  ) -> Tuple[dict, str]:
104
  """
105
- Convert an image to a 3D model with memory management.
106
  """
107
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
108
 
109
- # Generate base outputs
110
  outputs = pipeline.run(
111
  image,
112
  seed=seed,
@@ -121,68 +114,67 @@ def image_to_3d(
121
  "cfg_strength": slat_guidance_strength,
122
  },
123
  )
124
-
125
- # Clear CUDA cache after model generation
126
- torch.cuda.empty_cache()
127
 
128
- # Generate video preview in smaller batches
129
- video = []
130
- video_geo = []
 
 
 
131
  batch_size = 30 # Process 30 frames at a time
132
  num_frames = 120
133
 
134
  for i in range(0, num_frames, batch_size):
135
  end_idx = min(i + batch_size, num_frames)
136
- curr_frames = end_idx - i
137
-
138
- # Generate color frames
139
  batch_frames = render_utils.render_video(
140
- outputs['gaussian'][0],
141
- num_frames=curr_frames,
142
  start_frame=i
143
  )['color']
144
- video.extend(batch_frames)
145
 
146
- # Generate geometry frames
147
  batch_geo = render_utils.render_video(
148
- outputs['mesh'][0],
149
- num_frames=curr_frames,
150
  start_frame=i
151
  )['normal']
152
- video_geo.extend(batch_geo)
153
 
154
  # Clear cache after each batch
155
  torch.cuda.empty_cache()
156
-
157
- # Combine and save video
158
- video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
 
159
  trial_id = str(uuid.uuid4())
160
  video_path = os.path.join(user_dir, f"{trial_id}.mp4")
161
  imageio.mimsave(video_path, video, fps=15)
162
 
163
- # Clear memory
 
 
164
  del video
165
- del video_geo
166
  torch.cuda.empty_cache()
167
 
168
- # Pack state and return
169
  state = pack_state(outputs['gaussian'][0], outputs['mesh'][0], trial_id)
170
  return state, video_path
171
 
172
  @spaces.GPU
173
- def export_full_quality_glb(
174
  state: dict,
175
  req: gr.Request,
176
  ) -> Tuple[str, str]:
177
  """
178
- Export a full-quality GLB file with memory management.
179
  """
180
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
181
  gs, mesh, trial_id = unpack_state(state)
182
 
183
- # Clear cache before starting
184
  torch.cuda.empty_cache()
185
 
 
186
  glb = postprocessing_utils.to_glb(
187
  gs,
188
  mesh,
@@ -192,37 +184,51 @@ def export_full_quality_glb(
192
  texture_size=2048, # Maximum texture resolution
193
  verbose=True # Show progress
194
  )
 
195
  glb_path = os.path.join(user_dir, f"{trial_id}_full.glb")
196
  glb.export(glb_path)
197
 
198
- # Clear cache after finishing
199
  torch.cuda.empty_cache()
200
  return glb_path, glb_path
201
 
202
  @spaces.GPU
203
- def extract_glb(
204
  state: dict,
205
  mesh_simplify: float,
206
  texture_size: int,
207
  req: gr.Request,
208
  ) -> Tuple[str, str]:
209
  """
210
- Extract a GLB file from the 3D model.
211
  """
212
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
213
  gs, mesh, trial_id = unpack_state(state)
214
- glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
215
- glb_path = os.path.join(user_dir, f"{trial_id}.glb")
 
 
 
 
 
 
 
 
 
 
216
  glb.export(glb_path)
 
 
 
217
  return glb_path, glb_path
218
 
219
  with gr.Blocks(delete_cache=(600, 600)) as demo:
220
  gr.Markdown("""
221
  ## Image to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
222
- * Upload an image and click "Generate" to create a 3D asset. If the image has alpha channel, it be used as the mask. Otherwise, we use `rembg` to remove the background.
223
  * After generation:
224
- * Click "Download Full-Quality GLB" for maximum quality
225
- * Or use GLB Extraction Settings for a reduced size version
226
  """)
227
 
228
  with gr.Row():
@@ -242,12 +248,13 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
242
  slat_sampling_steps = gr.Slider(1, 500, label="Sampling Steps", value=12, step=1)
243
 
244
  generate_btn = gr.Button("Generate")
 
245
 
246
  with gr.Accordion(label="GLB Extraction Settings", open=False):
247
  mesh_simplify = gr.Slider(0.0, 0.98, label="Simplify", value=0.95, step=0.01)
248
  texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
249
 
250
- extract_glb_btn = gr.Button("Extract GLB", interactive=False)
251
 
252
  with gr.Column():
253
  video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
@@ -258,7 +265,7 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
258
 
259
  output_buf = gr.State()
260
 
261
- # Example images at the bottom of the page
262
  with gr.Row():
263
  examples = gr.Examples(
264
  examples=[
@@ -291,12 +298,12 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
291
  inputs=[image_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
292
  outputs=[output_buf, video_output],
293
  ).then(
294
- lambda: [gr.Button(interactive=True), gr.Button(interactive=True), gr.Button(interactive=False)],
295
- outputs=[download_full, extract_glb_btn, download_reduced],
296
  )
297
 
298
- download_full.click(
299
- export_full_quality_glb,
300
  inputs=[output_buf],
301
  outputs=[model_output, download_full],
302
  ).then(
@@ -304,8 +311,8 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
304
  outputs=[download_full],
305
  )
306
 
307
- extract_glb_btn.click(
308
- extract_glb,
309
  inputs=[output_buf, mesh_simplify, texture_size],
310
  outputs=[model_output, download_reduced],
311
  ).then(
 
35
  def preprocess_image(image: Image.Image) -> Tuple[str, Image.Image]:
36
  """
37
  Preprocess the input image.
 
 
 
 
 
 
 
38
  """
39
  processed_image = pipeline.preprocess_image(image)
40
  return processed_image
 
95
  req: gr.Request,
96
  ) -> Tuple[dict, str]:
97
  """
98
+ Convert an image to a 3D model.
99
  """
100
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
101
 
102
+ # First stage: Generate sparse structure
103
  outputs = pipeline.run(
104
  image,
105
  seed=seed,
 
114
  "cfg_strength": slat_guidance_strength,
115
  },
116
  )
 
 
 
117
 
118
+ # Clear CUDA cache after structure generation
119
+ torch.cuda.empty_cache()
120
+
121
+ # Second stage: Generate video preview in batches
122
+ video_frames = []
123
+ video_geo_frames = []
124
  batch_size = 30 # Process 30 frames at a time
125
  num_frames = 120
126
 
127
  for i in range(0, num_frames, batch_size):
128
  end_idx = min(i + batch_size, num_frames)
 
 
 
129
  batch_frames = render_utils.render_video(
130
+ outputs['gaussian'][0],
131
+ num_frames=end_idx - i,
132
  start_frame=i
133
  )['color']
134
+ video_frames.extend(batch_frames)
135
 
 
136
  batch_geo = render_utils.render_video(
137
+ outputs['mesh'][0],
138
+ num_frames=end_idx - i,
139
  start_frame=i
140
  )['normal']
141
+ video_geo_frames.extend(batch_geo)
142
 
143
  # Clear cache after each batch
144
  torch.cuda.empty_cache()
145
+
146
+ # Combine frames and save video
147
+ video = [np.concatenate([video_frames[i], video_geo_frames[i]], axis=1)
148
+ for i in range(len(video_frames))]
149
  trial_id = str(uuid.uuid4())
150
  video_path = os.path.join(user_dir, f"{trial_id}.mp4")
151
  imageio.mimsave(video_path, video, fps=15)
152
 
153
+ # Clear video data
154
+ del video_frames
155
+ del video_geo_frames
156
  del video
 
157
  torch.cuda.empty_cache()
158
 
159
+ # Pack state
160
  state = pack_state(outputs['gaussian'][0], outputs['mesh'][0], trial_id)
161
  return state, video_path
162
 
163
  @spaces.GPU
164
+ def extract_high_quality_glb(
165
  state: dict,
166
  req: gr.Request,
167
  ) -> Tuple[str, str]:
168
  """
169
+ Extract a high-quality GLB file with memory optimization.
170
  """
171
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
172
  gs, mesh, trial_id = unpack_state(state)
173
 
174
+ # Clear cache before GLB generation
175
  torch.cuda.empty_cache()
176
 
177
+ # Process mesh in original quality (no reduction)
178
  glb = postprocessing_utils.to_glb(
179
  gs,
180
  mesh,
 
184
  texture_size=2048, # Maximum texture resolution
185
  verbose=True # Show progress
186
  )
187
+
188
  glb_path = os.path.join(user_dir, f"{trial_id}_full.glb")
189
  glb.export(glb_path)
190
 
191
+ # Final cleanup
192
  torch.cuda.empty_cache()
193
  return glb_path, glb_path
194
 
195
  @spaces.GPU
196
+ def extract_reduced_glb(
197
  state: dict,
198
  mesh_simplify: float,
199
  texture_size: int,
200
  req: gr.Request,
201
  ) -> Tuple[str, str]:
202
  """
203
+ Extract a reduced-quality GLB file with memory optimization.
204
  """
205
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
206
  gs, mesh, trial_id = unpack_state(state)
207
+
208
+ # Clear cache before GLB generation
209
+ torch.cuda.empty_cache()
210
+
211
+ glb = postprocessing_utils.to_glb(
212
+ gs,
213
+ mesh,
214
+ simplify=mesh_simplify,
215
+ texture_size=texture_size,
216
+ verbose=True
217
+ )
218
+ glb_path = os.path.join(user_dir, f"{trial_id}_reduced.glb")
219
  glb.export(glb_path)
220
+
221
+ # Final cleanup
222
+ torch.cuda.empty_cache()
223
  return glb_path, glb_path
224
 
225
  with gr.Blocks(delete_cache=(600, 600)) as demo:
226
  gr.Markdown("""
227
  ## Image to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
228
+ * Upload an image and click "Generate" to create a 3D asset
229
  * After generation:
230
+ * Click "Extract Full GLB" for maximum quality (no reduction)
231
+ * Or use GLB Extraction Settings for a reduced version
232
  """)
233
 
234
  with gr.Row():
 
248
  slat_sampling_steps = gr.Slider(1, 500, label="Sampling Steps", value=12, step=1)
249
 
250
  generate_btn = gr.Button("Generate")
251
+ extract_full_btn = gr.Button("Extract Full GLB", interactive=False)
252
 
253
  with gr.Accordion(label="GLB Extraction Settings", open=False):
254
  mesh_simplify = gr.Slider(0.0, 0.98, label="Simplify", value=0.95, step=0.01)
255
  texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
256
 
257
+ extract_reduced_btn = gr.Button("Extract Reduced GLB", interactive=False)
258
 
259
  with gr.Column():
260
  video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
 
265
 
266
  output_buf = gr.State()
267
 
268
+ # Example images
269
  with gr.Row():
270
  examples = gr.Examples(
271
  examples=[
 
298
  inputs=[image_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
299
  outputs=[output_buf, video_output],
300
  ).then(
301
+ lambda: [gr.Button(interactive=True), gr.Button(interactive=True)],
302
+ outputs=[extract_full_btn, extract_reduced_btn],
303
  )
304
 
305
+ extract_full_btn.click(
306
+ extract_high_quality_glb,
307
  inputs=[output_buf],
308
  outputs=[model_output, download_full],
309
  ).then(
 
311
  outputs=[download_full],
312
  )
313
 
314
+ extract_reduced_btn.click(
315
+ extract_reduced_glb,
316
  inputs=[output_buf, mesh_simplify, texture_size],
317
  outputs=[model_output, download_reduced],
318
  ).then(