ostapagon commited on
Commit
842f5dd
·
1 Parent(s): 2956ec7

Add lazy loading for cuda libs

Browse files
Files changed (2) hide show
  1. demo/gs_train.py +26 -21
  2. demo/mast3r_demo.py +35 -34
demo/gs_train.py CHANGED
@@ -6,28 +6,10 @@ import uuid
6
  from tqdm.auto import tqdm
7
  import gradio as gr
8
  import importlib.util
 
9
 
10
- # Add the path to the gaussian-splatting repository
11
- gaussian_splatting_path = 'wild-gaussian-splatting/gaussian-splatting/'
12
- sys.path.append(gaussian_splatting_path)
13
-
14
- # Import necessary modules from the gaussian-splatting directory
15
- from utils.loss_utils import l1_loss, ssim
16
- from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer
17
- from scene import Scene, GaussianModel
18
- from utils.general_utils import safe_state
19
- from utils.image_utils import psnr
20
-
21
- # Dynamically import the train module from the gaussian-splatting directory
22
- train_spec = importlib.util.spec_from_file_location("gaussian_splatting_train", os.path.join(gaussian_splatting_path, "train.py"))
23
- gaussian_splatting_train = importlib.util.module_from_spec(train_spec)
24
- train_spec.loader.exec_module(gaussian_splatting_train)
25
-
26
- # Import the necessary functions from the dynamically loaded module
27
- prepare_output_and_logger = gaussian_splatting_train.prepare_output_and_logger
28
- training_report = gaussian_splatting_train.training_report
29
 
30
- from dataclasses import dataclass, field
31
 
32
  @dataclass
33
  class PipelineParams:
@@ -78,6 +60,7 @@ class TrainingArgs:
78
  checkpoint_iterations: list[int] = field(default_factory=lambda: [7_000, 15_000, 30_000])
79
  start_checkpoint: str = None
80
 
 
81
  def train(
82
  data_source_path, sh_degree, model_path, images, resolution, white_background, data_device, eval,
83
  convert_SHs_python, compute_cov3D_python, debug,
@@ -86,6 +69,29 @@ def train(
86
  percent_dense, lambda_dssim, densification_interval, opacity_reset_interval,
87
  densify_from_iter, densify_until_iter, densify_grad_threshold, random_background
88
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  print(data_source_path)
90
  # Create instances of the parameter dataclasses
91
  dataset = ModelParams(
@@ -221,7 +227,6 @@ def train(
221
 
222
 
223
  from os import makedirs
224
- from utils.graphics_utils import focal2fov, fov2focal, getProjectionMatrix
225
  import torchvision
226
  import subprocess
227
 
 
6
  from tqdm.auto import tqdm
7
  import gradio as gr
8
  import importlib.util
9
+ from dataclasses import dataclass, field
10
 
11
+ import spaces
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
 
13
 
14
  @dataclass
15
  class PipelineParams:
 
60
  checkpoint_iterations: list[int] = field(default_factory=lambda: [7_000, 15_000, 30_000])
61
  start_checkpoint: str = None
62
 
63
+ @spaces.GPU(duration=600)
64
  def train(
65
  data_source_path, sh_degree, model_path, images, resolution, white_background, data_device, eval,
66
  convert_SHs_python, compute_cov3D_python, debug,
 
69
  percent_dense, lambda_dssim, densification_interval, opacity_reset_interval,
70
  densify_from_iter, densify_until_iter, densify_grad_threshold, random_background
71
  ):
72
+
73
+ # Add the path to the gaussian-splatting repository
74
+ if 'GaussianRasterizer' not in globals():
75
+ gaussian_splatting_path = 'wild-gaussian-splatting/gaussian-splatting/'
76
+ sys.path.append(gaussian_splatting_path)
77
+
78
+ # Import necessary modules from the gaussian-splatting directory
79
+ from utils.loss_utils import l1_loss, ssim
80
+ from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer
81
+ from scene import Scene, GaussianModel
82
+ from utils.general_utils import safe_state
83
+ from utils.image_utils import psnr
84
+ from utils.graphics_utils import focal2fov, fov2focal, getProjectionMatrix
85
+
86
+ # Dynamically import the train module from the gaussian-splatting directory
87
+ train_spec = importlib.util.spec_from_file_location("gaussian_splatting_train", os.path.join(gaussian_splatting_path, "train.py"))
88
+ gaussian_splatting_train = importlib.util.module_from_spec(train_spec)
89
+ train_spec.loader.exec_module(gaussian_splatting_train)
90
+
91
+ # Import the necessary functions from the dynamically loaded module
92
+ prepare_output_and_logger = gaussian_splatting_train.prepare_output_and_logger
93
+ training_report = gaussian_splatting_train.training_report
94
+
95
  print(data_source_path)
96
  # Create instances of the parameter dataclasses
97
  dataset = ModelParams(
 
227
 
228
 
229
  from os import makedirs
 
230
  import torchvision
231
  import subprocess
232
 
demo/mast3r_demo.py CHANGED
@@ -29,18 +29,6 @@ from dust3r.utils.device import to_numpy
29
  from dust3r.viz import add_scene_cam, CAM_COLORS, OPENGL, pts3d_to_trimesh, cat_meshes
30
  from dust3r.demo import get_args_parser as dust3r_get_args_parser
31
 
32
- sys.path.append(os.path.join(os.path.dirname(__file__), '../wild-gaussian-splatting/gaussian-splatting'))
33
- sys.path.append(os.path.join(os.path.dirname(__file__), '../wild-gaussian-splatting/src'))
34
- # from colmap_dataset_utils import (
35
- # inv,
36
- # init_filestructure,
37
- # save_images_masks,
38
- # save_cameras,
39
- # save_imagestxt,
40
- # save_pointcloud,
41
- # save_pointcloud_with_normals
42
- # )
43
-
44
  import matplotlib.pyplot as pl
45
 
46
  import torch
@@ -151,27 +139,40 @@ def get_3D_model_from_scene(silent, scene_state, min_conf_thr=2, as_pointcloud=F
151
  return _convert_scene_output_to_glb(outfile, rgbimg, pts3d, msk, focals, cams2world, as_pointcloud=as_pointcloud,
152
  transparent_cams=transparent_cams, cam_size=cam_size, silent=silent)
153
 
154
- # def save_colmap_scene(scene, save_dir, min_conf_thr=2, clean_depth=False):
155
- # cam2world = scene.get_im_poses().detach().cpu().numpy()
156
- # world2cam = inv(cam2world) #
157
- # principal_points = scene.get_principal_points().detach().cpu().numpy()
158
- # focals = scene.get_focals().detach().cpu().numpy()[..., None]
159
- # imgs = np.array(scene.imgs)
160
-
161
- # pts3d, _, confs = scene.get_dense_pts3d(clean_depth=clean_depth)
162
- # pts3d = [i.detach().reshape(imgs[0].shape) for i in pts3d] #
163
-
164
- # masks = to_numpy([c > min_conf_thr for c in to_numpy(confs)])
165
-
166
- # # move
167
- # mask_images = True
168
-
169
- # save_path, images_path, masks_path, sparse_path = init_filestructure(save_dir)
170
- # save_images_masks(imgs, masks, images_path, masks_path, mask_images)
171
- # save_cameras(focals, principal_points, sparse_path, imgs_shape=imgs.shape)
172
- # save_imagestxt(world2cam, sparse_path)
173
- # save_pointcloud_with_normals(imgs, pts3d, masks, sparse_path)
174
- # return save_path
 
 
 
 
 
 
 
 
 
 
 
 
 
175
 
176
  @spaces.GPU(duration=300)
177
  def get_reconstructed_scene(outdir, model, device, silent, image_size, current_scene_state,
@@ -222,7 +223,7 @@ def get_reconstructed_scene(outdir, model, device, silent, image_size, current_s
222
  os.makedirs(base_colmapdata_dir, exist_ok=True)
223
  colmap_data_dir = get_next_dir(base_colmapdata_dir)
224
  #
225
- # save_colmap_scene(scene, colmap_data_dir, min_conf_thr, clean_depth)
226
 
227
  if current_scene_state is not None and \
228
  current_scene_state.outfile_name is not None:
 
29
  from dust3r.viz import add_scene_cam, CAM_COLORS, OPENGL, pts3d_to_trimesh, cat_meshes
30
  from dust3r.demo import get_args_parser as dust3r_get_args_parser
31
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  import matplotlib.pyplot as pl
33
 
34
  import torch
 
139
  return _convert_scene_output_to_glb(outfile, rgbimg, pts3d, msk, focals, cams2world, as_pointcloud=as_pointcloud,
140
  transparent_cams=transparent_cams, cam_size=cam_size, silent=silent)
141
 
142
+ def save_colmap_scene(scene, save_dir, min_conf_thr=2, clean_depth=False):
143
+ if 'save_pointcloud_with_normals' not in globals():
144
+ sys.path.append(os.path.join(os.path.dirname(__file__), '../wild-gaussian-splatting/gaussian-splatting'))
145
+ sys.path.append(os.path.join(os.path.dirname(__file__), '../wild-gaussian-splatting/src'))
146
+ from colmap_dataset_utils import (
147
+ inv,
148
+ init_filestructure,
149
+ save_images_masks,
150
+ save_cameras,
151
+ save_imagestxt,
152
+ save_pointcloud,
153
+ save_pointcloud_with_normals
154
+ )
155
+
156
+ cam2world = scene.get_im_poses().detach().cpu().numpy()
157
+ world2cam = inv(cam2world) #
158
+ principal_points = scene.get_principal_points().detach().cpu().numpy()
159
+ focals = scene.get_focals().detach().cpu().numpy()[..., None]
160
+ imgs = np.array(scene.imgs)
161
+
162
+ pts3d, _, confs = scene.get_dense_pts3d(clean_depth=clean_depth)
163
+ pts3d = [i.detach().reshape(imgs[0].shape) for i in pts3d] #
164
+
165
+ masks = to_numpy([c > min_conf_thr for c in to_numpy(confs)])
166
+
167
+ # move
168
+ mask_images = True
169
+
170
+ save_path, images_path, masks_path, sparse_path = init_filestructure(save_dir)
171
+ save_images_masks(imgs, masks, images_path, masks_path, mask_images)
172
+ save_cameras(focals, principal_points, sparse_path, imgs_shape=imgs.shape)
173
+ save_imagestxt(world2cam, sparse_path)
174
+ save_pointcloud_with_normals(imgs, pts3d, masks, sparse_path)
175
+ return save_path
176
 
177
  @spaces.GPU(duration=300)
178
  def get_reconstructed_scene(outdir, model, device, silent, image_size, current_scene_state,
 
223
  os.makedirs(base_colmapdata_dir, exist_ok=True)
224
  colmap_data_dir = get_next_dir(base_colmapdata_dir)
225
  #
226
+ save_colmap_scene(scene, colmap_data_dir, min_conf_thr, clean_depth)
227
 
228
  if current_scene_state is not None and \
229
  current_scene_state.outfile_name is not None: