Abdulrahman1989 commited on
Commit
c35ebce
·
1 Parent(s): 82bebc2

Fix reconstruct_and_export

Browse files
Files changed (1) hide show
  1. Image3DProcessor.py +14 -16
Image3DProcessor.py CHANGED
@@ -66,48 +66,46 @@ class Image3DProcessor:
66
  return image
67
  @torch.no_grad()
68
  def reconstruct_and_export(self, image):
69
- # Ensure the input image is a NumPy array after preprocessing
70
- if isinstance(image, Image.Image):
71
- image = np.array(image)
72
- elif isinstance(image, bytes):
73
- image = np.array(Image.open(BytesIO(image)))
74
-
75
  image_tensor = to_tensor(image).to(self.device)
76
  view_to_world_source, rot_transform_quats = get_source_camera_v2w_rmo_and_quats()
77
  view_to_world_source = view_to_world_source.to(self.device)
78
  rot_transform_quats = rot_transform_quats.to(self.device)
79
 
80
- reconstruction_unactivated = self.model(
81
  image_tensor.unsqueeze(0).unsqueeze(0),
82
  view_to_world_source,
83
  rot_transform_quats,
84
  None,
85
  activate_output=False
86
  )
87
-
88
  reconstruction = {k: v[0].contiguous() for k, v in reconstruction_unactivated.items()}
89
- reconstruction["scaling"] = self.model.scaling_activation(reconstruction["scaling"])
90
- reconstruction["opacity"] = self.model.opacity_activation(reconstruction["opacity"])
91
 
92
  # Render images in a loop
93
  world_view_transforms, full_proj_transforms, camera_centers = get_target_cameras()
94
  background = torch.tensor([1, 1, 1], dtype=torch.float32, device=self.device)
95
  loop_renders = []
96
- t_to_512 = torchvision.transforms.Resize(512, interpolation=torchvision.transforms.InterpolationMode.LANCZOS)
97
 
98
  for r_idx in range(world_view_transforms.shape[0]):
99
  rendered_image = render_predicted(
100
  reconstruction,
101
- world_view_transforms[r_idx].to(self.device),
102
- full_proj_transforms[r_idx].to(self.device),
103
- camera_centers[r_idx].to(self.device),
104
  background,
105
- self.model_cfg,
106
  focals_pixels=None
107
  )["render"]
108
  rendered_image = t_to_512(rendered_image)
109
  loop_renders.append(torch.clamp(rendered_image * 255, 0.0, 255.0).detach().permute(1, 2, 0).cpu().numpy().astype(np.uint8))
110
-
111
  # Save video to a file and load its content
112
  video_path = "loop_.mp4"
113
  imageio.mimsave(video_path, loop_renders, fps=25)
 
66
  return image
67
  @torch.no_grad()
68
  def reconstruct_and_export(self, image):
69
+ """
70
+ Passes image through model and outputs the reconstruction.
71
+ """
72
+ image= np.array(image)
 
 
73
  image_tensor = to_tensor(image).to(self.device)
74
  view_to_world_source, rot_transform_quats = get_source_camera_v2w_rmo_and_quats()
75
  view_to_world_source = view_to_world_source.to(self.device)
76
  rot_transform_quats = rot_transform_quats.to(self.device)
77
 
78
+ reconstruction_unactivated = model(
79
  image_tensor.unsqueeze(0).unsqueeze(0),
80
  view_to_world_source,
81
  rot_transform_quats,
82
  None,
83
  activate_output=False
84
  )
85
+
86
  reconstruction = {k: v[0].contiguous() for k, v in reconstruction_unactivated.items()}
87
+ reconstruction["scaling"] = model.scaling_activation(reconstruction["scaling"])
88
+ reconstruction["opacity"] = model.opacity_activation(reconstruction["opacity"])
89
 
90
  # Render images in a loop
91
  world_view_transforms, full_proj_transforms, camera_centers = get_target_cameras()
92
  background = torch.tensor([1, 1, 1], dtype=torch.float32, device=self.device)
93
  loop_renders = []
94
+ t_to_512 = torchvision.transforms.Resize(512, interpolation=torchvision.transforms.InterpolationMode.NEAREST)
95
 
96
  for r_idx in range(world_view_transforms.shape[0]):
97
  rendered_image = render_predicted(
98
  reconstruction,
99
+ world_view_transforms[r_idx].to(device),
100
+ full_proj_transforms[r_idx].to(device),
101
+ camera_centers[r_idx].to(device),
102
  background,
103
+ model_cfg,
104
  focals_pixels=None
105
  )["render"]
106
  rendered_image = t_to_512(rendered_image)
107
  loop_renders.append(torch.clamp(rendered_image * 255, 0.0, 255.0).detach().permute(1, 2, 0).cpu().numpy().astype(np.uint8))
108
+
109
  # Save video to a file and load its content
110
  video_path = "loop_.mp4"
111
  imageio.mimsave(video_path, loop_renders, fps=25)