Spaces:
Running
on
Zero
Running
on
Zero
Add lazy loading for cuda libs
Browse files- demo/gs_train.py +26 -21
- demo/mast3r_demo.py +35 -34
demo/gs_train.py
CHANGED
@@ -6,28 +6,10 @@ import uuid
|
|
6 |
from tqdm.auto import tqdm
|
7 |
import gradio as gr
|
8 |
import importlib.util
|
|
|
9 |
|
10 |
-
|
11 |
-
gaussian_splatting_path = 'wild-gaussian-splatting/gaussian-splatting/'
|
12 |
-
sys.path.append(gaussian_splatting_path)
|
13 |
-
|
14 |
-
# Import necessary modules from the gaussian-splatting directory
|
15 |
-
from utils.loss_utils import l1_loss, ssim
|
16 |
-
from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer
|
17 |
-
from scene import Scene, GaussianModel
|
18 |
-
from utils.general_utils import safe_state
|
19 |
-
from utils.image_utils import psnr
|
20 |
-
|
21 |
-
# Dynamically import the train module from the gaussian-splatting directory
|
22 |
-
train_spec = importlib.util.spec_from_file_location("gaussian_splatting_train", os.path.join(gaussian_splatting_path, "train.py"))
|
23 |
-
gaussian_splatting_train = importlib.util.module_from_spec(train_spec)
|
24 |
-
train_spec.loader.exec_module(gaussian_splatting_train)
|
25 |
-
|
26 |
-
# Import the necessary functions from the dynamically loaded module
|
27 |
-
prepare_output_and_logger = gaussian_splatting_train.prepare_output_and_logger
|
28 |
-
training_report = gaussian_splatting_train.training_report
|
29 |
|
30 |
-
from dataclasses import dataclass, field
|
31 |
|
32 |
@dataclass
|
33 |
class PipelineParams:
|
@@ -78,6 +60,7 @@ class TrainingArgs:
|
|
78 |
checkpoint_iterations: list[int] = field(default_factory=lambda: [7_000, 15_000, 30_000])
|
79 |
start_checkpoint: str = None
|
80 |
|
|
|
81 |
def train(
|
82 |
data_source_path, sh_degree, model_path, images, resolution, white_background, data_device, eval,
|
83 |
convert_SHs_python, compute_cov3D_python, debug,
|
@@ -86,6 +69,29 @@ def train(
|
|
86 |
percent_dense, lambda_dssim, densification_interval, opacity_reset_interval,
|
87 |
densify_from_iter, densify_until_iter, densify_grad_threshold, random_background
|
88 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
print(data_source_path)
|
90 |
# Create instances of the parameter dataclasses
|
91 |
dataset = ModelParams(
|
@@ -221,7 +227,6 @@ def train(
|
|
221 |
|
222 |
|
223 |
from os import makedirs
|
224 |
-
from utils.graphics_utils import focal2fov, fov2focal, getProjectionMatrix
|
225 |
import torchvision
|
226 |
import subprocess
|
227 |
|
|
|
6 |
from tqdm.auto import tqdm
|
7 |
import gradio as gr
|
8 |
import importlib.util
|
9 |
+
from dataclasses import dataclass, field
|
10 |
|
11 |
+
import spaces
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
|
|
13 |
|
14 |
@dataclass
|
15 |
class PipelineParams:
|
|
|
60 |
checkpoint_iterations: list[int] = field(default_factory=lambda: [7_000, 15_000, 30_000])
|
61 |
start_checkpoint: str = None
|
62 |
|
63 |
+
@spaces.GPU(duration=600)
|
64 |
def train(
|
65 |
data_source_path, sh_degree, model_path, images, resolution, white_background, data_device, eval,
|
66 |
convert_SHs_python, compute_cov3D_python, debug,
|
|
|
69 |
percent_dense, lambda_dssim, densification_interval, opacity_reset_interval,
|
70 |
densify_from_iter, densify_until_iter, densify_grad_threshold, random_background
|
71 |
):
|
72 |
+
|
73 |
+
# Add the path to the gaussian-splatting repository
|
74 |
+
if 'GaussianRasterizer' not in globals():
|
75 |
+
gaussian_splatting_path = 'wild-gaussian-splatting/gaussian-splatting/'
|
76 |
+
sys.path.append(gaussian_splatting_path)
|
77 |
+
|
78 |
+
# Import necessary modules from the gaussian-splatting directory
|
79 |
+
from utils.loss_utils import l1_loss, ssim
|
80 |
+
from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer
|
81 |
+
from scene import Scene, GaussianModel
|
82 |
+
from utils.general_utils import safe_state
|
83 |
+
from utils.image_utils import psnr
|
84 |
+
from utils.graphics_utils import focal2fov, fov2focal, getProjectionMatrix
|
85 |
+
|
86 |
+
# Dynamically import the train module from the gaussian-splatting directory
|
87 |
+
train_spec = importlib.util.spec_from_file_location("gaussian_splatting_train", os.path.join(gaussian_splatting_path, "train.py"))
|
88 |
+
gaussian_splatting_train = importlib.util.module_from_spec(train_spec)
|
89 |
+
train_spec.loader.exec_module(gaussian_splatting_train)
|
90 |
+
|
91 |
+
# Import the necessary functions from the dynamically loaded module
|
92 |
+
prepare_output_and_logger = gaussian_splatting_train.prepare_output_and_logger
|
93 |
+
training_report = gaussian_splatting_train.training_report
|
94 |
+
|
95 |
print(data_source_path)
|
96 |
# Create instances of the parameter dataclasses
|
97 |
dataset = ModelParams(
|
|
|
227 |
|
228 |
|
229 |
from os import makedirs
|
|
|
230 |
import torchvision
|
231 |
import subprocess
|
232 |
|
demo/mast3r_demo.py
CHANGED
@@ -29,18 +29,6 @@ from dust3r.utils.device import to_numpy
|
|
29 |
from dust3r.viz import add_scene_cam, CAM_COLORS, OPENGL, pts3d_to_trimesh, cat_meshes
|
30 |
from dust3r.demo import get_args_parser as dust3r_get_args_parser
|
31 |
|
32 |
-
sys.path.append(os.path.join(os.path.dirname(__file__), '../wild-gaussian-splatting/gaussian-splatting'))
|
33 |
-
sys.path.append(os.path.join(os.path.dirname(__file__), '../wild-gaussian-splatting/src'))
|
34 |
-
# from colmap_dataset_utils import (
|
35 |
-
# inv,
|
36 |
-
# init_filestructure,
|
37 |
-
# save_images_masks,
|
38 |
-
# save_cameras,
|
39 |
-
# save_imagestxt,
|
40 |
-
# save_pointcloud,
|
41 |
-
# save_pointcloud_with_normals
|
42 |
-
# )
|
43 |
-
|
44 |
import matplotlib.pyplot as pl
|
45 |
|
46 |
import torch
|
@@ -151,27 +139,40 @@ def get_3D_model_from_scene(silent, scene_state, min_conf_thr=2, as_pointcloud=F
|
|
151 |
return _convert_scene_output_to_glb(outfile, rgbimg, pts3d, msk, focals, cams2world, as_pointcloud=as_pointcloud,
|
152 |
transparent_cams=transparent_cams, cam_size=cam_size, silent=silent)
|
153 |
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
|
176 |
@spaces.GPU(duration=300)
|
177 |
def get_reconstructed_scene(outdir, model, device, silent, image_size, current_scene_state,
|
@@ -222,7 +223,7 @@ def get_reconstructed_scene(outdir, model, device, silent, image_size, current_s
|
|
222 |
os.makedirs(base_colmapdata_dir, exist_ok=True)
|
223 |
colmap_data_dir = get_next_dir(base_colmapdata_dir)
|
224 |
#
|
225 |
-
|
226 |
|
227 |
if current_scene_state is not None and \
|
228 |
current_scene_state.outfile_name is not None:
|
|
|
29 |
from dust3r.viz import add_scene_cam, CAM_COLORS, OPENGL, pts3d_to_trimesh, cat_meshes
|
30 |
from dust3r.demo import get_args_parser as dust3r_get_args_parser
|
31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
import matplotlib.pyplot as pl
|
33 |
|
34 |
import torch
|
|
|
139 |
return _convert_scene_output_to_glb(outfile, rgbimg, pts3d, msk, focals, cams2world, as_pointcloud=as_pointcloud,
|
140 |
transparent_cams=transparent_cams, cam_size=cam_size, silent=silent)
|
141 |
|
142 |
+
def save_colmap_scene(scene, save_dir, min_conf_thr=2, clean_depth=False):
|
143 |
+
if 'save_pointcloud_with_normals' not in globals():
|
144 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), '../wild-gaussian-splatting/gaussian-splatting'))
|
145 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), '../wild-gaussian-splatting/src'))
|
146 |
+
from colmap_dataset_utils import (
|
147 |
+
inv,
|
148 |
+
init_filestructure,
|
149 |
+
save_images_masks,
|
150 |
+
save_cameras,
|
151 |
+
save_imagestxt,
|
152 |
+
save_pointcloud,
|
153 |
+
save_pointcloud_with_normals
|
154 |
+
)
|
155 |
+
|
156 |
+
cam2world = scene.get_im_poses().detach().cpu().numpy()
|
157 |
+
world2cam = inv(cam2world) #
|
158 |
+
principal_points = scene.get_principal_points().detach().cpu().numpy()
|
159 |
+
focals = scene.get_focals().detach().cpu().numpy()[..., None]
|
160 |
+
imgs = np.array(scene.imgs)
|
161 |
+
|
162 |
+
pts3d, _, confs = scene.get_dense_pts3d(clean_depth=clean_depth)
|
163 |
+
pts3d = [i.detach().reshape(imgs[0].shape) for i in pts3d] #
|
164 |
+
|
165 |
+
masks = to_numpy([c > min_conf_thr for c in to_numpy(confs)])
|
166 |
+
|
167 |
+
# move
|
168 |
+
mask_images = True
|
169 |
+
|
170 |
+
save_path, images_path, masks_path, sparse_path = init_filestructure(save_dir)
|
171 |
+
save_images_masks(imgs, masks, images_path, masks_path, mask_images)
|
172 |
+
save_cameras(focals, principal_points, sparse_path, imgs_shape=imgs.shape)
|
173 |
+
save_imagestxt(world2cam, sparse_path)
|
174 |
+
save_pointcloud_with_normals(imgs, pts3d, masks, sparse_path)
|
175 |
+
return save_path
|
176 |
|
177 |
@spaces.GPU(duration=300)
|
178 |
def get_reconstructed_scene(outdir, model, device, silent, image_size, current_scene_state,
|
|
|
223 |
os.makedirs(base_colmapdata_dir, exist_ok=True)
|
224 |
colmap_data_dir = get_next_dir(base_colmapdata_dir)
|
225 |
#
|
226 |
+
save_colmap_scene(scene, colmap_data_dir, min_conf_thr, clean_depth)
|
227 |
|
228 |
if current_scene_state is not None and \
|
229 |
current_scene_state.outfile_name is not None:
|