balthou's picture
update PCA bug
38d6139
import os
src_path = os.path.abspath(os.path.join(os.path.dirname(__file__), 'src'))
os.sys.path.append(src_path)
from pixrender.rendering.splatting import ms_splatting
from interactive_pipe import interactive
from pathlib import Path
from pixrender.camera.camera_geometry import get_camera_extrinsics, get_camera_intrinsics, set_camera_parameters_orbit_mode
import torch
from pixrender.rendering.splatting import splat_points
from pixrender.interactive.interactive_plugins import define_default_sliders
# from novel_views import load_colored_point_cloud_from_files
from pixrender.learning.utils import load_model
from pixrender.rendering.forward_project import project_3d_to_2d
from interactive_pipe import interactive_pipeline
from pixrender.interactive.utils import tensor_to_image, rescale_image
from pixrender.properties import DEVICE, SCALE_LIST
from shared_parser import get_shared_parser
from experiments_definition import get_experiment_from_id
from pixrender.learning.experiments import get_training_content
from typing import Optional
def infer_image(splatted_image, model, global_params={}) -> torch.Tensor:
if model is not None:
with torch.no_grad():
ms_list = [spl.permute(2, 0, 1).unsqueeze(0)
for spl in splatted_image]
with torch.no_grad():
inferred_image = model(ms_list)[global_params.get("scale", 0)]
inferred_image = inferred_image.squeeze(0).permute(1, 2, 0)
return inferred_image
else:
return splatted_image
def apply_pca_to_tensor(tensor, n_components=3):
"""
Applies PCA on a tensor of shape (H, W, C) to reduce it to (H, W, n_components).
Parameters:
- tensor: Input tensor of shape (H, W, C)
- n_components: Number of principal components to keep
Returns:
- pca_tensor: Tensor of shape (H, W, n_components) after PCA
"""
# Validate inputs
if tensor.dim() != 3 or tensor.size(2) < n_components:
raise ValueError(
"Input tensor must be of shape (H, W, C) with C >= n_components.")
H, W, C = tensor.shape
# Flatten the (H, W) dimensions
flat_tensor = tensor.reshape(-1, C)
# Standardize the features
mean = flat_tensor.mean(dim=0)
std = flat_tensor.std(dim=0)
standardized_tensor = (flat_tensor - mean) / std
try:
# Perform SVD, which is equivalent to PCA since the data is centered
U, S, V = torch.svd(standardized_tensor)
except Exception as e:
print(e)
return tensor[:, :n_components]
# Keep the top n_components
principal_components = V[:, :n_components]
# Project the data onto the top n_components
pca_result = torch.mm(standardized_tensor, principal_components)
# Reshape back to (H, W, n_components)
pca_tensor = pca_result.reshape(H, W, n_components)
return pca_tensor
@interactive(neural_point_cloud_pca_flag=(True,))
def debug_splat(img_ms, neural_point_cloud_pca_flag=True, global_params={}):
seleced_scale = global_params.get('scale', 0)
if neural_point_cloud_pca_flag:
selected_full_res = apply_pca_to_tensor(img_ms[seleced_scale])
if not selected_full_res.shape == img_ms[seleced_scale][:, :, :3].shape:
selected_full_res = img_ms[seleced_scale][:, :, :3]
else:
selected_full_res = img_ms[seleced_scale][:, :, :3]
return rescale_image(tensor_to_image(selected_full_res).clip(0, 1), global_params=global_params)
def splat_pipeline_novel_view(wc_points, wc_normals, points_colors, model, scales_list):
# yaw, pitch, roll, cam_pos = set_camera_parameters()
yaw, pitch, roll, cam_pos = set_camera_parameters_orbit_mode()
camera_extrinsics = get_camera_extrinsics(yaw, pitch, roll, cam_pos)
camera_intrinsics, w, h = get_camera_intrinsics()
cc_points, points_depths, cc_normals = project_3d_to_2d(
wc_points, camera_intrinsics, camera_extrinsics, wc_normals)
# # Let's splat the triangle nodes
# splatted_image = splat_points(cc_points, points_colors, points_depths, w, h, camera_intrinsics, cc_normals)
splatted_image = ms_splatting(cc_points, points_colors, points_depths, w,
h, camera_intrinsics, cc_normals, scales_list)
inferred_image = infer_image(splatted_image, model)
inferred_image = tensor_to_image(inferred_image)
inferred_image = rescale_image(inferred_image)
splatted_image_debug = debug_splat(splatted_image)
return inferred_image, splatted_image_debug
def main_interactive_version(exp, training_dir):
config = get_experiment_from_id(exp)
model, optim = get_training_content(config, training_mode=False)
# model_path = training_dir / f"__{exp:04d}" / "best_model.pt"
model_path = training_dir / f"__{exp:04d}" / "last_model.pt"
# wc_points, wc_normals, color_pred = load_colored_point_cloud_from_files(splat_scene_path)
model_state_dict, wc_points, wc_normals, color_pred = load_model(
model_path)
wc_points = wc_points.detach().to(DEVICE)
wc_normals = wc_normals.detach().to(DEVICE)
color_pred = color_pred.detach().to(DEVICE)
define_default_sliders(orbit_mode=True, multiscale=model.n_scales)
if model_path is not None:
print(model.count_parameters())
model.load_state_dict(model_state_dict)
model.to(DEVICE)
model.eval()
else:
model = None
interactive_pipeline(
gui="gradio",
cache=True,
safe_input_buffer_deepcopy=False,
size=(20, 15)
)(splat_pipeline_novel_view)(wc_points, wc_normals, color_pred, model, config[SCALE_LIST])
if __name__ == '__main__':
main_interactive_version(55, training_dir=Path("pretrained_scene"))