Spaces:
Paused
Paused
import sys | |
import os | |
os.system('bash setup.sh') | |
sys.path.append('/home/user/app/splatter-image') | |
sys.path.append('/home/user/app/diff-gaussian-rasterization') | |
import torch | |
import torchvision | |
import numpy as np | |
import imageio | |
from PIL import Image | |
import rembg | |
from omegaconf import OmegaConf | |
from huggingface_hub import hf_hub_download | |
from io import BytesIO | |
from utils.app_utils import ( | |
remove_background, | |
resize_foreground, | |
set_white_background, | |
resize_to_128, | |
to_tensor, | |
get_source_camera_v2w_rmo_and_quats, | |
get_target_cameras, | |
export_to_obj | |
) | |
from scene.gaussian_predictor import GaussianSplatPredictor | |
from gaussian_renderer import render_predicted | |
class Image3DProcessor: | |
def __init__(self, model_cfg_path, model_repo_id, model_filename): | |
self.use_cuda = torch.cuda.is_available() | |
self.device = torch.device("cuda" if self.use_cuda else "cpu") | |
print("Image3DProcessor Device: ", self.device) | |
# Load model configuration | |
self.model_cfg = OmegaConf.load(model_cfg_path) | |
# Load pre-trained model weights | |
model_path = model_filename | |
self.model = GaussianSplatPredictor(self.model_cfg) | |
ckpt_loaded = torch.load(model_path, map_location=self.device) | |
self.model.load_state_dict(ckpt_loaded["model_state_dict"]) | |
self.model.to(self.device) | |
self.model.eval() | |
def preprocess(self, input_image, preprocess_background=True, foreground_ratio=0.65): | |
# Create a new Rembg session | |
rembg_session = rembg.new_session() | |
# Convert bytes to a PIL image if necessary | |
if isinstance(input_image, bytes): | |
input_image = Image.open(BytesIO(input_image)) | |
# Preprocess input image | |
if preprocess_background: | |
image = input_image.convert("RGB") | |
image = remove_background(image, rembg_session) | |
image = resize_foreground(image, foreground_ratio) | |
image = set_white_background(image) | |
else: | |
image = input_image | |
if image.mode == "RGBA": | |
image = set_white_background(image) | |
image = resize_to_128(image) | |
return image | |
def reconstruct_and_export(self, image): | |
""" | |
Passes image through model and outputs the reconstruction. | |
""" | |
image= np.array(image) | |
image_tensor = to_tensor(image).to(self.device) | |
view_to_world_source, rot_transform_quats = get_source_camera_v2w_rmo_and_quats() | |
view_to_world_source = view_to_world_source.to(self.device) | |
rot_transform_quats = rot_transform_quats.to(self.device) | |
reconstruction_unactivated = self.model( | |
image_tensor.unsqueeze(0).unsqueeze(0), | |
view_to_world_source, | |
rot_transform_quats, | |
None, | |
activate_output=False | |
) | |
reconstruction = {k: v[0].contiguous() for k, v in reconstruction_unactivated.items()} | |
reconstruction["scaling"] = self.model.scaling_activation(reconstruction["scaling"]) | |
reconstruction["opacity"] = self.model.opacity_activation(reconstruction["opacity"]) | |
# Render images in a loop | |
world_view_transforms, full_proj_transforms, camera_centers = get_target_cameras() | |
background = torch.tensor([1, 1, 1], dtype=torch.float32, device=self.device) | |
loop_renders = [] | |
t_to_512 = torchvision.transforms.Resize(512, interpolation=torchvision.transforms.InterpolationMode.NEAREST) | |
for r_idx in range(world_view_transforms.shape[0]): | |
rendered_image = render_predicted( | |
reconstruction, | |
world_view_transforms[r_idx].to(self.device), | |
full_proj_transforms[r_idx].to(self.device), | |
camera_centers[r_idx].to(self.device), | |
background, | |
self.model_cfg, | |
focals_pixels=None | |
)["render"] | |
rendered_image = t_to_512(rendered_image) | |
loop_renders.append(torch.clamp(rendered_image * 255, 0.0, 255.0).detach().permute(1, 2, 0).cpu().numpy().astype(np.uint8)) | |
# Save video to a file and load its content | |
video_path = "loop_.mp4" | |
imageio.mimsave(video_path, loop_renders, fps=25) | |
with open(video_path, "rb") as video_file: | |
video_data = video_file.read() | |
return video_data |