Ryukijano commited on
Commit
53f1f2c
·
verified ·
1 Parent(s): acebad3

Update app.py

Browse files

Pseudocode
Add sliders for resolution and num_gauss to the Gradio interface.
Modify the preprocess function to accept resolution as a parameter.
Modify the reconstruct_and_export function to accept num_gauss as a parameter.
Update the Gradio interface to include the new sliders and pass their values to the respective functions

Files changed (1) hide show
  1. app.py +12 -44
app.py CHANGED
@@ -1,18 +1,3 @@
1
- import sys
2
- import spaces
3
- sys.path.append("flash3d") # Add the flash3d directory to the system path for importing local modules
4
-
5
- from omegaconf import OmegaConf
6
- import gradio as gr
7
- import torch
8
- import torchvision.transforms as TT
9
- import torchvision.transforms.functional as TTF
10
- from huggingface_hub import hf_hub_download
11
- import numpy as np
12
-
13
- from networks.gaussian_predictor import GaussianPredictor
14
- from util.vis3d import save_ply
15
-
16
  def main():
17
  print("[INFO] Starting main function...")
18
  if torch.cuda.is_available():
@@ -32,12 +17,8 @@ def main():
32
 
33
  print("[INFO] Initializing GaussianPredictor model...")
34
  model = GaussianPredictor(cfg)
35
- try:
36
- device = torch.device(device)
37
- model.to(device)
38
- except Exception as e:
39
- print(f"[ERROR] Failed to set device: {e}")
40
- raise
41
 
42
  print("[INFO] Loading model weights...")
43
  model.load_model(model_path)
@@ -52,31 +33,23 @@ def main():
52
  raise gr.Error("No image uploaded!")
53
  print("[INFO] Input image is valid.")
54
 
55
- def preprocess(image, padding_value):
56
  print("[DEBUG] Preprocessing image...")
57
- image = TTF.resize(image, (cfg.dataset.height, cfg.dataset.width), interpolation=TT.InterpolationMode.BICUBIC)
58
- pad_border_fn = TT.Pad((padding_value, padding_value))
59
  image = pad_border_fn(image)
60
  print("[INFO] Image preprocessing complete.")
61
  return image
62
 
63
  @spaces.GPU(duration=120)
64
- def reconstruct_and_export(image, num_gauss, max_sh_degree, scaling_modifier):
65
  print("[DEBUG] Starting reconstruction and export...")
66
  image = to_tensor(image).to(device).unsqueeze(0)
67
  inputs = {("color_aug", 0, 0): image}
68
-
69
  print("[INFO] Passing image through the model...")
70
  outputs = model(inputs)
71
-
72
- gauss_means = outputs[('gauss_means',0, 0)]
73
- if gauss_means.shape[0] % num_gauss != 0:
74
- raise ValueError(f"Shape mismatch: cannot divide axis of length {gauss_means.shape[0]} into chunks of {num_gauss}")
75
-
76
  print(f"[INFO] Saving output to {ply_out_path}...")
77
- save_ply(outputs, ply_out_path, num_gauss=num_gauss, max_sh_degree=max_sh_degree, scaling_modifier=scaling_modifier)
78
  print("[INFO] Reconstruction and export complete.")
79
-
80
  return ply_out_path
81
 
82
  ply_out_path = f'./mesh.ply'
@@ -94,15 +67,9 @@ def main():
94
  with gr.Column(scale=1):
95
  with gr.Row():
96
  input_image = gr.Image(label="Input Image", image_mode="RGBA", sources="upload", type="pil", elem_id="content_image")
97
- with gr.Row():
98
- num_gauss = gr.Slider(minimum=1, maximum=20, step=1, label="Number of Gaussians per Pixel", value=10)
99
- padding_value = gr.Slider(minimum=0, maximum=128, step=8, label="Padding Amount for Output Processing", value=32)
100
- max_sh_degree = gr.Slider(minimum=1, maximum=10, step=1, label="Max SH Degree", value=1)
101
- scaling_modifier = gr.Slider(minimum=0.1, maximum=2.0, step=0.1, label="Scaling Modifier", value=1.0)
102
  with gr.Row():
103
  submit = gr.Button("Generate", elem_id="generate", variant="primary")
104
-
105
- with gr.Row(variant="panel"):
106
  gr.Examples(
107
  examples=[
108
  './demo_examples/bedroom_01.png',
@@ -117,22 +84,23 @@ def main():
117
  label="Examples",
118
  examples_per_page=20,
119
  )
120
-
121
  with gr.Row():
122
  processed_image = gr.Image(label="Processed Image", interactive=False)
123
-
124
  with gr.Column(scale=2):
125
  with gr.Row():
126
  with gr.Tab("Reconstruction"):
127
  output_model = gr.Model3D(height=512, label="Output Model", interactive=False)
 
 
 
128
 
129
  submit.click(fn=check_input_image, inputs=[input_image]).success(
130
  fn=preprocess,
131
- inputs=[input_image, padding_value],
132
  outputs=[processed_image],
133
  ).success(
134
  fn=reconstruct_and_export,
135
- inputs=[processed_image, num_gauss, max_sh_degree, scaling_modifier],
136
  outputs=[output_model],
137
  )
138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  def main():
2
  print("[INFO] Starting main function...")
3
  if torch.cuda.is_available():
 
17
 
18
  print("[INFO] Initializing GaussianPredictor model...")
19
  model = GaussianPredictor(cfg)
20
+ device = torch.device(device)
21
+ model.to(device)
 
 
 
 
22
 
23
  print("[INFO] Loading model weights...")
24
  model.load_model(model_path)
 
33
  raise gr.Error("No image uploaded!")
34
  print("[INFO] Input image is valid.")
35
 
36
+ def preprocess(image, resolution):
37
  print("[DEBUG] Preprocessing image...")
38
+ image = TTF.resize(image, (resolution, resolution), interpolation=TT.InterpolationMode.BICUBIC)
 
39
  image = pad_border_fn(image)
40
  print("[INFO] Image preprocessing complete.")
41
  return image
42
 
43
  @spaces.GPU(duration=120)
44
+ def reconstruct_and_export(image, num_gauss):
45
  print("[DEBUG] Starting reconstruction and export...")
46
  image = to_tensor(image).to(device).unsqueeze(0)
47
  inputs = {("color_aug", 0, 0): image}
 
48
  print("[INFO] Passing image through the model...")
49
  outputs = model(inputs)
 
 
 
 
 
50
  print(f"[INFO] Saving output to {ply_out_path}...")
51
+ save_ply(outputs, ply_out_path, num_gauss=num_gauss)
52
  print("[INFO] Reconstruction and export complete.")
 
53
  return ply_out_path
54
 
55
  ply_out_path = f'./mesh.ply'
 
67
  with gr.Column(scale=1):
68
  with gr.Row():
69
  input_image = gr.Image(label="Input Image", image_mode="RGBA", sources="upload", type="pil", elem_id="content_image")
 
 
 
 
 
70
  with gr.Row():
71
  submit = gr.Button("Generate", elem_id="generate", variant="primary")
72
+ with gr.Row(variant="panel"):
 
73
  gr.Examples(
74
  examples=[
75
  './demo_examples/bedroom_01.png',
 
84
  label="Examples",
85
  examples_per_page=20,
86
  )
 
87
  with gr.Row():
88
  processed_image = gr.Image(label="Processed Image", interactive=False)
 
89
  with gr.Column(scale=2):
90
  with gr.Row():
91
  with gr.Tab("Reconstruction"):
92
  output_model = gr.Model3D(height=512, label="Output Model", interactive=False)
93
+ with gr.Row():
94
+ resolution = gr.Slider(minimum=256, maximum=1024, step=64, label="Image Resolution", value=cfg.dataset.height)
95
+ num_gauss = gr.Slider(minimum=1, maximum=10, step=1, label="Number of Gaussian Components", value=2)
96
 
97
  submit.click(fn=check_input_image, inputs=[input_image]).success(
98
  fn=preprocess,
99
+ inputs=[input_image, resolution],
100
  outputs=[processed_image],
101
  ).success(
102
  fn=reconstruct_and_export,
103
+ inputs=[processed_image, num_gauss],
104
  outputs=[output_model],
105
  )
106