Ryukijano commited on
Commit
b789e6e
·
verified ·
1 Parent(s): a4b32a9

Set GPU duration to maximum allowed value for ZeroGPU spaces

Browse files

- Updated the `@spaces.GPU` decorator to allocate GPU for the maximum allowed duration of 120 seconds.
- This change was made based on findings regarding the limitations of ZeroGPU spaces, where the maximum supported duration is 120 seconds for function execution.
- Ensured that the GPU allocation is optimal to prevent errors and align with Hugging Face ZeroGPU's current constraints.

This update ensures that the GPU resource is utilized effectively for the given time without exceeding the allocation limits, which might lead to failures or unfulfilled GPU requests during extended operations.

Files changed (1) hide show
  1. app.py +273 -8
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import sys
2
  import spaces
3
- sys.path.append("flash3d")
4
 
5
  from omegaconf import OmegaConf
6
  import gradio as gr
@@ -8,61 +8,295 @@ import torch
8
  import torchvision.transforms as TT
9
  import torchvision.transforms.functional as TTF
10
  from huggingface_hub import hf_hub_download
 
11
 
12
  from networks.gaussian_predictor import GaussianPredictor
13
  from util.vis3d import save_ply
14
 
15
  def main():
 
 
16
  if torch.cuda.is_available():
17
  device = "cuda:0"
 
18
  else:
19
  device = "cpu"
 
20
 
 
 
21
  model_cfg_path = hf_hub_download(repo_id="einsafutdinov/flash3d",
22
  filename="config_re10k_v1.yaml")
 
23
  model_path = hf_hub_download(repo_id="einsafutdinov/flash3d",
24
  filename="model_re10k_v1.pth")
25
 
 
 
26
  cfg = OmegaConf.load(model_cfg_path)
 
 
 
27
  model = GaussianPredictor(cfg)
28
  device = torch.device(device)
29
- model.to(device)
 
 
 
30
  model.load_model(model_path)
31
 
32
- pad_border_fn = TT.Pad((cfg.dataset.pad_border_aug, cfg.dataset.pad_border_aug))
33
- to_tensor = TT.ToTensor()
 
34
 
 
35
  def check_input_image(input_image):
 
36
  if input_image is None:
 
37
  raise gr.Error("No image uploaded!")
 
38
 
 
39
  def preprocess(image):
 
 
40
  image = TTF.resize(
41
  image, (cfg.dataset.height, cfg.dataset.width),
42
  interpolation=TT.InterpolationMode.BICUBIC
43
  )
 
44
  image = pad_border_fn(image)
 
45
  return image
46
 
47
- @spaces.GPU(duration=120)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  def reconstruct_and_export(image):
49
  """
50
  Passes image through model, outputs reconstruction in form of a dict of tensors.
51
  """
 
 
52
  image = to_tensor(image).to(device).unsqueeze(0)
53
  inputs = {
54
  ("color_aug", 0, 0): image,
55
  }
56
 
 
 
57
  outputs = model(inputs)
58
 
59
- # export reconstruction to ply
 
60
  save_ply(outputs, ply_out_path, num_gauss=2)
 
61
 
62
  return ply_out_path
63
 
 
64
  ply_out_path = f'./mesh.ply'
65
 
 
66
  css = """
67
  h1 {
68
  text-align: center;
@@ -70,15 +304,38 @@ def main():
70
  }
71
  """
72
 
 
73
  with gr.Blocks(css=css) as demo:
74
  gr.Markdown(
75
  """
76
  # Flash3D
77
  """
78
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  with gr.Row(variant="panel"):
80
  with gr.Column(scale=1):
81
  with gr.Row():
 
82
  input_image = gr.Image(
83
  label="Input Image",
84
  image_mode="RGBA",
@@ -87,9 +344,11 @@ def main():
87
  elem_id="content_image",
88
  )
89
  with gr.Row():
 
90
  submit = gr.Button("Generate", elem_id="generate", variant="primary")
91
 
92
  with gr.Row(variant="panel"):
 
93
  gr.Examples(
94
  examples=[
95
  './demo_examples/bedroom_01.png',
@@ -106,17 +365,20 @@ def main():
106
  )
107
 
108
  with gr.Row():
 
109
  processed_image = gr.Image(label="Processed Image", interactive=False)
110
 
111
  with gr.Column(scale=2):
112
  with gr.Row():
113
  with gr.Tab("Reconstruction"):
 
114
  output_model = gr.Model3D(
115
  height=512,
116
  label="Output Model",
117
  interactive=False
118
  )
119
 
 
120
  submit.click(fn=check_input_image, inputs=[input_image]).success(
121
  fn=preprocess,
122
  inputs=[input_image],
@@ -127,8 +389,11 @@ def main():
127
  outputs=[output_model],
128
  )
129
 
 
130
  demo.queue(max_size=1)
131
- demo.launch(share=True)
 
132
 
133
  if __name__ == "__main__":
 
134
  main()
 
1
  import sys
2
  import spaces
3
+ sys.path.append("flash3d") # Add the flash3d directory to the system path for importing local modules
4
 
5
  from omegaconf import OmegaConf
6
  import gradio as gr
 
8
  import torchvision.transforms as TT
9
  import torchvision.transforms.functional as TTF
10
  from huggingface_hub import hf_hub_download
11
+ import numpy as np
12
 
13
  from networks.gaussian_predictor import GaussianPredictor
14
  from util.vis3d import save_ply
15
 
16
  def main():
17
+ print("[INFO] Starting main function...")
18
+ # Determine if CUDA (GPU) is available and set the device accordingly
19
  if torch.cuda.is_available():
20
  device = "cuda:0"
21
+ print("[INFO] CUDA is available. Using GPU device.")
22
  else:
23
  device = "cpu"
24
+ print("[INFO] CUDA is not available. Using CPU device.")
25
 
26
+ # Download model configuration and weights from Hugging Face Hub
27
+ print("[INFO] Downloading model configuration...")
28
  model_cfg_path = hf_hub_download(repo_id="einsafutdinov/flash3d",
29
  filename="config_re10k_v1.yaml")
30
+ print("[INFO] Downloading model weights...")
31
  model_path = hf_hub_download(repo_id="einsafutdinov/flash3d",
32
  filename="model_re10k_v1.pth")
33
 
34
+ # Load model configuration using OmegaConf
35
+ print("[INFO] Loading model configuration...")
36
  cfg = OmegaConf.load(model_cfg_path)
37
+
38
+ # Initialize the GaussianPredictor model with the loaded configuration
39
+ print("[INFO] Initializing GaussianPredictor model...")
40
  model = GaussianPredictor(cfg)
41
  device = torch.device(device)
42
+ model.to(device) # Move the model to the specified device (CPU or GPU)
43
+
44
+ # Load the pre-trained model weights
45
+ print("[INFO] Loading model weights...")
46
  model.load_model(model_path)
47
 
48
+ # Define transformation functions for image preprocessing
49
+ pad_border_fn = TT.Pad((cfg.dataset.pad_border_aug, cfg.dataset.pad_border_aug)) # Padding to augment the image borders
50
+ to_tensor = TT.ToTensor() # Convert image to tensor
51
 
52
+ # Function to check if an image is uploaded by the user
53
  def check_input_image(input_image):
54
+ print("[DEBUG] Checking input image...")
55
  if input_image is None:
56
+ print("[ERROR] No image uploaded!")
57
  raise gr.Error("No image uploaded!")
58
+ print("[INFO] Input image is valid.")
59
 
60
+ # Function to preprocess the input image before passing it to the model
61
  def preprocess(image):
62
+ print("[DEBUG] Preprocessing image...")
63
+ # Resize the image to the desired height and width specified in the configuration
64
  image = TTF.resize(
65
  image, (cfg.dataset.height, cfg.dataset.width),
66
  interpolation=TT.InterpolationMode.BICUBIC
67
  )
68
+ # Apply padding to the image
69
  image = pad_border_fn(image)
70
+ print("[INFO] Image preprocessing complete.")
71
  return image
72
 
73
+ # Function to reconstruct the 3D model from the input image and export it as a PLY file
74
+ import sys
75
+ import spaces
76
+ sys.path.append("flash3d") # Add the flash3d directory to the system path for importing local modules
77
+
78
+ from omegaconf import OmegaConf
79
+ import gradio as gr
80
+ import torch
81
+ import torchvision.transforms as TT
82
+ import torchvision.transforms.functional as TTF
83
+ from huggingface_hub import hf_hub_download
84
+ import numpy as np
85
+
86
+ from networks.gaussian_predictor import GaussianPredictor
87
+ from util.vis3d import save_ply
88
+
89
+ def main():
90
+ print("[INFO] Starting main function...")
91
+ # Determine if CUDA (GPU) is available and set the device accordingly
92
+ if torch.cuda.is_available():
93
+ device = "cuda:0"
94
+ print("[INFO] CUDA is available. Using GPU device.")
95
+ else:
96
+ device = "cpu"
97
+ print("[INFO] CUDA is not available. Using CPU device.")
98
+
99
+ # Download model configuration and weights from Hugging Face Hub
100
+ print("[INFO] Downloading model configuration...")
101
+ model_cfg_path = hf_hub_download(repo_id="einsafutdinov/flash3d",
102
+ filename="config_re10k_v1.yaml")
103
+ print("[INFO] Downloading model weights...")
104
+ model_path = hf_hub_download(repo_id="einsafutdinov/flash3d",
105
+ filename="model_re10k_v1.pth")
106
+
107
+ # Load model configuration using OmegaConf
108
+ print("[INFO] Loading model configuration...")
109
+ cfg = OmegaConf.load(model_cfg_path)
110
+
111
+ # Initialize the GaussianPredictor model with the loaded configuration
112
+ print("[INFO] Initializing GaussianPredictor model...")
113
+ model = GaussianPredictor(cfg)
114
+ device = torch.device(device)
115
+ model.to(device) # Move the model to the specified device (CPU or GPU)
116
+
117
+ # Load the pre-trained model weights
118
+ print("[INFO] Loading model weights...")
119
+ model.load_model(model_path)
120
+
121
+ # Define transformation functions for image preprocessing
122
+ pad_border_fn = TT.Pad((cfg.dataset.pad_border_aug, cfg.dataset.pad_border_aug)) # Padding to augment the image borders
123
+ to_tensor = TT.ToTensor() # Convert image to tensor
124
+
125
+ # Function to check if an image is uploaded by the user
126
+ def check_input_image(input_image):
127
+ print("[DEBUG] Checking input image...")
128
+ if input_image is None:
129
+ print("[ERROR] No image uploaded!")
130
+ raise gr.Error("No image uploaded!")
131
+ print("[INFO] Input image is valid.")
132
+
133
+ # Function to preprocess the input image before passing it to the model
134
+ def preprocess(image):
135
+ print("[DEBUG] Preprocessing image...")
136
+ # Resize the image to the desired height and width specified in the configuration
137
+ image = TTF.resize(
138
+ image, (cfg.dataset.height, cfg.dataset.width),
139
+ interpolation=TT.InterpolationMode.BICUBIC
140
+ )
141
+ # Apply padding to the image
142
+ image = pad_border_fn(image)
143
+ print("[INFO] Image preprocessing complete.")
144
+ return image
145
+
146
+ # Function to reconstruct the 3D model from the input image and export it as a PLY file
147
+ @spaces.GPU(duration=120) # Decorator to allocate a GPU for this function during execution
148
+ def reconstruct_and_export(image):
149
+ """
150
+ Passes image through model, outputs reconstruction in form of a dict of tensors.
151
+ """
152
+ print("[DEBUG] Starting reconstruction and export...")
153
+ # Convert the preprocessed image to a tensor and move it to the specified device
154
+ image = to_tensor(image).to(device).unsqueeze(0)
155
+ inputs = {
156
+ ("color_aug", 0, 0): image,
157
+ }
158
+
159
+ # Pass the image through the model to get the output
160
+ print("[INFO] Passing image through the model...")
161
+ outputs = model(inputs)
162
+
163
+ # Export the reconstruction to a PLY file
164
+ print(f"[INFO] Saving output to {ply_out_path}...")
165
+ save_ply(outputs, ply_out_path, num_gauss=2)
166
+ print("[INFO] Reconstruction and export complete.")
167
+
168
+ return ply_out_path
169
+
170
+ # Path to save the output PLY file
171
+ ply_out_path = f'./mesh.ply'
172
+
173
+ # CSS styling for the Gradio interface
174
+ css = """
175
+ h1 {
176
+ text-align: center;
177
+ display:block;
178
+ }
179
+ """
180
+
181
+ # Create the Gradio user interface
182
+ with gr.Blocks(css=css) as demo:
183
+ gr.Markdown(
184
+ """
185
+ # Flash3D
186
+ """
187
+ )
188
+ # Comments about the app's behavior and known limitations
189
+ gr.Markdown(
190
+ """
191
+ ## Comments:
192
+ 1. If you run the demo online, the first example you upload should take about 4.5 seconds (with preprocessing, saving and overhead), the following take about 1.5s.
193
+ 2. The 3D viewer shows a .ply mesh extracted from a mix of 3D Gaussians. This is only an approximation and artifacts might show.
194
+ 3. Known limitations include:
195
+ - A black dot appearing on the model from some viewpoints.
196
+ - See-through parts of objects, especially on the back: this is due to the model performing less well on more complicated shapes.
197
+ - Back of objects are blurry: this is a model limitation due to it being deterministic.
198
+ 4. Our model is of comparable quality to state-of-the-art methods, and is **much** cheaper to train and run.
199
+ ## How does it work?
200
+ Splatter Image formulates 3D reconstruction as an image-to-image translation task. It maps the input image to another image,
201
+ in which every pixel represents one 3D Gaussian and the channels of the output represent parameters of these Gaussians, including their shapes, colours, and locations.
202
+ The resulting image thus represents a set of Gaussians (almost like a point cloud) which reconstruct the shape and colour of the object.
203
+ The method is very cheap: the reconstruction amounts to a single forward pass of a neural network with only 2D operators (2D convolutions and attention).
204
+ The rendering is also very fast, due to using Gaussian Splatting.
205
+ Combined, this results in very cheap training and high-quality results.
206
+ For more results see the [project page](https://szymanowiczs.github.io/splatter-image) and the [CVPR article](https://arxiv.org/abs/2312.13150).
207
+ """
208
+ )
209
+ with gr.Row(variant="panel"):
210
+ with gr.Column(scale=1):
211
+ with gr.Row():
212
+ # Input image component for the user to upload an image
213
+ input_image = gr.Image(
214
+ label="Input Image",
215
+ image_mode="RGBA",
216
+ sources="upload",
217
+ type="pil",
218
+ elem_id="content_image",
219
+ )
220
+ with gr.Row():
221
+ # Button to trigger the generation process
222
+ submit = gr.Button("Generate", elem_id="generate", variant="primary")
223
+
224
+ with gr.Row(variant="panel"):
225
+ # Examples panel to provide sample images for users
226
+ gr.Examples(
227
+ examples=[
228
+ './demo_examples/bedroom_01.png',
229
+ './demo_examples/kitti_02.png',
230
+ './demo_examples/kitti_03.png',
231
+ './demo_examples/re10k_04.jpg',
232
+ './demo_examples/re10k_05.jpg',
233
+ './demo_examples/re10k_06.jpg',
234
+ ],
235
+ inputs=[input_image],
236
+ cache_examples=False,
237
+ label="Examples",
238
+ examples_per_page=20,
239
+ )
240
+
241
+ with gr.Row():
242
+ # Display the preprocessed image (after resizing and padding)
243
+ processed_image = gr.Image(label="Processed Image", interactive=False)
244
+
245
+ with gr.Column(scale=2):
246
+ with gr.Row():
247
+ with gr.Tab("Reconstruction"):
248
+ # 3D model viewer to display the reconstructed model
249
+ output_model = gr.Model3D(
250
+ height=512,
251
+ label="Output Model",
252
+ interactive=False
253
+ )
254
+
255
+ # Define the workflow for the Generate button
256
+ submit.click(fn=check_input_image, inputs=[input_image]).success(
257
+ fn=preprocess,
258
+ inputs=[input_image],
259
+ outputs=[processed_image],
260
+ ).success(
261
+ fn=reconstruct_and_export,
262
+ inputs=[processed_image],
263
+ outputs=[output_model],
264
+ )
265
+
266
+ # Queue the requests to handle them sequentially (to avoid GPU resource conflicts)
267
+ demo.queue(max_size=1)
268
+ print("[INFO] Launching Gradio demo...")
269
+ demo.launch(share=True) # Launch the Gradio interface and allow public sharing
270
+
271
+ if __name__ == "__main__":
272
+ print("[INFO] Running application...")
273
+ main() # Decorator to allocate a GPU for this function during execution
274
  def reconstruct_and_export(image):
275
  """
276
  Passes image through model, outputs reconstruction in form of a dict of tensors.
277
  """
278
+ print("[DEBUG] Starting reconstruction and export...")
279
+ # Convert the preprocessed image to a tensor and move it to the specified device
280
  image = to_tensor(image).to(device).unsqueeze(0)
281
  inputs = {
282
  ("color_aug", 0, 0): image,
283
  }
284
 
285
+ # Pass the image through the model to get the output
286
+ print("[INFO] Passing image through the model...")
287
  outputs = model(inputs)
288
 
289
+ # Export the reconstruction to a PLY file
290
+ print(f"[INFO] Saving output to {ply_out_path}...")
291
  save_ply(outputs, ply_out_path, num_gauss=2)
292
+ print("[INFO] Reconstruction and export complete.")
293
 
294
  return ply_out_path
295
 
296
+ # Path to save the output PLY file
297
  ply_out_path = f'./mesh.ply'
298
 
299
+ # CSS styling for the Gradio interface
300
  css = """
301
  h1 {
302
  text-align: center;
 
304
  }
305
  """
306
 
307
+ # Create the Gradio user interface
308
  with gr.Blocks(css=css) as demo:
309
  gr.Markdown(
310
  """
311
  # Flash3D
312
  """
313
+ )
314
+ # Comments about the app's behavior and known limitations
315
+ gr.Markdown(
316
+ """
317
+ ## Comments:
318
+ 1. If you run the demo online, the first example you upload should take about 4.5 seconds (with preprocessing, saving and overhead), the following take about 1.5s.
319
+ 2. The 3D viewer shows a .ply mesh extracted from a mix of 3D Gaussians. This is only an approximation and artifacts might show.
320
+ 3. Known limitations include:
321
+ - A black dot appearing on the model from some viewpoints.
322
+ - See-through parts of objects, especially on the back: this is due to the model performing less well on more complicated shapes.
323
+ - Back of objects are blurry: this is a model limitation due to it being deterministic.
324
+ 4. Our model is of comparable quality to state-of-the-art methods, and is **much** cheaper to train and run.
325
+ ## How does it work?
326
+ Splatter Image formulates 3D reconstruction as an image-to-image translation task. It maps the input image to another image,
327
+ in which every pixel represents one 3D Gaussian and the channels of the output represent parameters of these Gaussians, including their shapes, colours, and locations.
328
+ The resulting image thus represents a set of Gaussians (almost like a point cloud) which reconstruct the shape and colour of the object.
329
+ The method is very cheap: the reconstruction amounts to a single forward pass of a neural network with only 2D operators (2D convolutions and attention).
330
+ The rendering is also very fast, due to using Gaussian Splatting.
331
+ Combined, this results in very cheap training and high-quality results.
332
+ For more results see the [project page](https://szymanowiczs.github.io/splatter-image) and the [CVPR article](https://arxiv.org/abs/2312.13150).
333
+ """
334
+ )
335
  with gr.Row(variant="panel"):
336
  with gr.Column(scale=1):
337
  with gr.Row():
338
+ # Input image component for the user to upload an image
339
  input_image = gr.Image(
340
  label="Input Image",
341
  image_mode="RGBA",
 
344
  elem_id="content_image",
345
  )
346
  with gr.Row():
347
+ # Button to trigger the generation process
348
  submit = gr.Button("Generate", elem_id="generate", variant="primary")
349
 
350
  with gr.Row(variant="panel"):
351
+ # Examples panel to provide sample images for users
352
  gr.Examples(
353
  examples=[
354
  './demo_examples/bedroom_01.png',
 
365
  )
366
 
367
  with gr.Row():
368
+ # Display the preprocessed image (after resizing and padding)
369
  processed_image = gr.Image(label="Processed Image", interactive=False)
370
 
371
  with gr.Column(scale=2):
372
  with gr.Row():
373
  with gr.Tab("Reconstruction"):
374
+ # 3D model viewer to display the reconstructed model
375
  output_model = gr.Model3D(
376
  height=512,
377
  label="Output Model",
378
  interactive=False
379
  )
380
 
381
+ # Define the workflow for the Generate button
382
  submit.click(fn=check_input_image, inputs=[input_image]).success(
383
  fn=preprocess,
384
  inputs=[input_image],
 
389
  outputs=[output_model],
390
  )
391
 
392
+ # Queue the requests to handle them sequentially (to avoid GPU resource conflicts)
393
  demo.queue(max_size=1)
394
+ print("[INFO] Launching Gradio demo...")
395
+ demo.launch(share=True) # Launch the Gradio interface and allow public sharing
396
 
397
  if __name__ == "__main__":
398
+ print("[INFO] Running application...")
399
  main()