Spaces:
Running
on
Zero
Running
on
Zero
Remove 3dgs for now
Browse files- app.py +5 -4
- demo/gs_train.py +198 -201
app.py
CHANGED
@@ -8,7 +8,7 @@ import gradio as gr
|
|
8 |
from mast3r.demo import get_args_parser
|
9 |
from mast3r.utils.misc import hash_md5
|
10 |
from mast3r_demo import mast3r_demo_tab
|
11 |
-
from gs_demo import gs_demo_tab
|
12 |
|
13 |
if __name__ == '__main__':
|
14 |
parser = get_args_parser()
|
@@ -20,6 +20,7 @@ if __name__ == '__main__':
|
|
20 |
server_name = '0.0.0.0' if args.local_network else '127.0.0.1'
|
21 |
|
22 |
weights_path = args.weights if args.weights is not None else "naver/" + args.model_name
|
|
|
23 |
chkpt_tag = hash_md5(weights_path)
|
24 |
|
25 |
with tempfile.TemporaryDirectory(suffix='demo') as tmpdirname:
|
@@ -29,9 +30,9 @@ if __name__ == '__main__':
|
|
29 |
with gr.Blocks() as demo:
|
30 |
with gr.Tabs():
|
31 |
with gr.Tab("MASt3R Demo"):
|
32 |
-
mast3r_demo_tab(cache_path, weights_path,
|
33 |
-
with gr.Tab("Gaussian Splatting Demo"):
|
34 |
-
|
35 |
|
36 |
demo.launch(server_name=server_name, server_port=args.server_port)
|
37 |
|
|
|
8 |
from mast3r.demo import get_args_parser
|
9 |
from mast3r.utils.misc import hash_md5
|
10 |
from mast3r_demo import mast3r_demo_tab
|
11 |
+
# from gs_demo import gs_demo_tab
|
12 |
|
13 |
if __name__ == '__main__':
|
14 |
parser = get_args_parser()
|
|
|
20 |
server_name = '0.0.0.0' if args.local_network else '127.0.0.1'
|
21 |
|
22 |
weights_path = args.weights if args.weights is not None else "naver/" + args.model_name
|
23 |
+
device = device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
24 |
chkpt_tag = hash_md5(weights_path)
|
25 |
|
26 |
with tempfile.TemporaryDirectory(suffix='demo') as tmpdirname:
|
|
|
30 |
with gr.Blocks() as demo:
|
31 |
with gr.Tabs():
|
32 |
with gr.Tab("MASt3R Demo"):
|
33 |
+
mast3r_demo_tab(cache_path, weights_path, device)
|
34 |
+
# with gr.Tab("Gaussian Splatting Demo"):
|
35 |
+
# gs_demo_tab(cache_path)
|
36 |
|
37 |
demo.launch(server_name=server_name, server_port=args.server_port)
|
38 |
|
demo/gs_train.py
CHANGED
@@ -13,21 +13,19 @@ sys.path.append(gaussian_splatting_path)
|
|
13 |
|
14 |
# Import necessary modules from the gaussian-splatting directory
|
15 |
from utils.loss_utils import l1_loss, ssim
|
16 |
-
# from gaussian_renderer import render, network_gui
|
17 |
from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer
|
18 |
-
|
19 |
-
# from scene import Scene, GaussianModel
|
20 |
from utils.general_utils import safe_state
|
21 |
from utils.image_utils import psnr
|
22 |
|
23 |
-
#
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
|
28 |
-
#
|
29 |
-
|
30 |
-
|
31 |
|
32 |
from dataclasses import dataclass, field
|
33 |
|
@@ -89,204 +87,203 @@ def train(
|
|
89 |
densify_from_iter, densify_until_iter, densify_grad_threshold, random_background
|
90 |
):
|
91 |
print(data_source_path)
|
92 |
-
#
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
|
289 |
-
|
290 |
|
291 |
-
|
292 |
-
return None, None
|
|
|
13 |
|
14 |
# Import necessary modules from the gaussian-splatting directory
|
15 |
from utils.loss_utils import l1_loss, ssim
|
|
|
16 |
from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer
|
17 |
+
from scene import Scene, GaussianModel
|
|
|
18 |
from utils.general_utils import safe_state
|
19 |
from utils.image_utils import psnr
|
20 |
|
21 |
+
# Dynamically import the train module from the gaussian-splatting directory
|
22 |
+
train_spec = importlib.util.spec_from_file_location("gaussian_splatting_train", os.path.join(gaussian_splatting_path, "train.py"))
|
23 |
+
gaussian_splatting_train = importlib.util.module_from_spec(train_spec)
|
24 |
+
train_spec.loader.exec_module(gaussian_splatting_train)
|
25 |
|
26 |
+
# Import the necessary functions from the dynamically loaded module
|
27 |
+
prepare_output_and_logger = gaussian_splatting_train.prepare_output_and_logger
|
28 |
+
training_report = gaussian_splatting_train.training_report
|
29 |
|
30 |
from dataclasses import dataclass, field
|
31 |
|
|
|
87 |
densify_from_iter, densify_until_iter, densify_grad_threshold, random_background
|
88 |
):
|
89 |
print(data_source_path)
|
90 |
+
# Create instances of the parameter dataclasses
|
91 |
+
dataset = ModelParams(
|
92 |
+
sh_degree=sh_degree,
|
93 |
+
source_path=data_source_path,
|
94 |
+
model_path=model_path,
|
95 |
+
images=images,
|
96 |
+
resolution=resolution,
|
97 |
+
white_background=white_background,
|
98 |
+
data_device=data_device,
|
99 |
+
eval=eval
|
100 |
+
)
|
101 |
|
102 |
+
pipe = PipelineParams(
|
103 |
+
convert_SHs_python=convert_SHs_python,
|
104 |
+
compute_cov3D_python=compute_cov3D_python,
|
105 |
+
debug=debug
|
106 |
+
)
|
107 |
|
108 |
+
opt = OptimizationParams(
|
109 |
+
iterations=iterations,
|
110 |
+
position_lr_init=position_lr_init,
|
111 |
+
position_lr_final=position_lr_final,
|
112 |
+
position_lr_delay_mult=position_lr_delay_mult,
|
113 |
+
position_lr_max_steps=position_lr_max_steps,
|
114 |
+
feature_lr=feature_lr,
|
115 |
+
opacity_lr=opacity_lr,
|
116 |
+
scaling_lr=scaling_lr,
|
117 |
+
rotation_lr=rotation_lr,
|
118 |
+
percent_dense=percent_dense,
|
119 |
+
lambda_dssim=lambda_dssim,
|
120 |
+
densification_interval=densification_interval,
|
121 |
+
opacity_reset_interval=opacity_reset_interval,
|
122 |
+
densify_from_iter=densify_from_iter,
|
123 |
+
densify_until_iter=densify_until_iter,
|
124 |
+
densify_grad_threshold=densify_grad_threshold,
|
125 |
+
random_background=random_background
|
126 |
+
)
|
127 |
+
|
128 |
+
args = TrainingArgs()
|
129 |
+
|
130 |
+
testing_iterations = args.test_iterations
|
131 |
+
saving_iterations = args.save_iterations
|
132 |
+
checkpoint_iterations = args.checkpoint_iterations
|
133 |
+
debug_from = args.debug_from
|
134 |
+
|
135 |
+
tb_writer = prepare_output_and_logger(dataset)
|
136 |
|
137 |
+
gaussians = GaussianModel(dataset.sh_degree)
|
138 |
+
scene = Scene(dataset, gaussians)
|
139 |
+
gaussians.training_setup(opt)
|
140 |
+
|
141 |
+
bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
|
142 |
+
background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
|
143 |
+
|
144 |
+
iter_start = torch.cuda.Event(enable_timing = True)
|
145 |
+
iter_end = torch.cuda.Event(enable_timing = True)
|
146 |
+
|
147 |
+
viewpoint_stack = None
|
148 |
+
ema_loss_for_log = 0.0
|
149 |
+
first_iter = 0
|
150 |
+
progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress")
|
151 |
+
first_iter += 1
|
152 |
+
|
153 |
+
point_cloud_path = ""
|
154 |
+
progress = gr.Progress() # Initialize the progress bar
|
155 |
+
for iteration in range(first_iter, opt.iterations + 1):
|
156 |
+
iter_start.record()
|
157 |
+
gaussians.update_learning_rate(iteration)
|
158 |
|
159 |
+
# Every 1000 its we increase the levels of SH up to a maximum degree
|
160 |
+
if iteration % 1000 == 0:
|
161 |
+
gaussians.oneupSHdegree()
|
162 |
|
163 |
+
# Pick a random Camera
|
164 |
+
if not viewpoint_stack:
|
165 |
+
viewpoint_stack = scene.getTrainCameras().copy()
|
166 |
+
viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1))
|
167 |
|
168 |
+
# Render
|
169 |
+
if (iteration - 1) == debug_from:
|
170 |
+
pipe.debug = True
|
171 |
+
bg = torch.rand((3), device="cuda") if opt.random_background else background
|
172 |
|
173 |
+
render_pkg = render(viewpoint_cam, gaussians, pipe, bg)
|
174 |
+
image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"]
|
175 |
+
|
176 |
+
# Loss
|
177 |
+
gt_image = viewpoint_cam.original_image.cuda()
|
178 |
+
Ll1 = l1_loss(image, gt_image)
|
179 |
+
loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image))
|
180 |
+
loss.backward()
|
181 |
+
iter_end.record()
|
182 |
|
183 |
+
with torch.no_grad():
|
184 |
+
# Progress bar
|
185 |
+
ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log
|
186 |
+
if iteration % 10 == 0:
|
187 |
+
progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"})
|
188 |
+
progress_bar.update(10)
|
189 |
+
progress(iteration / opt.iterations) # Update Gradio progress bar
|
190 |
+
if iteration == opt.iterations:
|
191 |
+
progress_bar.close()
|
192 |
+
|
193 |
+
# Log and save
|
194 |
+
training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background))
|
195 |
+
if (iteration == opt.iterations):
|
196 |
+
point_cloud_path = os.path.join(os.path.join(dataset.model_path, "point_cloud/iteration_{}".format(iteration)), "point_cloud.ply")
|
197 |
+
print("\n[ITER {}] Saving Gaussians to {}".format(iteration, point_cloud_path))
|
198 |
+
scene.save(iteration)
|
199 |
+
|
200 |
+
# Densification
|
201 |
+
if iteration < opt.densify_until_iter:
|
202 |
+
# Keep track of max radii in image-space for pruning
|
203 |
+
gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter])
|
204 |
+
gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter)
|
205 |
+
|
206 |
+
if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0:
|
207 |
+
size_threshold = 20 if iteration > opt.opacity_reset_interval else None
|
208 |
+
gaussians.densify_and_prune(opt.densify_grad_threshold, 0.005, scene.cameras_extent, size_threshold)
|
209 |
+
|
210 |
+
if iteration % opt.opacity_reset_interval == 0 or (dataset.white_background and iteration == opt.densify_from_iter):
|
211 |
+
gaussians.reset_opacity()
|
212 |
+
|
213 |
+
# Optimizer step
|
214 |
+
if iteration < opt.iterations:
|
215 |
+
gaussians.optimizer.step()
|
216 |
+
gaussians.optimizer.zero_grad(set_to_none = True)
|
217 |
+
|
218 |
+
if (iteration == opt.iterations):
|
219 |
+
print("\n[ITER {}] Saving Checkpoint".format(iteration))
|
220 |
+
torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth")
|
221 |
+
|
222 |
+
|
223 |
+
from os import makedirs
|
224 |
+
from utils.graphics_utils import focal2fov, fov2focal, getProjectionMatrix
|
225 |
+
import torchvision
|
226 |
+
import subprocess
|
227 |
+
|
228 |
+
@torch.no_grad()
|
229 |
+
def render_path(dataset : ModelParams, iteration : int, pipeline : PipelineParams, render_resize_method='crop'):
|
230 |
+
"""
|
231 |
+
render_resize_method: crop, pad
|
232 |
+
"""
|
233 |
+
gaussians = GaussianModel(dataset.sh_degree)
|
234 |
+
scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False)
|
235 |
+
|
236 |
+
iteration = scene.loaded_iter
|
237 |
+
|
238 |
+
bg_color = [1,1,1] if dataset.white_background else [0, 0, 0]
|
239 |
+
background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
|
240 |
+
|
241 |
+
model_path = dataset.model_path
|
242 |
+
name = "render"
|
243 |
+
|
244 |
+
views = scene.getRenderCameras()
|
245 |
+
|
246 |
+
# print(len(views))
|
247 |
+
render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders")
|
248 |
+
|
249 |
+
makedirs(render_path, exist_ok=True)
|
250 |
+
|
251 |
+
for idx, view in enumerate(tqdm(views, desc="Rendering progress")):
|
252 |
+
if render_resize_method == 'crop':
|
253 |
+
image_size = 256
|
254 |
+
elif render_resize_method == 'pad':
|
255 |
+
image_size = max(view.image_width, view.image_height)
|
256 |
+
else:
|
257 |
+
raise NotImplementedError
|
258 |
+
view.original_image = torch.zeros((3, image_size, image_size), device=view.original_image.device)
|
259 |
+
focal_length_x = fov2focal(view.FoVx, view.image_width)
|
260 |
+
focal_length_y = fov2focal(view.FoVy, view.image_height)
|
261 |
+
view.image_width = image_size
|
262 |
+
view.image_height = image_size
|
263 |
+
view.FoVx = focal2fov(focal_length_x, image_size)
|
264 |
+
view.FoVy = focal2fov(focal_length_y, image_size)
|
265 |
+
view.projection_matrix = getProjectionMatrix(znear=view.znear, zfar=view.zfar, fovX=view.FoVx, fovY=view.FoVy).transpose(0,1).cuda().float()
|
266 |
+
view.full_proj_transform = (view.world_view_transform.unsqueeze(0).bmm(view.projection_matrix.unsqueeze(0))).squeeze(0)
|
267 |
+
|
268 |
+
render_pkg = render(view, gaussians, pipeline, background)
|
269 |
+
rendering = render_pkg["render"]
|
270 |
+
torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png"))
|
271 |
+
|
272 |
+
# Use ffmpeg to output video
|
273 |
+
renders_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders.mp4")
|
274 |
+
# Use ffmpeg to output video
|
275 |
+
subprocess.run(["ffmpeg", "-y",
|
276 |
+
"-framerate", "24",
|
277 |
+
"-i", os.path.join(render_path, "%05d.png"),
|
278 |
+
"-vf", "pad=ceil(iw/2)*2:ceil(ih/2)*2",
|
279 |
+
"-c:v", "libx264",
|
280 |
+
"-pix_fmt", "yuv420p",
|
281 |
+
"-crf", "23",
|
282 |
+
# "-pix_fmt", "yuv420p", # Set pixel format for compatibility
|
283 |
+
renders_path], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
|
284 |
+
)
|
285 |
+
return renders_path
|
286 |
|
287 |
+
renders_path = render_path(dataset, opt.iterations, pipe, render_resize_method='crop')
|
288 |
|
289 |
+
return renders_path, point_cloud_path
|
|