ostapagon commited on
Commit
61ba7d2
·
1 Parent(s): ec1c193

Remove 3dgs for now

Browse files
Files changed (2) hide show
  1. app.py +5 -4
  2. demo/gs_train.py +198 -201
app.py CHANGED
@@ -8,7 +8,7 @@ import gradio as gr
8
  from mast3r.demo import get_args_parser
9
  from mast3r.utils.misc import hash_md5
10
  from mast3r_demo import mast3r_demo_tab
11
- from gs_demo import gs_demo_tab
12
 
13
  if __name__ == '__main__':
14
  parser = get_args_parser()
@@ -20,6 +20,7 @@ if __name__ == '__main__':
20
  server_name = '0.0.0.0' if args.local_network else '127.0.0.1'
21
 
22
  weights_path = args.weights if args.weights is not None else "naver/" + args.model_name
 
23
  chkpt_tag = hash_md5(weights_path)
24
 
25
  with tempfile.TemporaryDirectory(suffix='demo') as tmpdirname:
@@ -29,9 +30,9 @@ if __name__ == '__main__':
29
  with gr.Blocks() as demo:
30
  with gr.Tabs():
31
  with gr.Tab("MASt3R Demo"):
32
- mast3r_demo_tab(cache_path, weights_path, args.device)
33
- with gr.Tab("Gaussian Splatting Demo"):
34
- gs_demo_tab(cache_path)
35
 
36
  demo.launch(server_name=server_name, server_port=args.server_port)
37
 
 
8
  from mast3r.demo import get_args_parser
9
  from mast3r.utils.misc import hash_md5
10
  from mast3r_demo import mast3r_demo_tab
11
+ # from gs_demo import gs_demo_tab
12
 
13
  if __name__ == '__main__':
14
  parser = get_args_parser()
 
20
  server_name = '0.0.0.0' if args.local_network else '127.0.0.1'
21
 
22
  weights_path = args.weights if args.weights is not None else "naver/" + args.model_name
23
+ device = device = 'cuda' if torch.cuda.is_available() else 'cpu'
24
  chkpt_tag = hash_md5(weights_path)
25
 
26
  with tempfile.TemporaryDirectory(suffix='demo') as tmpdirname:
 
30
  with gr.Blocks() as demo:
31
  with gr.Tabs():
32
  with gr.Tab("MASt3R Demo"):
33
+ mast3r_demo_tab(cache_path, weights_path, device)
34
+ # with gr.Tab("Gaussian Splatting Demo"):
35
+ # gs_demo_tab(cache_path)
36
 
37
  demo.launch(server_name=server_name, server_port=args.server_port)
38
 
demo/gs_train.py CHANGED
@@ -13,21 +13,19 @@ sys.path.append(gaussian_splatting_path)
13
 
14
  # Import necessary modules from the gaussian-splatting directory
15
  from utils.loss_utils import l1_loss, ssim
16
- # from gaussian_renderer import render, network_gui
17
  from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer
18
-
19
- # from scene import Scene, GaussianModel
20
  from utils.general_utils import safe_state
21
  from utils.image_utils import psnr
22
 
23
- # # Dynamically import the train module from the gaussian-splatting directory
24
- # train_spec = importlib.util.spec_from_file_location("gaussian_splatting_train", os.path.join(gaussian_splatting_path, "train.py"))
25
- # gaussian_splatting_train = importlib.util.module_from_spec(train_spec)
26
- # train_spec.loader.exec_module(gaussian_splatting_train)
27
 
28
- # # Import the necessary functions from the dynamically loaded module
29
- # prepare_output_and_logger = gaussian_splatting_train.prepare_output_and_logger
30
- # training_report = gaussian_splatting_train.training_report
31
 
32
  from dataclasses import dataclass, field
33
 
@@ -89,204 +87,203 @@ def train(
89
  densify_from_iter, densify_until_iter, densify_grad_threshold, random_background
90
  ):
91
  print(data_source_path)
92
- # # Create instances of the parameter dataclasses
93
- # dataset = ModelParams(
94
- # sh_degree=sh_degree,
95
- # source_path=data_source_path,
96
- # model_path=model_path,
97
- # images=images,
98
- # resolution=resolution,
99
- # white_background=white_background,
100
- # data_device=data_device,
101
- # eval=eval
102
- # )
103
 
104
- # pipe = PipelineParams(
105
- # convert_SHs_python=convert_SHs_python,
106
- # compute_cov3D_python=compute_cov3D_python,
107
- # debug=debug
108
- # )
109
 
110
- # opt = OptimizationParams(
111
- # iterations=iterations,
112
- # position_lr_init=position_lr_init,
113
- # position_lr_final=position_lr_final,
114
- # position_lr_delay_mult=position_lr_delay_mult,
115
- # position_lr_max_steps=position_lr_max_steps,
116
- # feature_lr=feature_lr,
117
- # opacity_lr=opacity_lr,
118
- # scaling_lr=scaling_lr,
119
- # rotation_lr=rotation_lr,
120
- # percent_dense=percent_dense,
121
- # lambda_dssim=lambda_dssim,
122
- # densification_interval=densification_interval,
123
- # opacity_reset_interval=opacity_reset_interval,
124
- # densify_from_iter=densify_from_iter,
125
- # densify_until_iter=densify_until_iter,
126
- # densify_grad_threshold=densify_grad_threshold,
127
- # random_background=random_background
128
- # )
129
-
130
- # args = TrainingArgs()
131
-
132
- # testing_iterations = args.test_iterations
133
- # saving_iterations = args.save_iterations
134
- # checkpoint_iterations = args.checkpoint_iterations
135
- # debug_from = args.debug_from
136
-
137
- # tb_writer = prepare_output_and_logger(dataset)
138
 
139
- # gaussians = GaussianModel(dataset.sh_degree)
140
- # scene = Scene(dataset, gaussians)
141
- # gaussians.training_setup(opt)
142
-
143
- # bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
144
- # background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
145
-
146
- # iter_start = torch.cuda.Event(enable_timing = True)
147
- # iter_end = torch.cuda.Event(enable_timing = True)
148
-
149
- # viewpoint_stack = None
150
- # ema_loss_for_log = 0.0
151
- # first_iter = 0
152
- # progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress")
153
- # first_iter += 1
154
-
155
- # point_cloud_path = ""
156
- # progress = gr.Progress() # Initialize the progress bar
157
- # for iteration in range(first_iter, opt.iterations + 1):
158
- # iter_start.record()
159
- # gaussians.update_learning_rate(iteration)
160
 
161
- # # Every 1000 its we increase the levels of SH up to a maximum degree
162
- # if iteration % 1000 == 0:
163
- # gaussians.oneupSHdegree()
164
 
165
- # # Pick a random Camera
166
- # if not viewpoint_stack:
167
- # viewpoint_stack = scene.getTrainCameras().copy()
168
- # viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1))
169
 
170
- # # Render
171
- # if (iteration - 1) == debug_from:
172
- # pipe.debug = True
173
- # bg = torch.rand((3), device="cuda") if opt.random_background else background
174
 
175
- # render_pkg = render(viewpoint_cam, gaussians, pipe, bg)
176
- # image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"]
177
-
178
- # # Loss
179
- # gt_image = viewpoint_cam.original_image.cuda()
180
- # Ll1 = l1_loss(image, gt_image)
181
- # loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image))
182
- # loss.backward()
183
- # iter_end.record()
184
 
185
- # with torch.no_grad():
186
- # # Progress bar
187
- # ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log
188
- # if iteration % 10 == 0:
189
- # progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"})
190
- # progress_bar.update(10)
191
- # progress(iteration / opt.iterations) # Update Gradio progress bar
192
- # if iteration == opt.iterations:
193
- # progress_bar.close()
194
-
195
- # # Log and save
196
- # training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background))
197
- # if (iteration == opt.iterations):
198
- # point_cloud_path = os.path.join(os.path.join(dataset.model_path, "point_cloud/iteration_{}".format(iteration)), "point_cloud.ply")
199
- # print("\n[ITER {}] Saving Gaussians to {}".format(iteration, point_cloud_path))
200
- # scene.save(iteration)
201
-
202
- # # Densification
203
- # if iteration < opt.densify_until_iter:
204
- # # Keep track of max radii in image-space for pruning
205
- # gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter])
206
- # gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter)
207
-
208
- # if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0:
209
- # size_threshold = 20 if iteration > opt.opacity_reset_interval else None
210
- # gaussians.densify_and_prune(opt.densify_grad_threshold, 0.005, scene.cameras_extent, size_threshold)
211
-
212
- # if iteration % opt.opacity_reset_interval == 0 or (dataset.white_background and iteration == opt.densify_from_iter):
213
- # gaussians.reset_opacity()
214
-
215
- # # Optimizer step
216
- # if iteration < opt.iterations:
217
- # gaussians.optimizer.step()
218
- # gaussians.optimizer.zero_grad(set_to_none = True)
219
-
220
- # if (iteration == opt.iterations):
221
- # print("\n[ITER {}] Saving Checkpoint".format(iteration))
222
- # torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth")
223
-
224
-
225
- # from os import makedirs
226
- # from utils.graphics_utils import focal2fov, fov2focal, getProjectionMatrix
227
- # import torchvision
228
- # import subprocess
229
-
230
- # @torch.no_grad()
231
- # def render_path(dataset : ModelParams, iteration : int, pipeline : PipelineParams, render_resize_method='crop'):
232
- # """
233
- # render_resize_method: crop, pad
234
- # """
235
- # gaussians = GaussianModel(dataset.sh_degree)
236
- # scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False)
237
-
238
- # iteration = scene.loaded_iter
239
-
240
- # bg_color = [1,1,1] if dataset.white_background else [0, 0, 0]
241
- # background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
242
-
243
- # model_path = dataset.model_path
244
- # name = "render"
245
-
246
- # views = scene.getRenderCameras()
247
-
248
- # # print(len(views))
249
- # render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders")
250
-
251
- # makedirs(render_path, exist_ok=True)
252
-
253
- # for idx, view in enumerate(tqdm(views, desc="Rendering progress")):
254
- # if render_resize_method == 'crop':
255
- # image_size = 256
256
- # elif render_resize_method == 'pad':
257
- # image_size = max(view.image_width, view.image_height)
258
- # else:
259
- # raise NotImplementedError
260
- # view.original_image = torch.zeros((3, image_size, image_size), device=view.original_image.device)
261
- # focal_length_x = fov2focal(view.FoVx, view.image_width)
262
- # focal_length_y = fov2focal(view.FoVy, view.image_height)
263
- # view.image_width = image_size
264
- # view.image_height = image_size
265
- # view.FoVx = focal2fov(focal_length_x, image_size)
266
- # view.FoVy = focal2fov(focal_length_y, image_size)
267
- # view.projection_matrix = getProjectionMatrix(znear=view.znear, zfar=view.zfar, fovX=view.FoVx, fovY=view.FoVy).transpose(0,1).cuda().float()
268
- # view.full_proj_transform = (view.world_view_transform.unsqueeze(0).bmm(view.projection_matrix.unsqueeze(0))).squeeze(0)
269
-
270
- # render_pkg = render(view, gaussians, pipeline, background)
271
- # rendering = render_pkg["render"]
272
- # torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png"))
273
-
274
- # # Use ffmpeg to output video
275
- # renders_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders.mp4")
276
- # # Use ffmpeg to output video
277
- # subprocess.run(["ffmpeg", "-y",
278
- # "-framerate", "24",
279
- # "-i", os.path.join(render_path, "%05d.png"),
280
- # "-vf", "pad=ceil(iw/2)*2:ceil(ih/2)*2",
281
- # "-c:v", "libx264",
282
- # "-pix_fmt", "yuv420p",
283
- # "-crf", "23",
284
- # # "-pix_fmt", "yuv420p", # Set pixel format for compatibility
285
- # renders_path], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
286
- # )
287
- # return renders_path
288
 
289
- # renders_path = render_path(dataset, opt.iterations, pipe, render_resize_method='crop')
290
 
291
- # return renders_path, point_cloud_path
292
- return None, None
 
13
 
14
  # Import necessary modules from the gaussian-splatting directory
15
  from utils.loss_utils import l1_loss, ssim
 
16
  from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer
17
+ from scene import Scene, GaussianModel
 
18
  from utils.general_utils import safe_state
19
  from utils.image_utils import psnr
20
 
21
+ # Dynamically import the train module from the gaussian-splatting directory
22
+ train_spec = importlib.util.spec_from_file_location("gaussian_splatting_train", os.path.join(gaussian_splatting_path, "train.py"))
23
+ gaussian_splatting_train = importlib.util.module_from_spec(train_spec)
24
+ train_spec.loader.exec_module(gaussian_splatting_train)
25
 
26
+ # Import the necessary functions from the dynamically loaded module
27
+ prepare_output_and_logger = gaussian_splatting_train.prepare_output_and_logger
28
+ training_report = gaussian_splatting_train.training_report
29
 
30
  from dataclasses import dataclass, field
31
 
 
87
  densify_from_iter, densify_until_iter, densify_grad_threshold, random_background
88
  ):
89
  print(data_source_path)
90
+ # Create instances of the parameter dataclasses
91
+ dataset = ModelParams(
92
+ sh_degree=sh_degree,
93
+ source_path=data_source_path,
94
+ model_path=model_path,
95
+ images=images,
96
+ resolution=resolution,
97
+ white_background=white_background,
98
+ data_device=data_device,
99
+ eval=eval
100
+ )
101
 
102
+ pipe = PipelineParams(
103
+ convert_SHs_python=convert_SHs_python,
104
+ compute_cov3D_python=compute_cov3D_python,
105
+ debug=debug
106
+ )
107
 
108
+ opt = OptimizationParams(
109
+ iterations=iterations,
110
+ position_lr_init=position_lr_init,
111
+ position_lr_final=position_lr_final,
112
+ position_lr_delay_mult=position_lr_delay_mult,
113
+ position_lr_max_steps=position_lr_max_steps,
114
+ feature_lr=feature_lr,
115
+ opacity_lr=opacity_lr,
116
+ scaling_lr=scaling_lr,
117
+ rotation_lr=rotation_lr,
118
+ percent_dense=percent_dense,
119
+ lambda_dssim=lambda_dssim,
120
+ densification_interval=densification_interval,
121
+ opacity_reset_interval=opacity_reset_interval,
122
+ densify_from_iter=densify_from_iter,
123
+ densify_until_iter=densify_until_iter,
124
+ densify_grad_threshold=densify_grad_threshold,
125
+ random_background=random_background
126
+ )
127
+
128
+ args = TrainingArgs()
129
+
130
+ testing_iterations = args.test_iterations
131
+ saving_iterations = args.save_iterations
132
+ checkpoint_iterations = args.checkpoint_iterations
133
+ debug_from = args.debug_from
134
+
135
+ tb_writer = prepare_output_and_logger(dataset)
136
 
137
+ gaussians = GaussianModel(dataset.sh_degree)
138
+ scene = Scene(dataset, gaussians)
139
+ gaussians.training_setup(opt)
140
+
141
+ bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
142
+ background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
143
+
144
+ iter_start = torch.cuda.Event(enable_timing = True)
145
+ iter_end = torch.cuda.Event(enable_timing = True)
146
+
147
+ viewpoint_stack = None
148
+ ema_loss_for_log = 0.0
149
+ first_iter = 0
150
+ progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress")
151
+ first_iter += 1
152
+
153
+ point_cloud_path = ""
154
+ progress = gr.Progress() # Initialize the progress bar
155
+ for iteration in range(first_iter, opt.iterations + 1):
156
+ iter_start.record()
157
+ gaussians.update_learning_rate(iteration)
158
 
159
+ # Every 1000 its we increase the levels of SH up to a maximum degree
160
+ if iteration % 1000 == 0:
161
+ gaussians.oneupSHdegree()
162
 
163
+ # Pick a random Camera
164
+ if not viewpoint_stack:
165
+ viewpoint_stack = scene.getTrainCameras().copy()
166
+ viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1))
167
 
168
+ # Render
169
+ if (iteration - 1) == debug_from:
170
+ pipe.debug = True
171
+ bg = torch.rand((3), device="cuda") if opt.random_background else background
172
 
173
+ render_pkg = render(viewpoint_cam, gaussians, pipe, bg)
174
+ image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"]
175
+
176
+ # Loss
177
+ gt_image = viewpoint_cam.original_image.cuda()
178
+ Ll1 = l1_loss(image, gt_image)
179
+ loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image))
180
+ loss.backward()
181
+ iter_end.record()
182
 
183
+ with torch.no_grad():
184
+ # Progress bar
185
+ ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log
186
+ if iteration % 10 == 0:
187
+ progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"})
188
+ progress_bar.update(10)
189
+ progress(iteration / opt.iterations) # Update Gradio progress bar
190
+ if iteration == opt.iterations:
191
+ progress_bar.close()
192
+
193
+ # Log and save
194
+ training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background))
195
+ if (iteration == opt.iterations):
196
+ point_cloud_path = os.path.join(os.path.join(dataset.model_path, "point_cloud/iteration_{}".format(iteration)), "point_cloud.ply")
197
+ print("\n[ITER {}] Saving Gaussians to {}".format(iteration, point_cloud_path))
198
+ scene.save(iteration)
199
+
200
+ # Densification
201
+ if iteration < opt.densify_until_iter:
202
+ # Keep track of max radii in image-space for pruning
203
+ gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter])
204
+ gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter)
205
+
206
+ if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0:
207
+ size_threshold = 20 if iteration > opt.opacity_reset_interval else None
208
+ gaussians.densify_and_prune(opt.densify_grad_threshold, 0.005, scene.cameras_extent, size_threshold)
209
+
210
+ if iteration % opt.opacity_reset_interval == 0 or (dataset.white_background and iteration == opt.densify_from_iter):
211
+ gaussians.reset_opacity()
212
+
213
+ # Optimizer step
214
+ if iteration < opt.iterations:
215
+ gaussians.optimizer.step()
216
+ gaussians.optimizer.zero_grad(set_to_none = True)
217
+
218
+ if (iteration == opt.iterations):
219
+ print("\n[ITER {}] Saving Checkpoint".format(iteration))
220
+ torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth")
221
+
222
+
223
+ from os import makedirs
224
+ from utils.graphics_utils import focal2fov, fov2focal, getProjectionMatrix
225
+ import torchvision
226
+ import subprocess
227
+
228
+ @torch.no_grad()
229
+ def render_path(dataset : ModelParams, iteration : int, pipeline : PipelineParams, render_resize_method='crop'):
230
+ """
231
+ render_resize_method: crop, pad
232
+ """
233
+ gaussians = GaussianModel(dataset.sh_degree)
234
+ scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False)
235
+
236
+ iteration = scene.loaded_iter
237
+
238
+ bg_color = [1,1,1] if dataset.white_background else [0, 0, 0]
239
+ background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
240
+
241
+ model_path = dataset.model_path
242
+ name = "render"
243
+
244
+ views = scene.getRenderCameras()
245
+
246
+ # print(len(views))
247
+ render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders")
248
+
249
+ makedirs(render_path, exist_ok=True)
250
+
251
+ for idx, view in enumerate(tqdm(views, desc="Rendering progress")):
252
+ if render_resize_method == 'crop':
253
+ image_size = 256
254
+ elif render_resize_method == 'pad':
255
+ image_size = max(view.image_width, view.image_height)
256
+ else:
257
+ raise NotImplementedError
258
+ view.original_image = torch.zeros((3, image_size, image_size), device=view.original_image.device)
259
+ focal_length_x = fov2focal(view.FoVx, view.image_width)
260
+ focal_length_y = fov2focal(view.FoVy, view.image_height)
261
+ view.image_width = image_size
262
+ view.image_height = image_size
263
+ view.FoVx = focal2fov(focal_length_x, image_size)
264
+ view.FoVy = focal2fov(focal_length_y, image_size)
265
+ view.projection_matrix = getProjectionMatrix(znear=view.znear, zfar=view.zfar, fovX=view.FoVx, fovY=view.FoVy).transpose(0,1).cuda().float()
266
+ view.full_proj_transform = (view.world_view_transform.unsqueeze(0).bmm(view.projection_matrix.unsqueeze(0))).squeeze(0)
267
+
268
+ render_pkg = render(view, gaussians, pipeline, background)
269
+ rendering = render_pkg["render"]
270
+ torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png"))
271
+
272
+ # Use ffmpeg to output video
273
+ renders_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders.mp4")
274
+ # Use ffmpeg to output video
275
+ subprocess.run(["ffmpeg", "-y",
276
+ "-framerate", "24",
277
+ "-i", os.path.join(render_path, "%05d.png"),
278
+ "-vf", "pad=ceil(iw/2)*2:ceil(ih/2)*2",
279
+ "-c:v", "libx264",
280
+ "-pix_fmt", "yuv420p",
281
+ "-crf", "23",
282
+ # "-pix_fmt", "yuv420p", # Set pixel format for compatibility
283
+ renders_path], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
284
+ )
285
+ return renders_path
286
 
287
+ renders_path = render_path(dataset, opt.iterations, pipe, render_resize_method='crop')
288
 
289
+ return renders_path, point_cloud_path