ostapagon commited on
Commit
03269be
·
1 Parent(s): 5f75751

Check allocation problem

Browse files
Files changed (1) hide show
  1. demo/gs_train.py +11 -4
demo/gs_train.py CHANGED
@@ -7,8 +7,11 @@ from tqdm.auto import tqdm
7
  import gradio as gr
8
  import importlib.util
9
  from dataclasses import dataclass, field
 
10
 
11
  import spaces
 
 
12
 
13
 
14
  @dataclass
@@ -45,7 +48,7 @@ class ModelParams:
45
  images: str = "images"
46
  resolution: int = -1
47
  white_background: bool = True
48
- data_device: str = "cuda"
49
  eval: bool = False
50
 
51
  @dataclass
@@ -60,7 +63,7 @@ class TrainingArgs:
60
  checkpoint_iterations: list[int] = field(default_factory=lambda: [7_000, 15_000, 30_000])
61
  start_checkpoint: str = None
62
 
63
- @spaces.GPU(duration=90)
64
  def train(
65
  data_source_path, sh_degree, model_path, images, resolution, white_background, data_device, eval,
66
  convert_SHs_python, compute_cov3D_python, debug,
@@ -138,6 +141,10 @@ def train(
138
  checkpoint_iterations = args.checkpoint_iterations
139
  debug_from = args.debug_from
140
 
 
 
 
 
141
  tb_writer = prepare_output_and_logger(dataset)
142
 
143
  gaussians = GaussianModel(dataset.sh_degree)
@@ -145,7 +152,7 @@ def train(
145
  gaussians.training_setup(opt)
146
 
147
  bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
148
- background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
149
 
150
  iter_start = torch.cuda.Event(enable_timing = True)
151
  iter_end = torch.cuda.Event(enable_timing = True)
@@ -174,7 +181,7 @@ def train(
174
  # Render
175
  if (iteration - 1) == debug_from:
176
  pipe.debug = True
177
- bg = torch.rand((3), device="cuda") if opt.random_background else background
178
 
179
  render_pkg = render(viewpoint_cam, gaussians, pipe, bg)
180
  image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"]
 
7
  import gradio as gr
8
  import importlib.util
9
  from dataclasses import dataclass, field
10
+ from demo_globals import DEVICE
11
 
12
  import spaces
13
+ from simple_knn._C import distCUDA2
14
+
15
 
16
 
17
  @dataclass
 
48
  images: str = "images"
49
  resolution: int = -1
50
  white_background: bool = True
51
+ data_device: str = DEVICE
52
  eval: bool = False
53
 
54
  @dataclass
 
63
  checkpoint_iterations: list[int] = field(default_factory=lambda: [7_000, 15_000, 30_000])
64
  start_checkpoint: str = None
65
 
66
+ @spaces.GPU(duration=20)
67
  def train(
68
  data_source_path, sh_degree, model_path, images, resolution, white_background, data_device, eval,
69
  convert_SHs_python, compute_cov3D_python, debug,
 
141
  checkpoint_iterations = args.checkpoint_iterations
142
  debug_from = args.debug_from
143
 
144
+ dist2 = torch.clamp_min(distCUDA2(torch.randn((90804, 3)).float().cuda()), 0.0000001)
145
+ print("dist2.shape: ", dist2.shape)
146
+
147
+
148
  tb_writer = prepare_output_and_logger(dataset)
149
 
150
  gaussians = GaussianModel(dataset.sh_degree)
 
152
  gaussians.training_setup(opt)
153
 
154
  bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
155
+ background = torch.tensor(bg_color, dtype=torch.float32, device=DEVICE)
156
 
157
  iter_start = torch.cuda.Event(enable_timing = True)
158
  iter_end = torch.cuda.Event(enable_timing = True)
 
181
  # Render
182
  if (iteration - 1) == debug_from:
183
  pipe.debug = True
184
+ bg = torch.rand((3), device=DEVICE) if opt.random_background else background
185
 
186
  render_pkg = render(viewpoint_cam, gaussians, pipe, bg)
187
  image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"]