Spaces:
Running
on
Zero
Running
on
Zero
Check if diff ext is working
Browse files- demo/gs_train.py +199 -198
demo/gs_train.py
CHANGED
@@ -14,18 +14,18 @@ sys.path.append(gaussian_splatting_path)
|
|
14 |
# Import necessary modules from the gaussian-splatting directory
|
15 |
from utils.loss_utils import l1_loss, ssim
|
16 |
from gaussian_renderer import render, network_gui
|
17 |
-
from scene import Scene, GaussianModel
|
18 |
from utils.general_utils import safe_state
|
19 |
from utils.image_utils import psnr
|
20 |
|
21 |
-
# Dynamically import the train module from the gaussian-splatting directory
|
22 |
-
train_spec = importlib.util.spec_from_file_location("gaussian_splatting_train", os.path.join(gaussian_splatting_path, "train.py"))
|
23 |
-
gaussian_splatting_train = importlib.util.module_from_spec(train_spec)
|
24 |
-
train_spec.loader.exec_module(gaussian_splatting_train)
|
25 |
|
26 |
-
# Import the necessary functions from the dynamically loaded module
|
27 |
-
prepare_output_and_logger = gaussian_splatting_train.prepare_output_and_logger
|
28 |
-
training_report = gaussian_splatting_train.training_report
|
29 |
|
30 |
from dataclasses import dataclass, field
|
31 |
|
@@ -87,203 +87,204 @@ def train(
|
|
87 |
densify_from_iter, densify_until_iter, densify_grad_threshold, random_background
|
88 |
):
|
89 |
print(data_source_path)
|
90 |
-
# Create instances of the parameter dataclasses
|
91 |
-
dataset = ModelParams(
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
)
|
101 |
|
102 |
-
pipe = PipelineParams(
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
)
|
107 |
|
108 |
-
opt = OptimizationParams(
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
)
|
127 |
-
|
128 |
-
args = TrainingArgs()
|
129 |
-
|
130 |
-
testing_iterations = args.test_iterations
|
131 |
-
saving_iterations = args.save_iterations
|
132 |
-
checkpoint_iterations = args.checkpoint_iterations
|
133 |
-
debug_from = args.debug_from
|
134 |
-
|
135 |
-
tb_writer = prepare_output_and_logger(dataset)
|
136 |
|
137 |
-
gaussians = GaussianModel(dataset.sh_degree)
|
138 |
-
scene = Scene(dataset, gaussians)
|
139 |
-
gaussians.training_setup(opt)
|
140 |
-
|
141 |
-
bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
|
142 |
-
background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
|
143 |
-
|
144 |
-
iter_start = torch.cuda.Event(enable_timing = True)
|
145 |
-
iter_end = torch.cuda.Event(enable_timing = True)
|
146 |
-
|
147 |
-
viewpoint_stack = None
|
148 |
-
ema_loss_for_log = 0.0
|
149 |
-
first_iter = 0
|
150 |
-
progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress")
|
151 |
-
first_iter += 1
|
152 |
-
|
153 |
-
point_cloud_path = ""
|
154 |
-
progress = gr.Progress() # Initialize the progress bar
|
155 |
-
for iteration in range(first_iter, opt.iterations + 1):
|
156 |
-
|
157 |
-
|
158 |
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
from os import makedirs
|
224 |
-
from utils.graphics_utils import focal2fov, fov2focal, getProjectionMatrix
|
225 |
-
import torchvision
|
226 |
-
import subprocess
|
227 |
-
|
228 |
-
@torch.no_grad()
|
229 |
-
def render_path(dataset : ModelParams, iteration : int, pipeline : PipelineParams, render_resize_method='crop'):
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
|
287 |
-
renders_path = render_path(dataset, opt.iterations, pipe, render_resize_method='crop')
|
288 |
|
289 |
-
return renders_path, point_cloud_path
|
|
|
|
14 |
# Import necessary modules from the gaussian-splatting directory
|
15 |
from utils.loss_utils import l1_loss, ssim
|
16 |
from gaussian_renderer import render, network_gui
|
17 |
+
# from scene import Scene, GaussianModel
|
18 |
from utils.general_utils import safe_state
|
19 |
from utils.image_utils import psnr
|
20 |
|
21 |
+
# # Dynamically import the train module from the gaussian-splatting directory
|
22 |
+
# train_spec = importlib.util.spec_from_file_location("gaussian_splatting_train", os.path.join(gaussian_splatting_path, "train.py"))
|
23 |
+
# gaussian_splatting_train = importlib.util.module_from_spec(train_spec)
|
24 |
+
# train_spec.loader.exec_module(gaussian_splatting_train)
|
25 |
|
26 |
+
# # Import the necessary functions from the dynamically loaded module
|
27 |
+
# prepare_output_and_logger = gaussian_splatting_train.prepare_output_and_logger
|
28 |
+
# training_report = gaussian_splatting_train.training_report
|
29 |
|
30 |
from dataclasses import dataclass, field
|
31 |
|
|
|
87 |
densify_from_iter, densify_until_iter, densify_grad_threshold, random_background
|
88 |
):
|
89 |
print(data_source_path)
|
90 |
+
# # Create instances of the parameter dataclasses
|
91 |
+
# dataset = ModelParams(
|
92 |
+
# sh_degree=sh_degree,
|
93 |
+
# source_path=data_source_path,
|
94 |
+
# model_path=model_path,
|
95 |
+
# images=images,
|
96 |
+
# resolution=resolution,
|
97 |
+
# white_background=white_background,
|
98 |
+
# data_device=data_device,
|
99 |
+
# eval=eval
|
100 |
+
# )
|
101 |
|
102 |
+
# pipe = PipelineParams(
|
103 |
+
# convert_SHs_python=convert_SHs_python,
|
104 |
+
# compute_cov3D_python=compute_cov3D_python,
|
105 |
+
# debug=debug
|
106 |
+
# )
|
107 |
|
108 |
+
# opt = OptimizationParams(
|
109 |
+
# iterations=iterations,
|
110 |
+
# position_lr_init=position_lr_init,
|
111 |
+
# position_lr_final=position_lr_final,
|
112 |
+
# position_lr_delay_mult=position_lr_delay_mult,
|
113 |
+
# position_lr_max_steps=position_lr_max_steps,
|
114 |
+
# feature_lr=feature_lr,
|
115 |
+
# opacity_lr=opacity_lr,
|
116 |
+
# scaling_lr=scaling_lr,
|
117 |
+
# rotation_lr=rotation_lr,
|
118 |
+
# percent_dense=percent_dense,
|
119 |
+
# lambda_dssim=lambda_dssim,
|
120 |
+
# densification_interval=densification_interval,
|
121 |
+
# opacity_reset_interval=opacity_reset_interval,
|
122 |
+
# densify_from_iter=densify_from_iter,
|
123 |
+
# densify_until_iter=densify_until_iter,
|
124 |
+
# densify_grad_threshold=densify_grad_threshold,
|
125 |
+
# random_background=random_background
|
126 |
+
# )
|
127 |
+
|
128 |
+
# args = TrainingArgs()
|
129 |
+
|
130 |
+
# testing_iterations = args.test_iterations
|
131 |
+
# saving_iterations = args.save_iterations
|
132 |
+
# checkpoint_iterations = args.checkpoint_iterations
|
133 |
+
# debug_from = args.debug_from
|
134 |
+
|
135 |
+
# tb_writer = prepare_output_and_logger(dataset)
|
136 |
|
137 |
+
# gaussians = GaussianModel(dataset.sh_degree)
|
138 |
+
# scene = Scene(dataset, gaussians)
|
139 |
+
# gaussians.training_setup(opt)
|
140 |
+
|
141 |
+
# bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
|
142 |
+
# background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
|
143 |
+
|
144 |
+
# iter_start = torch.cuda.Event(enable_timing = True)
|
145 |
+
# iter_end = torch.cuda.Event(enable_timing = True)
|
146 |
+
|
147 |
+
# viewpoint_stack = None
|
148 |
+
# ema_loss_for_log = 0.0
|
149 |
+
# first_iter = 0
|
150 |
+
# progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress")
|
151 |
+
# first_iter += 1
|
152 |
+
|
153 |
+
# point_cloud_path = ""
|
154 |
+
# progress = gr.Progress() # Initialize the progress bar
|
155 |
+
# for iteration in range(first_iter, opt.iterations + 1):
|
156 |
+
# iter_start.record()
|
157 |
+
# gaussians.update_learning_rate(iteration)
|
158 |
|
159 |
+
# # Every 1000 its we increase the levels of SH up to a maximum degree
|
160 |
+
# if iteration % 1000 == 0:
|
161 |
+
# gaussians.oneupSHdegree()
|
162 |
|
163 |
+
# # Pick a random Camera
|
164 |
+
# if not viewpoint_stack:
|
165 |
+
# viewpoint_stack = scene.getTrainCameras().copy()
|
166 |
+
# viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1))
|
167 |
|
168 |
+
# # Render
|
169 |
+
# if (iteration - 1) == debug_from:
|
170 |
+
# pipe.debug = True
|
171 |
+
# bg = torch.rand((3), device="cuda") if opt.random_background else background
|
172 |
|
173 |
+
# render_pkg = render(viewpoint_cam, gaussians, pipe, bg)
|
174 |
+
# image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"]
|
175 |
+
|
176 |
+
# # Loss
|
177 |
+
# gt_image = viewpoint_cam.original_image.cuda()
|
178 |
+
# Ll1 = l1_loss(image, gt_image)
|
179 |
+
# loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image))
|
180 |
+
# loss.backward()
|
181 |
+
# iter_end.record()
|
182 |
|
183 |
+
# with torch.no_grad():
|
184 |
+
# # Progress bar
|
185 |
+
# ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log
|
186 |
+
# if iteration % 10 == 0:
|
187 |
+
# progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"})
|
188 |
+
# progress_bar.update(10)
|
189 |
+
# progress(iteration / opt.iterations) # Update Gradio progress bar
|
190 |
+
# if iteration == opt.iterations:
|
191 |
+
# progress_bar.close()
|
192 |
+
|
193 |
+
# # Log and save
|
194 |
+
# training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background))
|
195 |
+
# if (iteration == opt.iterations):
|
196 |
+
# point_cloud_path = os.path.join(os.path.join(dataset.model_path, "point_cloud/iteration_{}".format(iteration)), "point_cloud.ply")
|
197 |
+
# print("\n[ITER {}] Saving Gaussians to {}".format(iteration, point_cloud_path))
|
198 |
+
# scene.save(iteration)
|
199 |
+
|
200 |
+
# # Densification
|
201 |
+
# if iteration < opt.densify_until_iter:
|
202 |
+
# # Keep track of max radii in image-space for pruning
|
203 |
+
# gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter])
|
204 |
+
# gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter)
|
205 |
+
|
206 |
+
# if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0:
|
207 |
+
# size_threshold = 20 if iteration > opt.opacity_reset_interval else None
|
208 |
+
# gaussians.densify_and_prune(opt.densify_grad_threshold, 0.005, scene.cameras_extent, size_threshold)
|
209 |
+
|
210 |
+
# if iteration % opt.opacity_reset_interval == 0 or (dataset.white_background and iteration == opt.densify_from_iter):
|
211 |
+
# gaussians.reset_opacity()
|
212 |
+
|
213 |
+
# # Optimizer step
|
214 |
+
# if iteration < opt.iterations:
|
215 |
+
# gaussians.optimizer.step()
|
216 |
+
# gaussians.optimizer.zero_grad(set_to_none = True)
|
217 |
+
|
218 |
+
# if (iteration == opt.iterations):
|
219 |
+
# print("\n[ITER {}] Saving Checkpoint".format(iteration))
|
220 |
+
# torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth")
|
221 |
+
|
222 |
+
|
223 |
+
# from os import makedirs
|
224 |
+
# from utils.graphics_utils import focal2fov, fov2focal, getProjectionMatrix
|
225 |
+
# import torchvision
|
226 |
+
# import subprocess
|
227 |
+
|
228 |
+
# @torch.no_grad()
|
229 |
+
# def render_path(dataset : ModelParams, iteration : int, pipeline : PipelineParams, render_resize_method='crop'):
|
230 |
+
# """
|
231 |
+
# render_resize_method: crop, pad
|
232 |
+
# """
|
233 |
+
# gaussians = GaussianModel(dataset.sh_degree)
|
234 |
+
# scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False)
|
235 |
+
|
236 |
+
# iteration = scene.loaded_iter
|
237 |
+
|
238 |
+
# bg_color = [1,1,1] if dataset.white_background else [0, 0, 0]
|
239 |
+
# background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
|
240 |
+
|
241 |
+
# model_path = dataset.model_path
|
242 |
+
# name = "render"
|
243 |
+
|
244 |
+
# views = scene.getRenderCameras()
|
245 |
+
|
246 |
+
# # print(len(views))
|
247 |
+
# render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders")
|
248 |
+
|
249 |
+
# makedirs(render_path, exist_ok=True)
|
250 |
+
|
251 |
+
# for idx, view in enumerate(tqdm(views, desc="Rendering progress")):
|
252 |
+
# if render_resize_method == 'crop':
|
253 |
+
# image_size = 256
|
254 |
+
# elif render_resize_method == 'pad':
|
255 |
+
# image_size = max(view.image_width, view.image_height)
|
256 |
+
# else:
|
257 |
+
# raise NotImplementedError
|
258 |
+
# view.original_image = torch.zeros((3, image_size, image_size), device=view.original_image.device)
|
259 |
+
# focal_length_x = fov2focal(view.FoVx, view.image_width)
|
260 |
+
# focal_length_y = fov2focal(view.FoVy, view.image_height)
|
261 |
+
# view.image_width = image_size
|
262 |
+
# view.image_height = image_size
|
263 |
+
# view.FoVx = focal2fov(focal_length_x, image_size)
|
264 |
+
# view.FoVy = focal2fov(focal_length_y, image_size)
|
265 |
+
# view.projection_matrix = getProjectionMatrix(znear=view.znear, zfar=view.zfar, fovX=view.FoVx, fovY=view.FoVy).transpose(0,1).cuda().float()
|
266 |
+
# view.full_proj_transform = (view.world_view_transform.unsqueeze(0).bmm(view.projection_matrix.unsqueeze(0))).squeeze(0)
|
267 |
+
|
268 |
+
# render_pkg = render(view, gaussians, pipeline, background)
|
269 |
+
# rendering = render_pkg["render"]
|
270 |
+
# torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png"))
|
271 |
+
|
272 |
+
# # Use ffmpeg to output video
|
273 |
+
# renders_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders.mp4")
|
274 |
+
# # Use ffmpeg to output video
|
275 |
+
# subprocess.run(["ffmpeg", "-y",
|
276 |
+
# "-framerate", "24",
|
277 |
+
# "-i", os.path.join(render_path, "%05d.png"),
|
278 |
+
# "-vf", "pad=ceil(iw/2)*2:ceil(ih/2)*2",
|
279 |
+
# "-c:v", "libx264",
|
280 |
+
# "-pix_fmt", "yuv420p",
|
281 |
+
# "-crf", "23",
|
282 |
+
# # "-pix_fmt", "yuv420p", # Set pixel format for compatibility
|
283 |
+
# renders_path], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
|
284 |
+
# )
|
285 |
+
# return renders_path
|
286 |
|
287 |
+
# renders_path = render_path(dataset, opt.iterations, pipe, render_resize_method='crop')
|
288 |
|
289 |
+
# return renders_path, point_cloud_path
|
290 |
+
return None, None
|