Spaces:
Running
on
Zero
Update app.py
Browse filesKey changes based on the paper's implementation:
Split the processing into clear stages with memory cleanup between each:
Stage 1: Generate sparse structure
Stage 2: Generate video preview in batches
GLB generation when requested
Added batched processing for video frame generation:
Process 30 frames at a time instead of all 120 at once
Clear CUDA cache after each batch
Added explicit memory management:
torch.cuda.empty_cache() calls at key points
Explicit deletion of large temporary data
Clear video data after saving
Separated high-quality and reduced GLB generation into distinct functions
Added progress visibility with verbose=True for GLB generation to track progress
This should help prevent GPU timeouts by:
Breaking up large operations into smaller chunks
Managing memory more efficiently
Cleaning up resources between steps
@@ -35,13 +35,6 @@ def end_session(req: gr.Request):
|
|
35 |
def preprocess_image(image: Image.Image) -> Tuple[str, Image.Image]:
|
36 |
"""
|
37 |
Preprocess the input image.
|
38 |
-
|
39 |
-
Args:
|
40 |
-
image (Image.Image): The input image.
|
41 |
-
|
42 |
-
Returns:
|
43 |
-
str: uuid of the trial.
|
44 |
-
Image.Image: The preprocessed image.
|
45 |
"""
|
46 |
processed_image = pipeline.preprocess_image(image)
|
47 |
return processed_image
|
@@ -102,11 +95,11 @@ def image_to_3d(
|
|
102 |
req: gr.Request,
|
103 |
) -> Tuple[dict, str]:
|
104 |
"""
|
105 |
-
Convert an image to a 3D model
|
106 |
"""
|
107 |
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
108 |
|
109 |
-
# Generate
|
110 |
outputs = pipeline.run(
|
111 |
image,
|
112 |
seed=seed,
|
@@ -121,68 +114,67 @@ def image_to_3d(
|
|
121 |
"cfg_strength": slat_guidance_strength,
|
122 |
},
|
123 |
)
|
124 |
-
|
125 |
-
# Clear CUDA cache after model generation
|
126 |
-
torch.cuda.empty_cache()
|
127 |
|
128 |
-
#
|
129 |
-
|
130 |
-
|
|
|
|
|
|
|
131 |
batch_size = 30 # Process 30 frames at a time
|
132 |
num_frames = 120
|
133 |
|
134 |
for i in range(0, num_frames, batch_size):
|
135 |
end_idx = min(i + batch_size, num_frames)
|
136 |
-
curr_frames = end_idx - i
|
137 |
-
|
138 |
-
# Generate color frames
|
139 |
batch_frames = render_utils.render_video(
|
140 |
-
outputs['gaussian'][0],
|
141 |
-
num_frames=
|
142 |
start_frame=i
|
143 |
)['color']
|
144 |
-
|
145 |
|
146 |
-
# Generate geometry frames
|
147 |
batch_geo = render_utils.render_video(
|
148 |
-
outputs['mesh'][0],
|
149 |
-
num_frames=
|
150 |
start_frame=i
|
151 |
)['normal']
|
152 |
-
|
153 |
|
154 |
# Clear cache after each batch
|
155 |
torch.cuda.empty_cache()
|
156 |
-
|
157 |
-
# Combine and save video
|
158 |
-
video = [np.concatenate([
|
|
|
159 |
trial_id = str(uuid.uuid4())
|
160 |
video_path = os.path.join(user_dir, f"{trial_id}.mp4")
|
161 |
imageio.mimsave(video_path, video, fps=15)
|
162 |
|
163 |
-
# Clear
|
|
|
|
|
164 |
del video
|
165 |
-
del video_geo
|
166 |
torch.cuda.empty_cache()
|
167 |
|
168 |
-
# Pack state
|
169 |
state = pack_state(outputs['gaussian'][0], outputs['mesh'][0], trial_id)
|
170 |
return state, video_path
|
171 |
|
172 |
@spaces.GPU
|
173 |
-
def
|
174 |
state: dict,
|
175 |
req: gr.Request,
|
176 |
) -> Tuple[str, str]:
|
177 |
"""
|
178 |
-
|
179 |
"""
|
180 |
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
181 |
gs, mesh, trial_id = unpack_state(state)
|
182 |
|
183 |
-
# Clear cache before
|
184 |
torch.cuda.empty_cache()
|
185 |
|
|
|
186 |
glb = postprocessing_utils.to_glb(
|
187 |
gs,
|
188 |
mesh,
|
@@ -192,37 +184,51 @@ def export_full_quality_glb(
|
|
192 |
texture_size=2048, # Maximum texture resolution
|
193 |
verbose=True # Show progress
|
194 |
)
|
|
|
195 |
glb_path = os.path.join(user_dir, f"{trial_id}_full.glb")
|
196 |
glb.export(glb_path)
|
197 |
|
198 |
-
#
|
199 |
torch.cuda.empty_cache()
|
200 |
return glb_path, glb_path
|
201 |
|
202 |
@spaces.GPU
|
203 |
-
def
|
204 |
state: dict,
|
205 |
mesh_simplify: float,
|
206 |
texture_size: int,
|
207 |
req: gr.Request,
|
208 |
) -> Tuple[str, str]:
|
209 |
"""
|
210 |
-
Extract a GLB file
|
211 |
"""
|
212 |
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
213 |
gs, mesh, trial_id = unpack_state(state)
|
214 |
-
|
215 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
216 |
glb.export(glb_path)
|
|
|
|
|
|
|
217 |
return glb_path, glb_path
|
218 |
|
219 |
with gr.Blocks(delete_cache=(600, 600)) as demo:
|
220 |
gr.Markdown("""
|
221 |
## Image to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
|
222 |
-
* Upload an image and click "Generate" to create a 3D asset
|
223 |
* After generation:
|
224 |
-
* Click "
|
225 |
-
* Or use GLB Extraction Settings for a reduced
|
226 |
""")
|
227 |
|
228 |
with gr.Row():
|
@@ -242,12 +248,13 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
|
|
242 |
slat_sampling_steps = gr.Slider(1, 500, label="Sampling Steps", value=12, step=1)
|
243 |
|
244 |
generate_btn = gr.Button("Generate")
|
|
|
245 |
|
246 |
with gr.Accordion(label="GLB Extraction Settings", open=False):
|
247 |
mesh_simplify = gr.Slider(0.0, 0.98, label="Simplify", value=0.95, step=0.01)
|
248 |
texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
|
249 |
|
250 |
-
|
251 |
|
252 |
with gr.Column():
|
253 |
video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
|
@@ -258,7 +265,7 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
|
|
258 |
|
259 |
output_buf = gr.State()
|
260 |
|
261 |
-
# Example images
|
262 |
with gr.Row():
|
263 |
examples = gr.Examples(
|
264 |
examples=[
|
@@ -291,12 +298,12 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
|
|
291 |
inputs=[image_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
|
292 |
outputs=[output_buf, video_output],
|
293 |
).then(
|
294 |
-
lambda: [gr.Button(interactive=True), gr.Button(interactive=True)
|
295 |
-
outputs=[
|
296 |
)
|
297 |
|
298 |
-
|
299 |
-
|
300 |
inputs=[output_buf],
|
301 |
outputs=[model_output, download_full],
|
302 |
).then(
|
@@ -304,8 +311,8 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
|
|
304 |
outputs=[download_full],
|
305 |
)
|
306 |
|
307 |
-
|
308 |
-
|
309 |
inputs=[output_buf, mesh_simplify, texture_size],
|
310 |
outputs=[model_output, download_reduced],
|
311 |
).then(
|
|
|
35 |
def preprocess_image(image: Image.Image) -> Tuple[str, Image.Image]:
|
36 |
"""
|
37 |
Preprocess the input image.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
"""
|
39 |
processed_image = pipeline.preprocess_image(image)
|
40 |
return processed_image
|
|
|
95 |
req: gr.Request,
|
96 |
) -> Tuple[dict, str]:
|
97 |
"""
|
98 |
+
Convert an image to a 3D model.
|
99 |
"""
|
100 |
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
101 |
|
102 |
+
# First stage: Generate sparse structure
|
103 |
outputs = pipeline.run(
|
104 |
image,
|
105 |
seed=seed,
|
|
|
114 |
"cfg_strength": slat_guidance_strength,
|
115 |
},
|
116 |
)
|
|
|
|
|
|
|
117 |
|
118 |
+
# Clear CUDA cache after structure generation
|
119 |
+
torch.cuda.empty_cache()
|
120 |
+
|
121 |
+
# Second stage: Generate video preview in batches
|
122 |
+
video_frames = []
|
123 |
+
video_geo_frames = []
|
124 |
batch_size = 30 # Process 30 frames at a time
|
125 |
num_frames = 120
|
126 |
|
127 |
for i in range(0, num_frames, batch_size):
|
128 |
end_idx = min(i + batch_size, num_frames)
|
|
|
|
|
|
|
129 |
batch_frames = render_utils.render_video(
|
130 |
+
outputs['gaussian'][0],
|
131 |
+
num_frames=end_idx - i,
|
132 |
start_frame=i
|
133 |
)['color']
|
134 |
+
video_frames.extend(batch_frames)
|
135 |
|
|
|
136 |
batch_geo = render_utils.render_video(
|
137 |
+
outputs['mesh'][0],
|
138 |
+
num_frames=end_idx - i,
|
139 |
start_frame=i
|
140 |
)['normal']
|
141 |
+
video_geo_frames.extend(batch_geo)
|
142 |
|
143 |
# Clear cache after each batch
|
144 |
torch.cuda.empty_cache()
|
145 |
+
|
146 |
+
# Combine frames and save video
|
147 |
+
video = [np.concatenate([video_frames[i], video_geo_frames[i]], axis=1)
|
148 |
+
for i in range(len(video_frames))]
|
149 |
trial_id = str(uuid.uuid4())
|
150 |
video_path = os.path.join(user_dir, f"{trial_id}.mp4")
|
151 |
imageio.mimsave(video_path, video, fps=15)
|
152 |
|
153 |
+
# Clear video data
|
154 |
+
del video_frames
|
155 |
+
del video_geo_frames
|
156 |
del video
|
|
|
157 |
torch.cuda.empty_cache()
|
158 |
|
159 |
+
# Pack state
|
160 |
state = pack_state(outputs['gaussian'][0], outputs['mesh'][0], trial_id)
|
161 |
return state, video_path
|
162 |
|
163 |
@spaces.GPU
|
164 |
+
def extract_high_quality_glb(
|
165 |
state: dict,
|
166 |
req: gr.Request,
|
167 |
) -> Tuple[str, str]:
|
168 |
"""
|
169 |
+
Extract a high-quality GLB file with memory optimization.
|
170 |
"""
|
171 |
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
172 |
gs, mesh, trial_id = unpack_state(state)
|
173 |
|
174 |
+
# Clear cache before GLB generation
|
175 |
torch.cuda.empty_cache()
|
176 |
|
177 |
+
# Process mesh in original quality (no reduction)
|
178 |
glb = postprocessing_utils.to_glb(
|
179 |
gs,
|
180 |
mesh,
|
|
|
184 |
texture_size=2048, # Maximum texture resolution
|
185 |
verbose=True # Show progress
|
186 |
)
|
187 |
+
|
188 |
glb_path = os.path.join(user_dir, f"{trial_id}_full.glb")
|
189 |
glb.export(glb_path)
|
190 |
|
191 |
+
# Final cleanup
|
192 |
torch.cuda.empty_cache()
|
193 |
return glb_path, glb_path
|
194 |
|
195 |
@spaces.GPU
|
196 |
+
def extract_reduced_glb(
|
197 |
state: dict,
|
198 |
mesh_simplify: float,
|
199 |
texture_size: int,
|
200 |
req: gr.Request,
|
201 |
) -> Tuple[str, str]:
|
202 |
"""
|
203 |
+
Extract a reduced-quality GLB file with memory optimization.
|
204 |
"""
|
205 |
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
206 |
gs, mesh, trial_id = unpack_state(state)
|
207 |
+
|
208 |
+
# Clear cache before GLB generation
|
209 |
+
torch.cuda.empty_cache()
|
210 |
+
|
211 |
+
glb = postprocessing_utils.to_glb(
|
212 |
+
gs,
|
213 |
+
mesh,
|
214 |
+
simplify=mesh_simplify,
|
215 |
+
texture_size=texture_size,
|
216 |
+
verbose=True
|
217 |
+
)
|
218 |
+
glb_path = os.path.join(user_dir, f"{trial_id}_reduced.glb")
|
219 |
glb.export(glb_path)
|
220 |
+
|
221 |
+
# Final cleanup
|
222 |
+
torch.cuda.empty_cache()
|
223 |
return glb_path, glb_path
|
224 |
|
225 |
with gr.Blocks(delete_cache=(600, 600)) as demo:
|
226 |
gr.Markdown("""
|
227 |
## Image to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
|
228 |
+
* Upload an image and click "Generate" to create a 3D asset
|
229 |
* After generation:
|
230 |
+
* Click "Extract Full GLB" for maximum quality (no reduction)
|
231 |
+
* Or use GLB Extraction Settings for a reduced version
|
232 |
""")
|
233 |
|
234 |
with gr.Row():
|
|
|
248 |
slat_sampling_steps = gr.Slider(1, 500, label="Sampling Steps", value=12, step=1)
|
249 |
|
250 |
generate_btn = gr.Button("Generate")
|
251 |
+
extract_full_btn = gr.Button("Extract Full GLB", interactive=False)
|
252 |
|
253 |
with gr.Accordion(label="GLB Extraction Settings", open=False):
|
254 |
mesh_simplify = gr.Slider(0.0, 0.98, label="Simplify", value=0.95, step=0.01)
|
255 |
texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
|
256 |
|
257 |
+
extract_reduced_btn = gr.Button("Extract Reduced GLB", interactive=False)
|
258 |
|
259 |
with gr.Column():
|
260 |
video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
|
|
|
265 |
|
266 |
output_buf = gr.State()
|
267 |
|
268 |
+
# Example images
|
269 |
with gr.Row():
|
270 |
examples = gr.Examples(
|
271 |
examples=[
|
|
|
298 |
inputs=[image_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
|
299 |
outputs=[output_buf, video_output],
|
300 |
).then(
|
301 |
+
lambda: [gr.Button(interactive=True), gr.Button(interactive=True)],
|
302 |
+
outputs=[extract_full_btn, extract_reduced_btn],
|
303 |
)
|
304 |
|
305 |
+
extract_full_btn.click(
|
306 |
+
extract_high_quality_glb,
|
307 |
inputs=[output_buf],
|
308 |
outputs=[model_output, download_full],
|
309 |
).then(
|
|
|
311 |
outputs=[download_full],
|
312 |
)
|
313 |
|
314 |
+
extract_reduced_btn.click(
|
315 |
+
extract_reduced_glb,
|
316 |
inputs=[output_buf, mesh_simplify, texture_size],
|
317 |
outputs=[model_output, download_reduced],
|
318 |
).then(
|