InstantSplat / app.py
kairunwen's picture
update duration
f908eb9
import os, subprocess, shlex, sys, gc
import time
import torch
import numpy as np
import shutil
import argparse
import gradio as gr
import uuid
import spaces
subprocess.run(shlex.split("pip install wheel/diff_gaussian_rasterization-0.0.0-cp310-cp310-linux_x86_64.whl"))
subprocess.run(shlex.split("pip install wheel/simple_knn-0.0.0-cp310-cp310-linux_x86_64.whl"))
subprocess.run(shlex.split("pip install wheel/curope-0.0.0-cp310-cp310-linux_x86_64.whl"))
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
os.sys.path.append(os.path.abspath(os.path.join(BASE_DIR, "submodules", "dust3r")))
# os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
from dust3r.inference import inference
from dust3r.model import AsymmetricCroCo3DStereo
from dust3r.utils.device import to_numpy
from dust3r.image_pairs import make_pairs
from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
from utils.dust3r_utils import compute_global_alignment, load_images, storePly, save_colmap_cameras, save_colmap_images
from argparse import ArgumentParser, Namespace
from arguments import ModelParams, PipelineParams, OptimizationParams
from train_joint import training
from render_by_interp import render_sets
GRADIO_CACHE_FOLDER = './gradio_cache_folder'
#############################################################################################################################################
def get_dust3r_args_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--image_size", type=int, default=512, choices=[512, 224], help="image size")
parser.add_argument("--model_path", type=str, default="submodules/dust3r/checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth", help="path to the model weights")
parser.add_argument("--device", type=str, default='cuda', help="pytorch device")
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--schedule", type=str, default='linear')
parser.add_argument("--lr", type=float, default=0.01)
parser.add_argument("--niter", type=int, default=300)
parser.add_argument("--focal_avg", type=bool, default=True)
parser.add_argument("--n_views", type=int, default=3)
parser.add_argument("--base_path", type=str, default=GRADIO_CACHE_FOLDER)
return parser
@spaces.GPU(duration=150)
def process(inputfiles, input_path=None):
if input_path is not None:
imgs_path = './assets/example/' + input_path
imgs_names = sorted(os.listdir(imgs_path))
inputfiles = []
for imgs_name in imgs_names:
file_path = os.path.join(imgs_path, imgs_name)
print(file_path)
inputfiles.append(file_path)
print(inputfiles)
# ------ (1) Coarse Geometric Initialization ------
# os.system(f"rm -rf {GRADIO_CACHE_FOLDER}")
parser = get_dust3r_args_parser()
opt = parser.parse_args()
tmp_user_folder = str(uuid.uuid4()).replace("-", "")
opt.img_base_path = os.path.join(opt.base_path, tmp_user_folder)
img_folder_path = os.path.join(opt.img_base_path, "images")
img_folder_path = os.path.join(opt.img_base_path, "images")
model = AsymmetricCroCo3DStereo.from_pretrained(opt.model_path).to(opt.device)
os.makedirs(img_folder_path, exist_ok=True)
opt.n_views = len(inputfiles)
if opt.n_views == 1:
raise gr.Error("The number of input images should be greater than 1.")
print("Multiple images: ", inputfiles)
for image_path in inputfiles:
if input_path is not None:
shutil.copy(image_path, img_folder_path)
else:
shutil.move(image_path, img_folder_path)
train_img_list = sorted(os.listdir(img_folder_path))
assert len(train_img_list)==opt.n_views, f"Number of images in the folder is not equal to {opt.n_views}"
images, ori_size, imgs_resolution = load_images(img_folder_path, size=512)
resolutions_are_equal = len(set(imgs_resolution)) == 1
if resolutions_are_equal == False:
raise gr.Error("The resolution of the input image should be the same.")
print("ori_size", ori_size)
start_time = time.time()
######################################################
pairs = make_pairs(images, scene_graph='complete', prefilter=None, symmetrize=True)
output = inference(pairs, model, opt.device, batch_size=opt.batch_size)
output_colmap_path=img_folder_path.replace("images", "sparse/0")
os.makedirs(output_colmap_path, exist_ok=True)
scene = global_aligner(output, device=opt.device, mode=GlobalAlignerMode.PointCloudOptimizer)
loss = compute_global_alignment(scene=scene, init="mst", niter=opt.niter, schedule=opt.schedule, lr=opt.lr, focal_avg=opt.focal_avg)
scene = scene.clean_pointcloud()
imgs = to_numpy(scene.imgs)
focals = scene.get_focals()
poses = to_numpy(scene.get_im_poses())
pts3d = to_numpy(scene.get_pts3d())
scene.min_conf_thr = float(scene.conf_trf(torch.tensor(1.0)))
confidence_masks = to_numpy(scene.get_masks())
intrinsics = to_numpy(scene.get_intrinsics())
######################################################
end_time = time.time()
print(f"Time taken for {opt.n_views} views: {end_time-start_time} seconds")
save_colmap_cameras(ori_size, intrinsics, os.path.join(output_colmap_path, 'cameras.txt'))
save_colmap_images(poses, os.path.join(output_colmap_path, 'images.txt'), train_img_list)
pts_4_3dgs = np.concatenate([p[m] for p, m in zip(pts3d, confidence_masks)])
color_4_3dgs = np.concatenate([p[m] for p, m in zip(imgs, confidence_masks)])
color_4_3dgs = (color_4_3dgs * 255.0).astype(np.uint8)
storePly(os.path.join(output_colmap_path, "points3D.ply"), pts_4_3dgs, color_4_3dgs)
pts_4_3dgs_all = np.array(pts3d).reshape(-1, 3)
np.save(output_colmap_path + "/pts_4_3dgs_all.npy", pts_4_3dgs_all)
np.save(output_colmap_path + "/focal.npy", np.array(focals.cpu()))
### save VRAM
del scene
torch.cuda.empty_cache()
gc.collect()
##################################################################################################################################################
# ------ (2) Fast 3D-Gaussian Optimization ------
parser = ArgumentParser(description="Training script parameters")
lp = ModelParams(parser)
op = OptimizationParams(parser)
pp = PipelineParams(parser)
parser.add_argument('--debug_from', type=int, default=-1)
parser.add_argument("--test_iterations", nargs="+", type=int, default=[])
parser.add_argument("--save_iterations", nargs="+", type=int, default=[])
parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[])
parser.add_argument("--start_checkpoint", type=str, default = None)
parser.add_argument("--scene", type=str, default="demo")
parser.add_argument("--n_views", type=int, default=3)
parser.add_argument("--get_video", action="store_true")
parser.add_argument("--optim_pose", type=bool, default=True)
parser.add_argument("--skip_train", action="store_true")
parser.add_argument("--skip_test", action="store_true")
args = parser.parse_args(sys.argv[1:])
args.save_iterations.append(args.iterations)
args.model_path = opt.img_base_path + '/output/'
args.source_path = opt.img_base_path
# args.model_path = GRADIO_CACHE_FOLDER + '/output/'
# args.source_path = GRADIO_CACHE_FOLDER
args.iteration = 1000
os.makedirs(args.model_path, exist_ok=True)
training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from, args)
##################################################################################################################################################
# ------ (3) Render video by interpolation ------
parser = ArgumentParser(description="Testing script parameters")
model = ModelParams(parser, sentinel=True)
pipeline = PipelineParams(parser)
args.eval = True
args.get_video = True
args.n_views = opt.n_views
render_sets(
model.extract(args),
args.iteration,
pipeline.extract(args),
args.skip_train,
args.skip_test,
args,
)
output_ply_path = opt.img_base_path + f'/output/point_cloud/iteration_{args.iteration}/point_cloud.ply'
output_video_path = opt.img_base_path + f'/output/demo_{opt.n_views}_view.mp4'
# output_ply_path = GRADIO_CACHE_FOLDER+ f'/output/point_cloud/iteration_{args.iteration}/point_cloud.ply'
# output_video_path = GRADIO_CACHE_FOLDER+ f'/output/demo_{opt.n_views}_view.mp4'
return output_video_path, output_ply_path, output_ply_path
##################################################################################################################################################
_TITLE = '''InstantSplat'''
_DESCRIPTION = '''
<div style="display: flex; justify-content: center; align-items: center;">
<div style="width: 100%; text-align: center; font-size: 30px;">
<strong>InstantSplat: Sparse-view SfM-free Gaussian Splatting in Seconds</strong>
</div>
</div>
<p></p>
<div align="center">
<a style="display:inline-block" href="https://instantsplat.github.io/"><img src='https://img.shields.io/badge/Project_Page-1c7d45?logo=gumtree'></a>&nbsp;
<a style="display:inline-block" href="https://www.youtube.com/watch?v=fxf_ypd7eD8"><img src='https://img.shields.io/badge/Demo_Video-E33122?logo=Youtube'></a>&nbsp;
<a style="display:inline-block" href="https://arxiv.org/abs/2403.20309"><img src="https://img.shields.io/badge/ArXiv-2403.20309-b31b1b?logo=arxiv" alt='arxiv'></a>
<a title="Social" href="https://x.com/KairunWen" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
<img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social">
</a>
</div>
<p></p>
* Official demo of: [InstantSplat: Sparse-view SfM-free Gaussian Splatting in Seconds](https://instantsplat.github.io/).
* Sparse-view examples for direct viewing: you can simply click the examples (in the bottom of the page), to quickly view the results on representative data.
* Training speeds may slow if the resolution or number of images is large. To achieve performance comparable to what has been reported, please conduct tests on your own GPU (A100/4090).
'''
# <a style="display:inline-block" href="https://github.com/VITA-Group/LightGaussian"><img src="https://img.shields.io/badge/Source_Code-black?logo=Github" alt='Github Source Code'></a>&nbsp;
# &nbsp;
# <a style="display:inline-block" href="https://www.nvidia.com/en-us/"><img src="https://img.shields.io/badge/Nvidia-575757?logo=nvidia" alt='Nvidia'></a>
# * If InstantSplat is helpful, please give us a star ⭐ on Github. Thanks! <a style="display:inline-block; margin-left: .5em" href="https://github.com/VITA-Group/LightGaussian"><img src='https://img.shields.io/github/stars/VITA-Group/LightGaussian?style=social'/></a>
# block = gr.Blocks(title=_TITLE).queue()
block = gr.Blocks().queue()
with block:
with gr.Row():
with gr.Column(scale=1):
# gr.Markdown('# ' + _TITLE)
gr.Markdown(_DESCRIPTION)
with gr.Row(variant='panel'):
with gr.Tab("Input"):
inputfiles = gr.File(file_count="multiple", label="images")
input_path = gr.Textbox(visible=False, label="example_path")
button_gen = gr.Button("RUN")
with gr.Row(variant='panel'):
with gr.Tab("Output"):
with gr.Column(scale=2):
with gr.Group():
output_model = gr.Model3D(
label="3D Dense Model under Gaussian Splats Formats, need more time to visualize",
interactive=False,
camera_position=[0.5, 0.5, 1], # 稍微偏移一点,以便更好地查看模型
)
gr.Markdown(
"""
<div class="model-description">
&nbsp;&nbsp;Use the left mouse button to rotate, the scroll wheel to zoom, and the right mouse button to move.
</div>
"""
)
output_file = gr.File(label="ply")
with gr.Column(scale=1):
output_video = gr.Video(label="video")
button_gen.click(process, inputs=[inputfiles], outputs=[output_video, output_file, output_model])
gr.Examples(
examples=[
"sora-santorini-3-views",
# "TT-family-3-views",
# "dl3dv-ba55-3-views",
],
inputs=[input_path],
outputs=[output_video, output_file, output_model],
fn=lambda x: process(inputfiles=None, input_path=x),
cache_examples=True,
label='Sparse-view Examples'
)
block.launch(server_name="0.0.0.0", share=False)