Abdulrahman1989 commited on
Commit
6bc357d
·
1 Parent(s): e5eeba4

Add Image3DProcessor

Browse files
Files changed (1) hide show
  1. Image3DProcessor.py +91 -0
Image3DProcessor.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torchvision
4
+ import numpy as np
5
+ import imageio
6
+ from PIL import Image
7
+ import rembg
8
+ from omegaconf import OmegaConf
9
+ from huggingface_hub import hf_hub_download
10
+ from utils.app_utils import (
11
+ remove_background,
12
+ resize_foreground,
13
+ set_white_background,
14
+ resize_to_128,
15
+ to_tensor,
16
+ get_source_camera_v2w_rmo_and_quats,
17
+ get_target_cameras,
18
+ export_to_obj
19
+ )
20
+ from scene.gaussian_predictor import GaussianSplatPredictor
21
+ from gaussian_renderer import render_predicted
22
+
23
+ class Image3DProcessor:
24
+ def __init__(self, model_cfg_path, model_repo_id, model_filename):
25
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
26
+
27
+ # Load model configuration
28
+ self.model_cfg = OmegaConf.load(model_cfg_path)
29
+
30
+ # Load pre-trained model weights
31
+ model_path = hf_hub_download(repo_id=model_repo_id, filename=model_filename)
32
+ self.model = GaussianSplatPredictor(self.model_cfg)
33
+ ckpt_loaded = torch.load(model_path, map_location=self.device)
34
+ self.model.load_state_dict(ckpt_loaded["model_state_dict"])
35
+ self.model.to(self.device)
36
+ self.model.eval()
37
+
38
+ def preprocess(self, input_image, preprocess_background=True, foreground_ratio=0.65):
39
+ rembg_session = rembg.new_session()
40
+ if preprocess_background:
41
+ image = input_image.convert("RGB")
42
+ image = remove_background(image, rembg_session)
43
+ image = resize_foreground(image, foreground_ratio)
44
+ image = set_white_background(image)
45
+ else:
46
+ image = input_image
47
+ if image.mode == "RGBA":
48
+ image = set_white_background(image)
49
+ image = resize_to_128(image)
50
+ return image
51
+
52
+ @torch.no_grad()
53
+ def reconstruct_and_export(self, image):
54
+ image_tensor = to_tensor(image).to(self.device)
55
+ view_to_world_source, rot_transform_quats = get_source_camera_v2w_rmo_and_quats()
56
+ view_to_world_source = view_to_world_source.to(self.device)
57
+ rot_transform_quats = rot_transform_quats.to(self.device)
58
+
59
+ reconstruction_unactivated = self.model(
60
+ image_tensor.unsqueeze(0).unsqueeze(0),
61
+ view_to_world_source,
62
+ rot_transform_quats,
63
+ None,
64
+ activate_output=False
65
+ )
66
+
67
+ reconstruction = {k: v[0].contiguous() for k, v in reconstruction_unactivated.items()}
68
+ reconstruction["scaling"] = self.model.scaling_activation(reconstruction["scaling"])
69
+ reconstruction["opacity"] = self.model.opacity_activation(reconstruction["opacity"])
70
+
71
+ # Render images in a loop
72
+ world_view_transforms, full_proj_transforms, camera_centers = get_target_cameras()
73
+ background = torch.tensor([1, 1, 1], dtype=torch.float32, device=self.device)
74
+ loop_renders = []
75
+ t_to_512 = torchvision.transforms.Resize(512, interpolation=torchvision.transforms.InterpolationMode.NEAREST)
76
+
77
+ for r_idx in range(world_view_transforms.shape[0]):
78
+ rendered_image = render_predicted(
79
+ reconstruction,
80
+ world_view_transforms[r_idx].to(self.device),
81
+ full_proj_transforms[r_idx].to(self.device),
82
+ camera_centers[r_idx].to(self.device),
83
+ background,
84
+ self.model_cfg,
85
+ focals_pixels=None
86
+ )["render"]
87
+ rendered_image = t_to_512(rendered_image)
88
+ loop_renders.append(torch.clamp(rendered_image * 255, 0.0, 255.0).detach().permute(1, 2, 0).cpu().numpy().astype(np.uint8))
89
+
90
+ return mesh_path, video_path
91
+