Abdulrahman1989 commited on
Commit
673639e
·
1 Parent(s): b2ebb9a
Files changed (2) hide show
  1. Image3DProcessor.py +12 -8
  2. app.py +5 -3
Image3DProcessor.py CHANGED
@@ -40,13 +40,13 @@ class Image3DProcessor:
40
  self.model.load_state_dict(ckpt_loaded["model_state_dict"])
41
  self.model.to(self.device)
42
  self.model.eval()
43
-
 
44
  def preprocess(self, input_image, preprocess_background=True, foreground_ratio=0.65):
45
- # Convert bytes to PIL Image if necessary
46
- if isinstance(input_image, bytes):
47
- input_image = Image.open(BytesIO(input_image))
48
-
49
  rembg_session = rembg.new_session()
 
 
50
  if preprocess_background:
51
  image = input_image.convert("RGB")
52
  image = remove_background(image, rembg_session)
@@ -56,15 +56,19 @@ class Image3DProcessor:
56
  image = input_image
57
  if image.mode == "RGBA":
58
  image = set_white_background(image)
 
59
  image = resize_to_128(image)
 
60
  return image
61
 
62
  @torch.no_grad()
63
  def reconstruct_and_export(self, image):
64
- # Convert PIL Image to NumPy array if needed
65
  if isinstance(image, Image.Image):
66
  image = np.array(image)
67
-
 
 
68
  image_tensor = to_tensor(image).to(self.device)
69
  view_to_world_source, rot_transform_quats = get_source_camera_v2w_rmo_and_quats()
70
  view_to_world_source = view_to_world_source.to(self.device)
@@ -113,4 +117,4 @@ class Image3DProcessor:
113
  with open(mesh_path, "rb") as mesh_file:
114
  mesh_data = mesh_file.read()
115
 
116
- return mesh_data, video_data
 
40
  self.model.load_state_dict(ckpt_loaded["model_state_dict"])
41
  self.model.to(self.device)
42
  self.model.eval()
43
+
44
+ @torch.no_grad()
45
  def preprocess(self, input_image, preprocess_background=True, foreground_ratio=0.65):
46
+ # Create a new Rembg session
 
 
 
47
  rembg_session = rembg.new_session()
48
+
49
+ # Preprocess input image
50
  if preprocess_background:
51
  image = input_image.convert("RGB")
52
  image = remove_background(image, rembg_session)
 
56
  image = input_image
57
  if image.mode == "RGBA":
58
  image = set_white_background(image)
59
+
60
  image = resize_to_128(image)
61
+
62
  return image
63
 
64
  @torch.no_grad()
65
  def reconstruct_and_export(self, image):
66
+ # Ensure the input image is a NumPy array after preprocessing
67
  if isinstance(image, Image.Image):
68
  image = np.array(image)
69
+ elif isinstance(image, bytes):
70
+ image = np.array(Image.open(BytesIO(image)))
71
+
72
  image_tensor = to_tensor(image).to(self.device)
73
  view_to_world_source, rot_transform_quats = get_source_camera_v2w_rmo_and_quats()
74
  view_to_world_source = view_to_world_source.to(self.device)
 
117
  with open(mesh_path, "rb") as mesh_file:
118
  mesh_data = mesh_file.read()
119
 
120
+ return mesh_data, video_data
app.py CHANGED
@@ -17,8 +17,10 @@ class VideoGenerator:
17
  self.processor = Image3DProcessor(model_cfg_path, model_repo_id, model_filename)
18
 
19
  def generate_3d_video(self, image):
20
- # Process the image and create a 3D video and mesh
21
- mesh_data, video_data = self.processor.reconstruct_and_export(image)
 
 
22
  return mesh_data, video_data
23
 
24
  class GradioApp:
@@ -64,4 +66,4 @@ class GradioApp:
64
 
65
  if __name__ == "__main__":
66
  app = GradioApp()
67
- app.launch()
 
17
  self.processor = Image3DProcessor(model_cfg_path, model_repo_id, model_filename)
18
 
19
  def generate_3d_video(self, image):
20
+ # Preprocess the image first
21
+ processed_image = self.processor.preprocess(image)
22
+ # Then pass it to reconstruct_and_export
23
+ mesh_data, video_data = self.processor.reconstruct_and_export(processed_image)
24
  return mesh_data, video_data
25
 
26
  class GradioApp:
 
66
 
67
  if __name__ == "__main__":
68
  app = GradioApp()
69
+ app.launch()