Spaces:
Running
on
Zero
Running
on
Zero
File size: 13,021 Bytes
35e2073 f908eb9 35e2073 551f29f 35e2073 4a2c55f 35e2073 551f29f 35e2073 e5ccf22 35e2073 509f048 4aa11ad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 |
import os, subprocess, shlex, sys, gc
import time
import torch
import numpy as np
import shutil
import argparse
import gradio as gr
import uuid
import spaces
subprocess.run(shlex.split("pip install wheel/diff_gaussian_rasterization-0.0.0-cp310-cp310-linux_x86_64.whl"))
subprocess.run(shlex.split("pip install wheel/simple_knn-0.0.0-cp310-cp310-linux_x86_64.whl"))
subprocess.run(shlex.split("pip install wheel/curope-0.0.0-cp310-cp310-linux_x86_64.whl"))
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
os.sys.path.append(os.path.abspath(os.path.join(BASE_DIR, "submodules", "dust3r")))
# os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
from dust3r.inference import inference
from dust3r.model import AsymmetricCroCo3DStereo
from dust3r.utils.device import to_numpy
from dust3r.image_pairs import make_pairs
from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
from utils.dust3r_utils import compute_global_alignment, load_images, storePly, save_colmap_cameras, save_colmap_images
from argparse import ArgumentParser, Namespace
from arguments import ModelParams, PipelineParams, OptimizationParams
from train_joint import training
from render_by_interp import render_sets
GRADIO_CACHE_FOLDER = './gradio_cache_folder'
#############################################################################################################################################
def get_dust3r_args_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--image_size", type=int, default=512, choices=[512, 224], help="image size")
parser.add_argument("--model_path", type=str, default="submodules/dust3r/checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth", help="path to the model weights")
parser.add_argument("--device", type=str, default='cuda', help="pytorch device")
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--schedule", type=str, default='linear')
parser.add_argument("--lr", type=float, default=0.01)
parser.add_argument("--niter", type=int, default=300)
parser.add_argument("--focal_avg", type=bool, default=True)
parser.add_argument("--n_views", type=int, default=3)
parser.add_argument("--base_path", type=str, default=GRADIO_CACHE_FOLDER)
return parser
@spaces.GPU(duration=150)
def process(inputfiles, input_path=None):
if input_path is not None:
imgs_path = './assets/example/' + input_path
imgs_names = sorted(os.listdir(imgs_path))
inputfiles = []
for imgs_name in imgs_names:
file_path = os.path.join(imgs_path, imgs_name)
print(file_path)
inputfiles.append(file_path)
print(inputfiles)
# ------ (1) Coarse Geometric Initialization ------
# os.system(f"rm -rf {GRADIO_CACHE_FOLDER}")
parser = get_dust3r_args_parser()
opt = parser.parse_args()
tmp_user_folder = str(uuid.uuid4()).replace("-", "")
opt.img_base_path = os.path.join(opt.base_path, tmp_user_folder)
img_folder_path = os.path.join(opt.img_base_path, "images")
img_folder_path = os.path.join(opt.img_base_path, "images")
model = AsymmetricCroCo3DStereo.from_pretrained(opt.model_path).to(opt.device)
os.makedirs(img_folder_path, exist_ok=True)
opt.n_views = len(inputfiles)
if opt.n_views == 1:
raise gr.Error("The number of input images should be greater than 1.")
print("Multiple images: ", inputfiles)
for image_path in inputfiles:
if input_path is not None:
shutil.copy(image_path, img_folder_path)
else:
shutil.move(image_path, img_folder_path)
train_img_list = sorted(os.listdir(img_folder_path))
assert len(train_img_list)==opt.n_views, f"Number of images in the folder is not equal to {opt.n_views}"
images, ori_size, imgs_resolution = load_images(img_folder_path, size=512)
resolutions_are_equal = len(set(imgs_resolution)) == 1
if resolutions_are_equal == False:
raise gr.Error("The resolution of the input image should be the same.")
print("ori_size", ori_size)
start_time = time.time()
######################################################
pairs = make_pairs(images, scene_graph='complete', prefilter=None, symmetrize=True)
output = inference(pairs, model, opt.device, batch_size=opt.batch_size)
output_colmap_path=img_folder_path.replace("images", "sparse/0")
os.makedirs(output_colmap_path, exist_ok=True)
scene = global_aligner(output, device=opt.device, mode=GlobalAlignerMode.PointCloudOptimizer)
loss = compute_global_alignment(scene=scene, init="mst", niter=opt.niter, schedule=opt.schedule, lr=opt.lr, focal_avg=opt.focal_avg)
scene = scene.clean_pointcloud()
imgs = to_numpy(scene.imgs)
focals = scene.get_focals()
poses = to_numpy(scene.get_im_poses())
pts3d = to_numpy(scene.get_pts3d())
scene.min_conf_thr = float(scene.conf_trf(torch.tensor(1.0)))
confidence_masks = to_numpy(scene.get_masks())
intrinsics = to_numpy(scene.get_intrinsics())
######################################################
end_time = time.time()
print(f"Time taken for {opt.n_views} views: {end_time-start_time} seconds")
save_colmap_cameras(ori_size, intrinsics, os.path.join(output_colmap_path, 'cameras.txt'))
save_colmap_images(poses, os.path.join(output_colmap_path, 'images.txt'), train_img_list)
pts_4_3dgs = np.concatenate([p[m] for p, m in zip(pts3d, confidence_masks)])
color_4_3dgs = np.concatenate([p[m] for p, m in zip(imgs, confidence_masks)])
color_4_3dgs = (color_4_3dgs * 255.0).astype(np.uint8)
storePly(os.path.join(output_colmap_path, "points3D.ply"), pts_4_3dgs, color_4_3dgs)
pts_4_3dgs_all = np.array(pts3d).reshape(-1, 3)
np.save(output_colmap_path + "/pts_4_3dgs_all.npy", pts_4_3dgs_all)
np.save(output_colmap_path + "/focal.npy", np.array(focals.cpu()))
### save VRAM
del scene
torch.cuda.empty_cache()
gc.collect()
##################################################################################################################################################
# ------ (2) Fast 3D-Gaussian Optimization ------
parser = ArgumentParser(description="Training script parameters")
lp = ModelParams(parser)
op = OptimizationParams(parser)
pp = PipelineParams(parser)
parser.add_argument('--debug_from', type=int, default=-1)
parser.add_argument("--test_iterations", nargs="+", type=int, default=[])
parser.add_argument("--save_iterations", nargs="+", type=int, default=[])
parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[])
parser.add_argument("--start_checkpoint", type=str, default = None)
parser.add_argument("--scene", type=str, default="demo")
parser.add_argument("--n_views", type=int, default=3)
parser.add_argument("--get_video", action="store_true")
parser.add_argument("--optim_pose", type=bool, default=True)
parser.add_argument("--skip_train", action="store_true")
parser.add_argument("--skip_test", action="store_true")
args = parser.parse_args(sys.argv[1:])
args.save_iterations.append(args.iterations)
args.model_path = opt.img_base_path + '/output/'
args.source_path = opt.img_base_path
# args.model_path = GRADIO_CACHE_FOLDER + '/output/'
# args.source_path = GRADIO_CACHE_FOLDER
args.iteration = 1000
os.makedirs(args.model_path, exist_ok=True)
training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from, args)
##################################################################################################################################################
# ------ (3) Render video by interpolation ------
parser = ArgumentParser(description="Testing script parameters")
model = ModelParams(parser, sentinel=True)
pipeline = PipelineParams(parser)
args.eval = True
args.get_video = True
args.n_views = opt.n_views
render_sets(
model.extract(args),
args.iteration,
pipeline.extract(args),
args.skip_train,
args.skip_test,
args,
)
output_ply_path = opt.img_base_path + f'/output/point_cloud/iteration_{args.iteration}/point_cloud.ply'
output_video_path = opt.img_base_path + f'/output/demo_{opt.n_views}_view.mp4'
# output_ply_path = GRADIO_CACHE_FOLDER+ f'/output/point_cloud/iteration_{args.iteration}/point_cloud.ply'
# output_video_path = GRADIO_CACHE_FOLDER+ f'/output/demo_{opt.n_views}_view.mp4'
return output_video_path, output_ply_path, output_ply_path
##################################################################################################################################################
_TITLE = '''InstantSplat'''
_DESCRIPTION = '''
<div style="display: flex; justify-content: center; align-items: center;">
<div style="width: 100%; text-align: center; font-size: 30px;">
<strong>InstantSplat: Sparse-view SfM-free Gaussian Splatting in Seconds</strong>
</div>
</div>
<p></p>
<div align="center">
<a style="display:inline-block" href="https://instantsplat.github.io/"><img src='https://img.shields.io/badge/Project_Page-1c7d45?logo=gumtree'></a>
<a style="display:inline-block" href="https://www.youtube.com/watch?v=fxf_ypd7eD8"><img src='https://img.shields.io/badge/Demo_Video-E33122?logo=Youtube'></a>
<a style="display:inline-block" href="https://arxiv.org/abs/2403.20309"><img src="https://img.shields.io/badge/ArXiv-2403.20309-b31b1b?logo=arxiv" alt='arxiv'></a>
<a title="Social" href="https://x.com/KairunWen" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
<img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social">
</a>
</div>
<p></p>
* Official demo of: [InstantSplat: Sparse-view SfM-free Gaussian Splatting in Seconds](https://instantsplat.github.io/).
* Sparse-view examples for direct viewing: you can simply click the examples (in the bottom of the page), to quickly view the results on representative data.
* Training speeds may slow if the resolution or number of images is large. To achieve performance comparable to what has been reported, please conduct tests on your own GPU (A100/4090).
'''
# <a style="display:inline-block" href="https://github.com/VITA-Group/LightGaussian"><img src="https://img.shields.io/badge/Source_Code-black?logo=Github" alt='Github Source Code'></a>
#
# <a style="display:inline-block" href="https://www.nvidia.com/en-us/"><img src="https://img.shields.io/badge/Nvidia-575757?logo=nvidia" alt='Nvidia'></a>
# * If InstantSplat is helpful, please give us a star ⭐ on Github. Thanks! <a style="display:inline-block; margin-left: .5em" href="https://github.com/VITA-Group/LightGaussian"><img src='https://img.shields.io/github/stars/VITA-Group/LightGaussian?style=social'/></a>
# block = gr.Blocks(title=_TITLE).queue()
block = gr.Blocks().queue()
with block:
with gr.Row():
with gr.Column(scale=1):
# gr.Markdown('# ' + _TITLE)
gr.Markdown(_DESCRIPTION)
with gr.Row(variant='panel'):
with gr.Tab("Input"):
inputfiles = gr.File(file_count="multiple", label="images")
input_path = gr.Textbox(visible=False, label="example_path")
button_gen = gr.Button("RUN")
with gr.Row(variant='panel'):
with gr.Tab("Output"):
with gr.Column(scale=2):
with gr.Group():
output_model = gr.Model3D(
label="3D Dense Model under Gaussian Splats Formats, need more time to visualize",
interactive=False,
camera_position=[0.5, 0.5, 1], # 稍微偏移一点,以便更好地查看模型
)
gr.Markdown(
"""
<div class="model-description">
Use the left mouse button to rotate, the scroll wheel to zoom, and the right mouse button to move.
</div>
"""
)
output_file = gr.File(label="ply")
with gr.Column(scale=1):
output_video = gr.Video(label="video")
button_gen.click(process, inputs=[inputfiles], outputs=[output_video, output_file, output_model])
gr.Examples(
examples=[
"sora-santorini-3-views",
# "TT-family-3-views",
# "dl3dv-ba55-3-views",
],
inputs=[input_path],
outputs=[output_video, output_file, output_model],
fn=lambda x: process(inputfiles=None, input_path=x),
cache_examples=True,
label='Sparse-view Examples'
)
block.launch(server_name="0.0.0.0", share=False) |