lemonaddie commited on
Commit
7807e94
·
verified ·
1 Parent(s): 9085e1a

Update models/depth_normal_pipeline_clip.py

Browse files
models/depth_normal_pipeline_clip.py CHANGED
@@ -79,7 +79,6 @@ class DepthNormalEstimationPipeline(DiffusionPipeline):
79
  match_input_res:bool =True,
80
  batch_size:int = 0,
81
  domain: str = "indoor",
82
- #seed: int = 0,
83
  color_map: str="Spectral",
84
  show_progress_bar:bool = True,
85
  ensemble_kwargs: Dict = None,
@@ -148,7 +147,6 @@ class DepthNormalEstimationPipeline(DiffusionPipeline):
148
  input_rgb=batched_image,
149
  num_inference_steps=denoising_steps,
150
  domain=domain,
151
- #seed=seed,
152
  show_pbar=show_progress_bar,
153
  )
154
  depth_pred_ls.append(depth_pred_raw.detach().clone())
@@ -232,7 +230,6 @@ class DepthNormalEstimationPipeline(DiffusionPipeline):
232
  def single_infer(self,input_rgb:torch.Tensor,
233
  num_inference_steps:int,
234
  domain:str,
235
- #seed: int,
236
  show_pbar:bool,):
237
 
238
  device = input_rgb.device
@@ -244,9 +241,7 @@ class DepthNormalEstimationPipeline(DiffusionPipeline):
244
  # encode image
245
  rgb_latent = self.encode_RGB(input_rgb)
246
 
247
- # Initial depth map (Guassian noise)
248
- #if seed >= 0:
249
- #torch.manual_seed(0)
250
  geo_latent = torch.randn(rgb_latent.shape, device=device, dtype=self.dtype).repeat(2,1,1,1)
251
  rgb_latent = rgb_latent.repeat(2,1,1,1)
252
 
@@ -258,7 +253,7 @@ class DepthNormalEstimationPipeline(DiffusionPipeline):
258
  (rgb_latent.shape[0], 1, 1)
259
  ) # [B, 1, 768]
260
 
261
- # hybrid hierarchical switcher
262
  geo_class = torch.tensor([[0., 1.], [1, 0]], device=device, dtype=self.dtype)
263
  geo_embedding = torch.cat([torch.sin(geo_class), torch.cos(geo_class)], dim=-1)
264
 
 
79
  match_input_res:bool =True,
80
  batch_size:int = 0,
81
  domain: str = "indoor",
 
82
  color_map: str="Spectral",
83
  show_progress_bar:bool = True,
84
  ensemble_kwargs: Dict = None,
 
147
  input_rgb=batched_image,
148
  num_inference_steps=denoising_steps,
149
  domain=domain,
 
150
  show_pbar=show_progress_bar,
151
  )
152
  depth_pred_ls.append(depth_pred_raw.detach().clone())
 
230
  def single_infer(self,input_rgb:torch.Tensor,
231
  num_inference_steps:int,
232
  domain:str,
 
233
  show_pbar:bool,):
234
 
235
  device = input_rgb.device
 
241
  # encode image
242
  rgb_latent = self.encode_RGB(input_rgb)
243
 
244
+ # Initial geometric maps (Guassian noise)
 
 
245
  geo_latent = torch.randn(rgb_latent.shape, device=device, dtype=self.dtype).repeat(2,1,1,1)
246
  rgb_latent = rgb_latent.repeat(2,1,1,1)
247
 
 
253
  (rgb_latent.shape[0], 1, 1)
254
  ) # [B, 1, 768]
255
 
256
+ # hybrid switcher
257
  geo_class = torch.tensor([[0., 1.], [1, 0]], device=device, dtype=self.dtype)
258
  geo_embedding = torch.cat([torch.sin(geo_class), torch.cos(geo_class)], dim=-1)
259