ostapagon commited on
Commit
05ec1e2
·
1 Parent(s): a55d4eb

Check if diff ext is working

Browse files
Files changed (1) hide show
  1. demo/gs_train.py +199 -198
demo/gs_train.py CHANGED
@@ -14,18 +14,18 @@ sys.path.append(gaussian_splatting_path)
14
  # Import necessary modules from the gaussian-splatting directory
15
  from utils.loss_utils import l1_loss, ssim
16
  from gaussian_renderer import render, network_gui
17
- from scene import Scene, GaussianModel
18
  from utils.general_utils import safe_state
19
  from utils.image_utils import psnr
20
 
21
- # Dynamically import the train module from the gaussian-splatting directory
22
- train_spec = importlib.util.spec_from_file_location("gaussian_splatting_train", os.path.join(gaussian_splatting_path, "train.py"))
23
- gaussian_splatting_train = importlib.util.module_from_spec(train_spec)
24
- train_spec.loader.exec_module(gaussian_splatting_train)
25
 
26
- # Import the necessary functions from the dynamically loaded module
27
- prepare_output_and_logger = gaussian_splatting_train.prepare_output_and_logger
28
- training_report = gaussian_splatting_train.training_report
29
 
30
  from dataclasses import dataclass, field
31
 
@@ -87,203 +87,204 @@ def train(
87
  densify_from_iter, densify_until_iter, densify_grad_threshold, random_background
88
  ):
89
  print(data_source_path)
90
- # Create instances of the parameter dataclasses
91
- dataset = ModelParams(
92
- sh_degree=sh_degree,
93
- source_path=data_source_path,
94
- model_path=model_path,
95
- images=images,
96
- resolution=resolution,
97
- white_background=white_background,
98
- data_device=data_device,
99
- eval=eval
100
- )
101
 
102
- pipe = PipelineParams(
103
- convert_SHs_python=convert_SHs_python,
104
- compute_cov3D_python=compute_cov3D_python,
105
- debug=debug
106
- )
107
 
108
- opt = OptimizationParams(
109
- iterations=iterations,
110
- position_lr_init=position_lr_init,
111
- position_lr_final=position_lr_final,
112
- position_lr_delay_mult=position_lr_delay_mult,
113
- position_lr_max_steps=position_lr_max_steps,
114
- feature_lr=feature_lr,
115
- opacity_lr=opacity_lr,
116
- scaling_lr=scaling_lr,
117
- rotation_lr=rotation_lr,
118
- percent_dense=percent_dense,
119
- lambda_dssim=lambda_dssim,
120
- densification_interval=densification_interval,
121
- opacity_reset_interval=opacity_reset_interval,
122
- densify_from_iter=densify_from_iter,
123
- densify_until_iter=densify_until_iter,
124
- densify_grad_threshold=densify_grad_threshold,
125
- random_background=random_background
126
- )
127
-
128
- args = TrainingArgs()
129
-
130
- testing_iterations = args.test_iterations
131
- saving_iterations = args.save_iterations
132
- checkpoint_iterations = args.checkpoint_iterations
133
- debug_from = args.debug_from
134
-
135
- tb_writer = prepare_output_and_logger(dataset)
136
 
137
- gaussians = GaussianModel(dataset.sh_degree)
138
- scene = Scene(dataset, gaussians)
139
- gaussians.training_setup(opt)
140
-
141
- bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
142
- background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
143
-
144
- iter_start = torch.cuda.Event(enable_timing = True)
145
- iter_end = torch.cuda.Event(enable_timing = True)
146
-
147
- viewpoint_stack = None
148
- ema_loss_for_log = 0.0
149
- first_iter = 0
150
- progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress")
151
- first_iter += 1
152
-
153
- point_cloud_path = ""
154
- progress = gr.Progress() # Initialize the progress bar
155
- for iteration in range(first_iter, opt.iterations + 1):
156
- iter_start.record()
157
- gaussians.update_learning_rate(iteration)
158
 
159
- # Every 1000 its we increase the levels of SH up to a maximum degree
160
- if iteration % 1000 == 0:
161
- gaussians.oneupSHdegree()
162
 
163
- # Pick a random Camera
164
- if not viewpoint_stack:
165
- viewpoint_stack = scene.getTrainCameras().copy()
166
- viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1))
167
 
168
- # Render
169
- if (iteration - 1) == debug_from:
170
- pipe.debug = True
171
- bg = torch.rand((3), device="cuda") if opt.random_background else background
172
 
173
- render_pkg = render(viewpoint_cam, gaussians, pipe, bg)
174
- image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"]
175
-
176
- # Loss
177
- gt_image = viewpoint_cam.original_image.cuda()
178
- Ll1 = l1_loss(image, gt_image)
179
- loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image))
180
- loss.backward()
181
- iter_end.record()
182
 
183
- with torch.no_grad():
184
- # Progress bar
185
- ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log
186
- if iteration % 10 == 0:
187
- progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"})
188
- progress_bar.update(10)
189
- progress(iteration / opt.iterations) # Update Gradio progress bar
190
- if iteration == opt.iterations:
191
- progress_bar.close()
192
-
193
- # Log and save
194
- training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background))
195
- if (iteration == opt.iterations):
196
- point_cloud_path = os.path.join(os.path.join(dataset.model_path, "point_cloud/iteration_{}".format(iteration)), "point_cloud.ply")
197
- print("\n[ITER {}] Saving Gaussians to {}".format(iteration, point_cloud_path))
198
- scene.save(iteration)
199
-
200
- # Densification
201
- if iteration < opt.densify_until_iter:
202
- # Keep track of max radii in image-space for pruning
203
- gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter])
204
- gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter)
205
-
206
- if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0:
207
- size_threshold = 20 if iteration > opt.opacity_reset_interval else None
208
- gaussians.densify_and_prune(opt.densify_grad_threshold, 0.005, scene.cameras_extent, size_threshold)
209
-
210
- if iteration % opt.opacity_reset_interval == 0 or (dataset.white_background and iteration == opt.densify_from_iter):
211
- gaussians.reset_opacity()
212
-
213
- # Optimizer step
214
- if iteration < opt.iterations:
215
- gaussians.optimizer.step()
216
- gaussians.optimizer.zero_grad(set_to_none = True)
217
-
218
- if (iteration == opt.iterations):
219
- print("\n[ITER {}] Saving Checkpoint".format(iteration))
220
- torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth")
221
-
222
-
223
- from os import makedirs
224
- from utils.graphics_utils import focal2fov, fov2focal, getProjectionMatrix
225
- import torchvision
226
- import subprocess
227
-
228
- @torch.no_grad()
229
- def render_path(dataset : ModelParams, iteration : int, pipeline : PipelineParams, render_resize_method='crop'):
230
- """
231
- render_resize_method: crop, pad
232
- """
233
- gaussians = GaussianModel(dataset.sh_degree)
234
- scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False)
235
-
236
- iteration = scene.loaded_iter
237
-
238
- bg_color = [1,1,1] if dataset.white_background else [0, 0, 0]
239
- background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
240
-
241
- model_path = dataset.model_path
242
- name = "render"
243
-
244
- views = scene.getRenderCameras()
245
-
246
- # print(len(views))
247
- render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders")
248
-
249
- makedirs(render_path, exist_ok=True)
250
-
251
- for idx, view in enumerate(tqdm(views, desc="Rendering progress")):
252
- if render_resize_method == 'crop':
253
- image_size = 256
254
- elif render_resize_method == 'pad':
255
- image_size = max(view.image_width, view.image_height)
256
- else:
257
- raise NotImplementedError
258
- view.original_image = torch.zeros((3, image_size, image_size), device=view.original_image.device)
259
- focal_length_x = fov2focal(view.FoVx, view.image_width)
260
- focal_length_y = fov2focal(view.FoVy, view.image_height)
261
- view.image_width = image_size
262
- view.image_height = image_size
263
- view.FoVx = focal2fov(focal_length_x, image_size)
264
- view.FoVy = focal2fov(focal_length_y, image_size)
265
- view.projection_matrix = getProjectionMatrix(znear=view.znear, zfar=view.zfar, fovX=view.FoVx, fovY=view.FoVy).transpose(0,1).cuda().float()
266
- view.full_proj_transform = (view.world_view_transform.unsqueeze(0).bmm(view.projection_matrix.unsqueeze(0))).squeeze(0)
267
-
268
- render_pkg = render(view, gaussians, pipeline, background)
269
- rendering = render_pkg["render"]
270
- torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png"))
271
-
272
- # Use ffmpeg to output video
273
- renders_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders.mp4")
274
- # Use ffmpeg to output video
275
- subprocess.run(["ffmpeg", "-y",
276
- "-framerate", "24",
277
- "-i", os.path.join(render_path, "%05d.png"),
278
- "-vf", "pad=ceil(iw/2)*2:ceil(ih/2)*2",
279
- "-c:v", "libx264",
280
- "-pix_fmt", "yuv420p",
281
- "-crf", "23",
282
- # "-pix_fmt", "yuv420p", # Set pixel format for compatibility
283
- renders_path], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
284
- )
285
- return renders_path
286
 
287
- renders_path = render_path(dataset, opt.iterations, pipe, render_resize_method='crop')
288
 
289
- return renders_path, point_cloud_path
 
 
14
  # Import necessary modules from the gaussian-splatting directory
15
  from utils.loss_utils import l1_loss, ssim
16
  from gaussian_renderer import render, network_gui
17
+ # from scene import Scene, GaussianModel
18
  from utils.general_utils import safe_state
19
  from utils.image_utils import psnr
20
 
21
+ # # Dynamically import the train module from the gaussian-splatting directory
22
+ # train_spec = importlib.util.spec_from_file_location("gaussian_splatting_train", os.path.join(gaussian_splatting_path, "train.py"))
23
+ # gaussian_splatting_train = importlib.util.module_from_spec(train_spec)
24
+ # train_spec.loader.exec_module(gaussian_splatting_train)
25
 
26
+ # # Import the necessary functions from the dynamically loaded module
27
+ # prepare_output_and_logger = gaussian_splatting_train.prepare_output_and_logger
28
+ # training_report = gaussian_splatting_train.training_report
29
 
30
  from dataclasses import dataclass, field
31
 
 
87
  densify_from_iter, densify_until_iter, densify_grad_threshold, random_background
88
  ):
89
  print(data_source_path)
90
+ # # Create instances of the parameter dataclasses
91
+ # dataset = ModelParams(
92
+ # sh_degree=sh_degree,
93
+ # source_path=data_source_path,
94
+ # model_path=model_path,
95
+ # images=images,
96
+ # resolution=resolution,
97
+ # white_background=white_background,
98
+ # data_device=data_device,
99
+ # eval=eval
100
+ # )
101
 
102
+ # pipe = PipelineParams(
103
+ # convert_SHs_python=convert_SHs_python,
104
+ # compute_cov3D_python=compute_cov3D_python,
105
+ # debug=debug
106
+ # )
107
 
108
+ # opt = OptimizationParams(
109
+ # iterations=iterations,
110
+ # position_lr_init=position_lr_init,
111
+ # position_lr_final=position_lr_final,
112
+ # position_lr_delay_mult=position_lr_delay_mult,
113
+ # position_lr_max_steps=position_lr_max_steps,
114
+ # feature_lr=feature_lr,
115
+ # opacity_lr=opacity_lr,
116
+ # scaling_lr=scaling_lr,
117
+ # rotation_lr=rotation_lr,
118
+ # percent_dense=percent_dense,
119
+ # lambda_dssim=lambda_dssim,
120
+ # densification_interval=densification_interval,
121
+ # opacity_reset_interval=opacity_reset_interval,
122
+ # densify_from_iter=densify_from_iter,
123
+ # densify_until_iter=densify_until_iter,
124
+ # densify_grad_threshold=densify_grad_threshold,
125
+ # random_background=random_background
126
+ # )
127
+
128
+ # args = TrainingArgs()
129
+
130
+ # testing_iterations = args.test_iterations
131
+ # saving_iterations = args.save_iterations
132
+ # checkpoint_iterations = args.checkpoint_iterations
133
+ # debug_from = args.debug_from
134
+
135
+ # tb_writer = prepare_output_and_logger(dataset)
136
 
137
+ # gaussians = GaussianModel(dataset.sh_degree)
138
+ # scene = Scene(dataset, gaussians)
139
+ # gaussians.training_setup(opt)
140
+
141
+ # bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
142
+ # background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
143
+
144
+ # iter_start = torch.cuda.Event(enable_timing = True)
145
+ # iter_end = torch.cuda.Event(enable_timing = True)
146
+
147
+ # viewpoint_stack = None
148
+ # ema_loss_for_log = 0.0
149
+ # first_iter = 0
150
+ # progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress")
151
+ # first_iter += 1
152
+
153
+ # point_cloud_path = ""
154
+ # progress = gr.Progress() # Initialize the progress bar
155
+ # for iteration in range(first_iter, opt.iterations + 1):
156
+ # iter_start.record()
157
+ # gaussians.update_learning_rate(iteration)
158
 
159
+ # # Every 1000 its we increase the levels of SH up to a maximum degree
160
+ # if iteration % 1000 == 0:
161
+ # gaussians.oneupSHdegree()
162
 
163
+ # # Pick a random Camera
164
+ # if not viewpoint_stack:
165
+ # viewpoint_stack = scene.getTrainCameras().copy()
166
+ # viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1))
167
 
168
+ # # Render
169
+ # if (iteration - 1) == debug_from:
170
+ # pipe.debug = True
171
+ # bg = torch.rand((3), device="cuda") if opt.random_background else background
172
 
173
+ # render_pkg = render(viewpoint_cam, gaussians, pipe, bg)
174
+ # image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"]
175
+
176
+ # # Loss
177
+ # gt_image = viewpoint_cam.original_image.cuda()
178
+ # Ll1 = l1_loss(image, gt_image)
179
+ # loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image))
180
+ # loss.backward()
181
+ # iter_end.record()
182
 
183
+ # with torch.no_grad():
184
+ # # Progress bar
185
+ # ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log
186
+ # if iteration % 10 == 0:
187
+ # progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"})
188
+ # progress_bar.update(10)
189
+ # progress(iteration / opt.iterations) # Update Gradio progress bar
190
+ # if iteration == opt.iterations:
191
+ # progress_bar.close()
192
+
193
+ # # Log and save
194
+ # training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background))
195
+ # if (iteration == opt.iterations):
196
+ # point_cloud_path = os.path.join(os.path.join(dataset.model_path, "point_cloud/iteration_{}".format(iteration)), "point_cloud.ply")
197
+ # print("\n[ITER {}] Saving Gaussians to {}".format(iteration, point_cloud_path))
198
+ # scene.save(iteration)
199
+
200
+ # # Densification
201
+ # if iteration < opt.densify_until_iter:
202
+ # # Keep track of max radii in image-space for pruning
203
+ # gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter])
204
+ # gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter)
205
+
206
+ # if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0:
207
+ # size_threshold = 20 if iteration > opt.opacity_reset_interval else None
208
+ # gaussians.densify_and_prune(opt.densify_grad_threshold, 0.005, scene.cameras_extent, size_threshold)
209
+
210
+ # if iteration % opt.opacity_reset_interval == 0 or (dataset.white_background and iteration == opt.densify_from_iter):
211
+ # gaussians.reset_opacity()
212
+
213
+ # # Optimizer step
214
+ # if iteration < opt.iterations:
215
+ # gaussians.optimizer.step()
216
+ # gaussians.optimizer.zero_grad(set_to_none = True)
217
+
218
+ # if (iteration == opt.iterations):
219
+ # print("\n[ITER {}] Saving Checkpoint".format(iteration))
220
+ # torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth")
221
+
222
+
223
+ # from os import makedirs
224
+ # from utils.graphics_utils import focal2fov, fov2focal, getProjectionMatrix
225
+ # import torchvision
226
+ # import subprocess
227
+
228
+ # @torch.no_grad()
229
+ # def render_path(dataset : ModelParams, iteration : int, pipeline : PipelineParams, render_resize_method='crop'):
230
+ # """
231
+ # render_resize_method: crop, pad
232
+ # """
233
+ # gaussians = GaussianModel(dataset.sh_degree)
234
+ # scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False)
235
+
236
+ # iteration = scene.loaded_iter
237
+
238
+ # bg_color = [1,1,1] if dataset.white_background else [0, 0, 0]
239
+ # background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
240
+
241
+ # model_path = dataset.model_path
242
+ # name = "render"
243
+
244
+ # views = scene.getRenderCameras()
245
+
246
+ # # print(len(views))
247
+ # render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders")
248
+
249
+ # makedirs(render_path, exist_ok=True)
250
+
251
+ # for idx, view in enumerate(tqdm(views, desc="Rendering progress")):
252
+ # if render_resize_method == 'crop':
253
+ # image_size = 256
254
+ # elif render_resize_method == 'pad':
255
+ # image_size = max(view.image_width, view.image_height)
256
+ # else:
257
+ # raise NotImplementedError
258
+ # view.original_image = torch.zeros((3, image_size, image_size), device=view.original_image.device)
259
+ # focal_length_x = fov2focal(view.FoVx, view.image_width)
260
+ # focal_length_y = fov2focal(view.FoVy, view.image_height)
261
+ # view.image_width = image_size
262
+ # view.image_height = image_size
263
+ # view.FoVx = focal2fov(focal_length_x, image_size)
264
+ # view.FoVy = focal2fov(focal_length_y, image_size)
265
+ # view.projection_matrix = getProjectionMatrix(znear=view.znear, zfar=view.zfar, fovX=view.FoVx, fovY=view.FoVy).transpose(0,1).cuda().float()
266
+ # view.full_proj_transform = (view.world_view_transform.unsqueeze(0).bmm(view.projection_matrix.unsqueeze(0))).squeeze(0)
267
+
268
+ # render_pkg = render(view, gaussians, pipeline, background)
269
+ # rendering = render_pkg["render"]
270
+ # torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png"))
271
+
272
+ # # Use ffmpeg to output video
273
+ # renders_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders.mp4")
274
+ # # Use ffmpeg to output video
275
+ # subprocess.run(["ffmpeg", "-y",
276
+ # "-framerate", "24",
277
+ # "-i", os.path.join(render_path, "%05d.png"),
278
+ # "-vf", "pad=ceil(iw/2)*2:ceil(ih/2)*2",
279
+ # "-c:v", "libx264",
280
+ # "-pix_fmt", "yuv420p",
281
+ # "-crf", "23",
282
+ # # "-pix_fmt", "yuv420p", # Set pixel format for compatibility
283
+ # renders_path], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
284
+ # )
285
+ # return renders_path
286
 
287
+ # renders_path = render_path(dataset, opt.iterations, pipe, render_resize_method='crop')
288
 
289
+ # return renders_path, point_cloud_path
290
+ return None, None