File size: 4,458 Bytes
bc9f84c
27ad0ef
3fa99fd
afeb5b7
 
6bc357d
 
 
 
 
 
 
 
1d126ec
6bc357d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d8af0d
55c6860
9d8af0d
6bc357d
 
 
 
28f286f
6bc357d
 
 
 
 
673639e
 
9eb36c8
673639e
6bc357d
673639e
e4ea7b7
 
 
 
673639e
6bc357d
 
 
 
 
 
 
 
 
673639e
6bc357d
673639e
82bebc2
6bc357d
 
c35ebce
 
 
 
6bc357d
 
 
 
 
1598c25
6bc357d
 
 
 
 
 
1598c25
6bc357d
aafb120
 
6bc357d
 
 
 
 
c35ebce
6bc357d
 
 
 
91765b3
 
 
6bc357d
ce7c685
6bc357d
 
 
 
c35ebce
b2ebb9a
 
 
 
 
9d98595
5318e63
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import sys
import os
os.system('bash setup.sh')
sys.path.append('/home/user/app/splatter-image')
sys.path.append('/home/user/app/diff-gaussian-rasterization')
import torch
import torchvision
import numpy as np
import imageio
from PIL import Image
import rembg
from omegaconf import OmegaConf
from huggingface_hub import hf_hub_download
from io import BytesIO
from utils.app_utils import (
    remove_background,
    resize_foreground,
    set_white_background,
    resize_to_128,
    to_tensor,
    get_source_camera_v2w_rmo_and_quats,
    get_target_cameras,
    export_to_obj
)
from scene.gaussian_predictor import GaussianSplatPredictor
from gaussian_renderer import render_predicted

class Image3DProcessor:
    def __init__(self, model_cfg_path, model_repo_id, model_filename):
        self.use_cuda = torch.cuda.is_available()
        self.device = torch.device("cuda" if self.use_cuda else "cpu")
        print("Image3DProcessor Device: ", self.device)
        # Load model configuration
        self.model_cfg = OmegaConf.load(model_cfg_path)
        
        # Load pre-trained model weights
        model_path = model_filename
        self.model = GaussianSplatPredictor(self.model_cfg)
        ckpt_loaded = torch.load(model_path, map_location=self.device)
        self.model.load_state_dict(ckpt_loaded["model_state_dict"])
        self.model.to(self.device)
        self.model.eval()

    @torch.no_grad()
    def preprocess(self, input_image, preprocess_background=True, foreground_ratio=0.65):
        # Create a new Rembg session
        rembg_session = rembg.new_session()

        # Convert bytes to a PIL image if necessary
        if isinstance(input_image, bytes):
            input_image = Image.open(BytesIO(input_image))

        # Preprocess input image
        if preprocess_background:
            image = input_image.convert("RGB")
            image = remove_background(image, rembg_session)
            image = resize_foreground(image, foreground_ratio)
            image = set_white_background(image)
        else:
            image = input_image
            if image.mode == "RGBA":
                image = set_white_background(image)

        image = resize_to_128(image)

        return image
    @torch.no_grad()
    def reconstruct_and_export(self, image):
        """
        Passes image through model and outputs the reconstruction.
        """
        image= np.array(image)
        image_tensor = to_tensor(image).to(self.device)
        view_to_world_source, rot_transform_quats = get_source_camera_v2w_rmo_and_quats()
        view_to_world_source = view_to_world_source.to(self.device)
        rot_transform_quats = rot_transform_quats.to(self.device)

        reconstruction_unactivated = self.model(
            image_tensor.unsqueeze(0).unsqueeze(0),
            view_to_world_source,
            rot_transform_quats,
            None,
            activate_output=False
        )

        reconstruction = {k: v[0].contiguous() for k, v in reconstruction_unactivated.items()}
        reconstruction["scaling"] = self.model.scaling_activation(reconstruction["scaling"])
        reconstruction["opacity"] = self.model.opacity_activation(reconstruction["opacity"])

        # Render images in a loop
        world_view_transforms, full_proj_transforms, camera_centers = get_target_cameras()
        background = torch.tensor([1, 1, 1], dtype=torch.float32, device=self.device)
        loop_renders = []
        t_to_512 = torchvision.transforms.Resize(512, interpolation=torchvision.transforms.InterpolationMode.NEAREST)

        for r_idx in range(world_view_transforms.shape[0]):
            rendered_image = render_predicted(
                reconstruction,
                world_view_transforms[r_idx].to(self.device),
                full_proj_transforms[r_idx].to(self.device),
                camera_centers[r_idx].to(self.device),
                background,
                self.model_cfg,
                focals_pixels=None
            )["render"]
            rendered_image = t_to_512(rendered_image)
            loop_renders.append(torch.clamp(rendered_image * 255, 0.0, 255.0).detach().permute(1, 2, 0).cpu().numpy().astype(np.uint8))
    
        # Save video to a file and load its content
        video_path = "loop_.mp4"
        imageio.mimsave(video_path, loop_renders, fps=25)
        with open(video_path, "rb") as video_file:
            video_data = video_file.read()

        return video_data