Spaces:
Paused
Paused
File size: 4,458 Bytes
bc9f84c 27ad0ef 3fa99fd afeb5b7 6bc357d 1d126ec 6bc357d 9d8af0d 55c6860 9d8af0d 6bc357d 28f286f 6bc357d 673639e 9eb36c8 673639e 6bc357d 673639e e4ea7b7 673639e 6bc357d 673639e 6bc357d 673639e 82bebc2 6bc357d c35ebce 6bc357d 1598c25 6bc357d 1598c25 6bc357d aafb120 6bc357d c35ebce 6bc357d 91765b3 6bc357d ce7c685 6bc357d c35ebce b2ebb9a 9d98595 5318e63 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import sys
import os
os.system('bash setup.sh')
sys.path.append('/home/user/app/splatter-image')
sys.path.append('/home/user/app/diff-gaussian-rasterization')
import torch
import torchvision
import numpy as np
import imageio
from PIL import Image
import rembg
from omegaconf import OmegaConf
from huggingface_hub import hf_hub_download
from io import BytesIO
from utils.app_utils import (
remove_background,
resize_foreground,
set_white_background,
resize_to_128,
to_tensor,
get_source_camera_v2w_rmo_and_quats,
get_target_cameras,
export_to_obj
)
from scene.gaussian_predictor import GaussianSplatPredictor
from gaussian_renderer import render_predicted
class Image3DProcessor:
def __init__(self, model_cfg_path, model_repo_id, model_filename):
self.use_cuda = torch.cuda.is_available()
self.device = torch.device("cuda" if self.use_cuda else "cpu")
print("Image3DProcessor Device: ", self.device)
# Load model configuration
self.model_cfg = OmegaConf.load(model_cfg_path)
# Load pre-trained model weights
model_path = model_filename
self.model = GaussianSplatPredictor(self.model_cfg)
ckpt_loaded = torch.load(model_path, map_location=self.device)
self.model.load_state_dict(ckpt_loaded["model_state_dict"])
self.model.to(self.device)
self.model.eval()
@torch.no_grad()
def preprocess(self, input_image, preprocess_background=True, foreground_ratio=0.65):
# Create a new Rembg session
rembg_session = rembg.new_session()
# Convert bytes to a PIL image if necessary
if isinstance(input_image, bytes):
input_image = Image.open(BytesIO(input_image))
# Preprocess input image
if preprocess_background:
image = input_image.convert("RGB")
image = remove_background(image, rembg_session)
image = resize_foreground(image, foreground_ratio)
image = set_white_background(image)
else:
image = input_image
if image.mode == "RGBA":
image = set_white_background(image)
image = resize_to_128(image)
return image
@torch.no_grad()
def reconstruct_and_export(self, image):
"""
Passes image through model and outputs the reconstruction.
"""
image= np.array(image)
image_tensor = to_tensor(image).to(self.device)
view_to_world_source, rot_transform_quats = get_source_camera_v2w_rmo_and_quats()
view_to_world_source = view_to_world_source.to(self.device)
rot_transform_quats = rot_transform_quats.to(self.device)
reconstruction_unactivated = self.model(
image_tensor.unsqueeze(0).unsqueeze(0),
view_to_world_source,
rot_transform_quats,
None,
activate_output=False
)
reconstruction = {k: v[0].contiguous() for k, v in reconstruction_unactivated.items()}
reconstruction["scaling"] = self.model.scaling_activation(reconstruction["scaling"])
reconstruction["opacity"] = self.model.opacity_activation(reconstruction["opacity"])
# Render images in a loop
world_view_transforms, full_proj_transforms, camera_centers = get_target_cameras()
background = torch.tensor([1, 1, 1], dtype=torch.float32, device=self.device)
loop_renders = []
t_to_512 = torchvision.transforms.Resize(512, interpolation=torchvision.transforms.InterpolationMode.NEAREST)
for r_idx in range(world_view_transforms.shape[0]):
rendered_image = render_predicted(
reconstruction,
world_view_transforms[r_idx].to(self.device),
full_proj_transforms[r_idx].to(self.device),
camera_centers[r_idx].to(self.device),
background,
self.model_cfg,
focals_pixels=None
)["render"]
rendered_image = t_to_512(rendered_image)
loop_renders.append(torch.clamp(rendered_image * 255, 0.0, 255.0).detach().permute(1, 2, 0).cpu().numpy().astype(np.uint8))
# Save video to a file and load its content
video_path = "loop_.mp4"
imageio.mimsave(video_path, loop_renders, fps=25)
with open(video_path, "rb") as video_file:
video_data = video_file.read()
return video_data |