Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import torch | |
print("cuda is available: ", torch.cuda.is_available()) | |
import gradio as gr | |
import os | |
import shutil | |
import numpy as np | |
import math | |
import open3d as o3d | |
from PIL import Image | |
import torchvision | |
import trimesh | |
import imageio | |
import matplotlib.pyplot as pl | |
pl.ion() | |
CaPE_TYPE = "6DoF" | |
device = 'cuda' #if torch.cuda.is_available() else 'cpu' | |
weight_dtype = torch.float16 | |
torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12 | |
# EscherNet | |
# create angles in archimedean spiral with N steps | |
def get_archimedean_spiral(sphere_radius, num_steps=250): | |
# x-z plane, around upper y | |
''' | |
https://en.wikipedia.org/wiki/Spiral, section "Spherical spiral". c = a / pi | |
''' | |
a = 40 | |
r = sphere_radius | |
translations = [] | |
angles = [] | |
# i = a / 2 | |
i = 0.01 | |
while i < a: | |
theta = i / a * math.pi | |
x = r * math.sin(theta) * math.cos(-i) | |
z = r * math.sin(-theta + math.pi) * math.sin(-i) | |
y = r * - math.cos(theta) | |
# translations.append((x, y, z)) # origin | |
translations.append((x, z, -y)) | |
angles.append([np.rad2deg(-i), np.rad2deg(theta)]) | |
# i += a / (2 * num_steps) | |
i += a / (1 * num_steps) | |
return np.array(translations), np.stack(angles) | |
def look_at(origin, target, up): | |
forward = (target - origin) | |
forward = forward / np.linalg.norm(forward) | |
right = np.cross(up, forward) | |
right = right / np.linalg.norm(right) | |
new_up = np.cross(forward, right) | |
rotation_matrix = np.column_stack((right, new_up, -forward, target)) | |
matrix = np.row_stack((rotation_matrix, [0, 0, 0, 1])) | |
return matrix | |
import einops | |
import sys | |
sys.path.insert(0, "./6DoF/") # TODO change it when deploying | |
# use the customized diffusers modules | |
from diffusers import DDIMScheduler | |
from dataset import get_pose | |
from CN_encoder import CN_encoder | |
from pipeline_zero1to3 import Zero1to3StableDiffusionPipeline | |
from segment_anything import sam_model_registry, SamPredictor | |
# import rembg | |
from carvekit.api.high import HiInterface | |
pretrained_model_name_or_path = "kxic/EscherNet_demo" | |
resolution = 256 | |
h,w = resolution,resolution | |
guidance_scale = 3.0 | |
radius = 2.2 | |
bg_color = [1., 1., 1., 1.] | |
image_transforms = torchvision.transforms.Compose( | |
[ | |
torchvision.transforms.Resize((resolution, resolution)), # 256, 256 | |
torchvision.transforms.ToTensor(), | |
torchvision.transforms.Normalize([0.5], [0.5]) | |
] | |
) | |
xyzs_spiral, angles_spiral = get_archimedean_spiral(1.5, 200) | |
# only half toop | |
xyzs_spiral = xyzs_spiral[:100] | |
angles_spiral = angles_spiral[:100] | |
# Init pipeline | |
scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler", revision=None) | |
image_encoder = CN_encoder.from_pretrained(pretrained_model_name_or_path, subfolder="image_encoder", revision=None) | |
pipeline = Zero1to3StableDiffusionPipeline.from_pretrained( | |
pretrained_model_name_or_path, | |
revision=None, | |
scheduler=scheduler, | |
image_encoder=None, | |
safety_checker=None, | |
feature_extractor=None, | |
torch_dtype=weight_dtype, | |
) | |
pipeline.image_encoder = image_encoder.to(weight_dtype) | |
pipeline.set_progress_bar_config(disable=False) | |
pipeline = pipeline.to(device) | |
# pipeline.enable_xformers_memory_efficient_attention() | |
# enable vae slicing | |
pipeline.enable_vae_slicing() | |
# pipeline.enable_xformers_memory_efficient_attention() | |
#### object segmentation | |
def sam_init(): | |
sam_checkpoint = os.path.join("./sam_pt/sam_vit_h_4b8939.pth") | |
if os.path.exists(sam_checkpoint) is False: | |
os.system("wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth -P ./sam_pt/") | |
model_type = "vit_h" | |
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device=device) | |
predictor = SamPredictor(sam) | |
return predictor | |
def create_carvekit_interface(): | |
# Check doc strings for more information | |
interface = HiInterface(object_type="object", # Can be "object" or "hairs-like". | |
batch_size_seg=6, | |
batch_size_matting=1, | |
device="cpu", | |
seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net | |
matting_mask_size=2048, | |
trimap_prob_threshold=231, | |
trimap_dilation=30, | |
trimap_erosion_iters=5, | |
fp16=False) | |
return interface | |
# rembg_session = rembg.new_session() | |
rembg_session = create_carvekit_interface() | |
rembg_session.u2net = rembg_session.u2net.to(device) | |
rembg_session.fba = rembg_session.fba.to(device) | |
rembg_session.fba.device = device | |
rembg_session.device = device | |
rembg_session.u2net.device = device | |
predictor = sam_init() | |
def run_eschernet(eschernet_input_dict, sample_steps, sample_seed, nvs_num, nvs_mode): | |
# set the random seed | |
generator = torch.Generator(device=device).manual_seed(sample_seed) | |
# generator = None | |
T_out = nvs_num | |
T_in = len(eschernet_input_dict['imgs']) | |
####### output pose | |
# TODO choose T_out number of poses sequentially from the spiral | |
xyzs = xyzs_spiral[::(len(xyzs_spiral) // T_out)] | |
angles_out = angles_spiral[::(len(xyzs_spiral) // T_out)] | |
####### input's max radius for translation scaling | |
radii = eschernet_input_dict['radii'] | |
max_t = np.max(radii) | |
min_t = np.min(radii) | |
####### input pose | |
pose_in = [] | |
for T_in_index in range(T_in): | |
pose = get_pose(np.linalg.inv(eschernet_input_dict['poses'][T_in_index])) | |
pose[1:3, :] *= -1 # coordinate system conversion | |
pose[3, 3] *= 1. / max_t * radius # scale radius to [1.5, 2.2] | |
pose_in.append(torch.from_numpy(pose)) | |
####### input image | |
img = eschernet_input_dict['imgs'] / 255. | |
img[img[:, :, :, -1] == 0.] = bg_color | |
# TODO batch image_transforms | |
input_image = [image_transforms(Image.fromarray(np.uint8(im[:, :, :3] * 255.)).convert("RGB")) for im in img] | |
####### nvs pose | |
pose_out = [] | |
for T_out_index in range(T_out): | |
azimuth, polar = angles_out[T_out_index] | |
if CaPE_TYPE == "4DoF": | |
pose_out.append(torch.tensor([np.deg2rad(polar), np.deg2rad(azimuth), 0., 0.])) | |
elif CaPE_TYPE == "6DoF": | |
pose = look_at(origin=np.array([0, 0, 0]), target=xyzs[T_out_index], up=np.array([0, 0, 1])) | |
pose = np.linalg.inv(pose) | |
pose[2, :] *= -1 | |
pose_out.append(torch.from_numpy(get_pose(pose))) | |
# [B, T, C, H, W] | |
input_image = torch.stack(input_image, dim=0).to(device).to(weight_dtype).unsqueeze(0) | |
# [B, T, 4] | |
pose_in = np.stack(pose_in) | |
pose_out = np.stack(pose_out) | |
if CaPE_TYPE == "6DoF": | |
pose_in_inv = np.linalg.inv(pose_in).transpose([0, 2, 1]) | |
pose_out_inv = np.linalg.inv(pose_out).transpose([0, 2, 1]) | |
pose_in_inv = torch.from_numpy(pose_in_inv).to(device).to(weight_dtype).unsqueeze(0) | |
pose_out_inv = torch.from_numpy(pose_out_inv).to(device).to(weight_dtype).unsqueeze(0) | |
pose_in = torch.from_numpy(pose_in).to(device).to(weight_dtype).unsqueeze(0) | |
pose_out = torch.from_numpy(pose_out).to(device).to(weight_dtype).unsqueeze(0) | |
input_image = einops.rearrange(input_image, "b t c h w -> (b t) c h w") | |
assert T_in == input_image.shape[0] | |
assert T_in == pose_in.shape[1] | |
assert T_out == pose_out.shape[1] | |
# run inference | |
# pipeline.to(device) | |
pipeline.enable_xformers_memory_efficient_attention() | |
image = pipeline(input_imgs=input_image, prompt_imgs=input_image, | |
poses=[[pose_out, pose_out_inv], [pose_in, pose_in_inv]], | |
height=h, width=w, T_in=T_in, T_out=T_out, | |
guidance_scale=guidance_scale, num_inference_steps=50, generator=generator, | |
output_type="numpy").images | |
# save output image | |
output_dir = os.path.join(tmpdirname, "eschernet") | |
if os.path.exists(output_dir): | |
shutil.rmtree(output_dir) | |
os.makedirs(output_dir, exist_ok=True) | |
# # save to N imgs | |
# for i in range(T_out): | |
# imsave(os.path.join(output_dir, f'{i}.png'), (image[i] * 255).astype(np.uint8)) | |
# make a gif | |
frames = [Image.fromarray((image[i] * 255).astype(np.uint8)) for i in range(T_out)] | |
# frame_one = frames[0] | |
# frame_one.save(os.path.join(output_dir, "output.gif"), format="GIF", append_images=frames, | |
# save_all=True, duration=50, loop=1) | |
# get a video | |
video_path = os.path.join(output_dir, "output.mp4") | |
imageio.mimwrite(video_path, np.stack(frames), fps=10, codec='h264') | |
return video_path | |
############################ Dust3r as Pose Estimation ############################ | |
from scipy.spatial.transform import Rotation | |
import copy | |
from dust3r.inference import inference | |
from dust3r.model import AsymmetricCroCo3DStereo | |
from dust3r.image_pairs import make_pairs | |
from dust3r.utils.image import load_images, rgb | |
from dust3r.utils.device import to_numpy | |
from dust3r.viz import add_scene_cam, CAM_COLORS, OPENGL, pts3d_to_trimesh, cat_meshes | |
from dust3r.cloud_opt import global_aligner, GlobalAlignerMode | |
def _convert_scene_output_to_glb(outdir, imgs, pts3d, mask, focals, cams2world, cam_size=0.05, | |
cam_color=None, as_pointcloud=False, | |
transparent_cams=False, silent=False, same_focals=False): | |
assert len(pts3d) == len(mask) <= len(imgs) <= len(cams2world) | |
if not same_focals: | |
assert (len(cams2world) == len(focals)) | |
pts3d = to_numpy(pts3d) | |
imgs = to_numpy(imgs) | |
focals = to_numpy(focals) | |
cams2world = to_numpy(cams2world) | |
scene = trimesh.Scene() | |
# add axes | |
scene.add_geometry(trimesh.creation.axis(axis_length=0.5, axis_radius=0.001)) | |
# full pointcloud | |
if as_pointcloud: | |
pts = np.concatenate([p[m] for p, m in zip(pts3d, mask)]) | |
col = np.concatenate([p[m] for p, m in zip(imgs, mask)]) | |
pct = trimesh.PointCloud(pts.reshape(-1, 3), colors=col.reshape(-1, 3)) | |
scene.add_geometry(pct) | |
else: | |
meshes = [] | |
for i in range(len(imgs)): | |
meshes.append(pts3d_to_trimesh(imgs[i], pts3d[i], mask[i])) | |
mesh = trimesh.Trimesh(**cat_meshes(meshes)) | |
scene.add_geometry(mesh) | |
# add each camera | |
for i, pose_c2w in enumerate(cams2world): | |
if isinstance(cam_color, list): | |
camera_edge_color = cam_color[i] | |
else: | |
camera_edge_color = cam_color or CAM_COLORS[i % len(CAM_COLORS)] | |
if same_focals: | |
focal = focals[0] | |
else: | |
focal = focals[i] | |
add_scene_cam(scene, pose_c2w, camera_edge_color, | |
None if transparent_cams else imgs[i], focal, | |
imsize=imgs[i].shape[1::-1], screen_width=cam_size) | |
rot = np.eye(4) | |
rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix() | |
scene.apply_transform(np.linalg.inv(cams2world[0] @ OPENGL @ rot)) | |
outfile = os.path.join(outdir, 'scene.glb') | |
if not silent: | |
print('(exporting 3D scene to', outfile, ')') | |
scene.export(file_obj=outfile) | |
return outfile | |
def get_3D_model_from_scene(outdir, silent, scene, min_conf_thr=3, as_pointcloud=False, mask_sky=False, | |
clean_depth=False, transparent_cams=False, cam_size=0.05, same_focals=False): | |
""" | |
extract 3D_model (glb file) from a reconstructed scene | |
""" | |
if scene is None: | |
return None | |
# post processes | |
if clean_depth: | |
scene = scene.clean_pointcloud() | |
if mask_sky: | |
scene = scene.mask_sky() | |
# get optimized values from scene | |
rgbimg = to_numpy(scene.imgs) | |
focals = to_numpy(scene.get_focals().cpu()) | |
# cams2world = to_numpy(scene.get_im_poses().cpu()) | |
# TODO use the vis_poses | |
cams2world = scene.vis_poses | |
# 3D pointcloud from depthmap, poses and intrinsics | |
# pts3d = to_numpy(scene.get_pts3d()) | |
# TODO use the vis_poses | |
pts3d = scene.vis_pts3d | |
scene.min_conf_thr = float(scene.conf_trf(torch.tensor(min_conf_thr))) | |
msk = to_numpy(scene.get_masks()) | |
return _convert_scene_output_to_glb(outdir, rgbimg, pts3d, msk, focals, cams2world, as_pointcloud=as_pointcloud, | |
transparent_cams=transparent_cams, cam_size=cam_size, silent=silent, | |
same_focals=same_focals) | |
def get_reconstructed_scene(filelist, schedule, niter, min_conf_thr, | |
as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size, | |
scenegraph_type, winsize, refid, same_focals): | |
""" | |
from a list of images, run dust3r inference, global aligner. | |
then run get_3D_model_from_scene | |
""" | |
silent = False | |
image_size = 224 | |
# remove the directory if it already exists | |
outdir = tmpdirname | |
if os.path.exists(outdir): | |
shutil.rmtree(outdir) | |
os.makedirs(outdir, exist_ok=True) | |
imgs, imgs_rgba = load_images(filelist, size=image_size, verbose=not silent, do_remove_background=True, rembg_session=rembg_session, predictor=predictor) | |
if len(imgs) == 1: | |
imgs = [imgs[0], copy.deepcopy(imgs[0])] | |
imgs[1]['idx'] = 1 | |
if scenegraph_type == "swin": | |
scenegraph_type = scenegraph_type + "-" + str(winsize) | |
elif scenegraph_type == "oneref": | |
scenegraph_type = scenegraph_type + "-" + str(refid) | |
pairs = make_pairs(imgs, scene_graph=scenegraph_type, prefilter=None, symmetrize=True) | |
output = inference(pairs, model, device, batch_size=1, verbose=not silent) | |
mode = GlobalAlignerMode.PointCloudOptimizer if len(imgs) > 2 else GlobalAlignerMode.PairViewer | |
scene = global_aligner(output, device=device, mode=mode, verbose=not silent, same_focals=same_focals) | |
lr = 0.01 | |
if mode == GlobalAlignerMode.PointCloudOptimizer: | |
loss = scene.compute_global_alignment(init='mst', niter=niter, schedule=schedule, lr=lr) | |
# for eschernet | |
cams2world = to_numpy(scene.get_im_poses().cpu()) | |
rgbimg = to_numpy(scene.imgs) | |
imgs = [] | |
rgbaimg = [] | |
for i in range(len(rgbimg)): # when only 1 image, scene.imgs is two | |
imgs.append(rgbimg[i]) | |
# imgs.append(rgb(depths[i])) | |
# imgs.append(rgb(confs[i])) | |
# imgs.append(imgs_rgba[i]) | |
if len(imgs_rgba) == 1 and i == 1: | |
imgs.append(imgs_rgba[0]) | |
rgbaimg.append(np.array(imgs_rgba[0])) | |
else: | |
imgs.append(imgs_rgba[i]) | |
rgbaimg.append(np.array(imgs_rgba[i])) | |
rgbaimg = np.array(rgbaimg) | |
# 3D pointcloud from depthmap, poses and intrinsics | |
pts3d = to_numpy(scene.get_pts3d()) | |
scene.min_conf_thr = float(scene.conf_trf(torch.tensor(min_conf_thr))) | |
msk = to_numpy(scene.get_masks()) | |
obj_mask = rgbaimg[..., 3] > 0 | |
# TODO set global coordinate system at the center of the scene, z-axis is up | |
pts = np.concatenate([p[m] for p, m in zip(pts3d, msk)]).reshape(-1, 3) | |
pts_obj = np.concatenate([p[m&obj_m] for p, m, obj_m in zip(pts3d, msk, obj_mask)]).reshape(-1, 3) | |
centroid = np.mean(pts_obj, axis=0) # obj center | |
obj2world = np.eye(4) | |
obj2world[:3, 3] = -centroid # T_wc | |
# get z_up vector | |
# TODO fit a plane and get the normal vector | |
pcd = o3d.geometry.PointCloud() | |
pcd.points = o3d.utility.Vector3dVector(pts) | |
plane_model, inliers = pcd.segment_plane(distance_threshold=0.01, ransac_n=3, num_iterations=1000) | |
# get the normalised normal vector dim = 3 | |
normal = plane_model[:3] / np.linalg.norm(plane_model[:3]) | |
# the normal direction should be pointing up | |
if normal[1] < 0: | |
normal = -normal | |
# print("normal", normal) | |
# # TODO z-up 180 | |
# z_up = np.array([[1,0,0,0], | |
# [0,-1,0,0], | |
# [0,0,-1,0], | |
# [0,0,0,1]]) | |
# obj2world = z_up @ obj2world | |
# # avg the y | |
# z_up_avg = cams2world[:,:3,3].sum(0) / np.linalg.norm(cams2world[:,:3,3].sum(0), axis=-1) # average direction in cam coordinate | |
# # import pdb; pdb.set_trace() | |
# rot_axis = np.cross(np.array([0, 0, 1]), z_up_avg) | |
# rot_angle = np.arccos(np.dot(np.array([0, 0, 1]), z_up_avg) / (np.linalg.norm(z_up_avg) + 1e-6)) | |
# rot = Rotation.from_rotvec(rot_angle * rot_axis) | |
# z_up = np.eye(4) | |
# z_up[:3, :3] = rot.as_matrix() | |
# get the rotation matrix from normal to z-axis | |
z_axis = np.array([0, 0, 1]) | |
rot_axis = np.cross(normal, z_axis) | |
rot_angle = np.arccos(np.dot(normal, z_axis) / (np.linalg.norm(normal) + 1e-6)) | |
rot = Rotation.from_rotvec(rot_angle * rot_axis) | |
z_up = np.eye(4) | |
z_up[:3, :3] = rot.as_matrix() | |
obj2world = z_up @ obj2world | |
# flip 180 | |
flip_rot = np.array([[1, 0, 0, 0], | |
[0, -1, 0, 0], | |
[0, 0, -1, 0], | |
[0, 0, 0, 1]]) | |
obj2world = flip_rot @ obj2world | |
# get new cams2obj | |
cams2obj = [] | |
for i, cam2world in enumerate(cams2world): | |
cams2obj.append(obj2world @ cam2world) | |
# TODO transform pts3d to the new coordinate system | |
for i, pts in enumerate(pts3d): | |
pts3d[i] = (obj2world @ np.concatenate([pts, np.ones_like(pts)[..., :1]], axis=-1).transpose(2, 0, 1).reshape(4, | |
-1)) \ | |
.reshape(4, pts.shape[0], pts.shape[1]).transpose(1, 2, 0)[..., :3] | |
cams2world = np.array(cams2obj) | |
# TODO rewrite hack | |
scene.vis_poses = cams2world.copy() | |
scene.vis_pts3d = pts3d.copy() | |
# # TODO save cams2world and rgbimg to each file, file name "000.npy", "001.npy", ... and "000.png", "001.png", ... | |
# for i, (img, img_rgba, pose) in enumerate(zip(rgbimg, rgbaimg, cams2world)): | |
# np.save(os.path.join(outdir, f"{i:03d}.npy"), pose) | |
# pl.imsave(os.path.join(outdir, f"{i:03d}.png"), img) | |
# pl.imsave(os.path.join(outdir, f"{i:03d}_rgba.png"), img_rgba) | |
# # np.save(os.path.join(outdir, f"{i:03d}_focal.npy"), to_numpy(focal)) | |
# save the min/max radius of camera | |
radii = np.linalg.norm(np.linalg.inv(cams2world)[..., :3, 3]) | |
np.save(os.path.join(outdir, "radii.npy"), radii) | |
eschernet_input = {"poses": cams2world, | |
"radii": radii, | |
"imgs": rgbaimg} | |
print("got eschernet input") | |
outfile = get_3D_model_from_scene(outdir, silent, scene, min_conf_thr, as_pointcloud, mask_sky, | |
clean_depth, transparent_cams, cam_size, same_focals=same_focals) | |
return outfile, imgs, eschernet_input | |
def set_scenegraph_options(inputfiles, winsize, refid, scenegraph_type): | |
num_files = len(inputfiles) if inputfiles is not None else 1 | |
max_winsize = max(1, math.ceil((num_files - 1) / 2)) | |
if scenegraph_type == "swin": | |
winsize = gr.Slider(label="Scene Graph: Window Size", value=max_winsize, | |
minimum=1, maximum=max_winsize, step=1, visible=True) | |
refid = gr.Slider(label="Scene Graph: Id", value=0, minimum=0, | |
maximum=num_files - 1, step=1, visible=False) | |
elif scenegraph_type == "oneref": | |
winsize = gr.Slider(label="Scene Graph: Window Size", value=max_winsize, | |
minimum=1, maximum=max_winsize, step=1, visible=False) | |
refid = gr.Slider(label="Scene Graph: Id", value=0, minimum=0, | |
maximum=num_files - 1, step=1, visible=True) | |
else: | |
winsize = gr.Slider(label="Scene Graph: Window Size", value=max_winsize, | |
minimum=1, maximum=max_winsize, step=1, visible=False) | |
refid = gr.Slider(label="Scene Graph: Id", value=0, minimum=0, | |
maximum=num_files - 1, step=1, visible=False) | |
return winsize, refid | |
def get_examples(path): | |
objs = [] | |
for obj_name in sorted(os.listdir(path)): | |
img_files = [] | |
for img_file in sorted(os.listdir(os.path.join(path, obj_name))): | |
img_files.append(os.path.join(path, obj_name, img_file)) | |
objs.append([img_files]) | |
print("objs = ", objs) | |
return objs | |
def preview_input(inputfiles): | |
if inputfiles is None: | |
return None | |
imgs = [] | |
for img_file in inputfiles: | |
img = pl.imread(img_file) | |
imgs.append(img) | |
return imgs | |
# def main(): | |
# dustr init | |
silent = False | |
image_size = 224 | |
weights_path = 'checkpoints/DUSt3R_ViTLarge_BaseDecoder_224_linear.pth' | |
model = AsymmetricCroCo3DStereo.from_pretrained(weights_path).to(device) | |
# dust3r will write the 3D model inside tmpdirname | |
# with tempfile.TemporaryDirectory(suffix='dust3r_gradio_demo') as tmpdirname: | |
tmpdirname = os.path.join('logs/user_object') | |
# remove the directory if it already exists | |
if os.path.exists(tmpdirname): | |
shutil.rmtree(tmpdirname) | |
os.makedirs(tmpdirname, exist_ok=True) | |
if not silent: | |
print('Outputing stuff in', tmpdirname) | |
_HEADER_ = ''' | |
<h2><b>[CVPR'24 Oral] EscherNet: A Generative Model for Scalable View Synthesis</b></h2> | |
<b>EscherNet</b> is a multiview diffusion model for scalable generative any-to-any number/pose novel view synthesis. | |
Image views are treated as tokens and the camera pose is encoded by <b>CaPE (Camera Positional Encoding)</b>. | |
<a href='https://kxhit.github.io/EscherNet' target='_blank'>Project</a> <b>|</b> | |
<a href='https://github.com/kxhit/EscherNet' target='_blank'>GitHub</a> <b>|</b> | |
<a href='https://arxiv.org/abs/2402.03908' target='_blank'>ArXiv</a> | |
<h4><b>Tips:</b></h4> | |
- Our model can take <b>any number input images</b>. The more images you provide <b>(>=3 for this demo)</b>, the better the results. | |
- Our model can generate <b>any number and any pose</b> novel views. You can specify the number of views you want to generate. In this demo, we set novel views on an <b>archemedian spiral</b> for simplicity. | |
- The pose estimation is done using <a href='https://github.com/naver/dust3r' target='_blank'>DUSt3R</a>. You can also provide your own poses or get pose via any SLAM system. | |
- The current checkpoint supports 6DoF camera pose and is trained on 30k 3D <a href='https://objaverse.allenai.org/' target='_blank'>Objaverse</a> objects for demo. Scaling is on the roadmap! | |
''' | |
_CITE_ = r""" | |
📝 <b>Citation</b>: | |
```bibtex | |
@article{kong2024eschernet, | |
title={EscherNet: A Generative Model for Scalable View Synthesis}, | |
author={Kong, Xin and Liu, Shikun and Lyu, Xiaoyang and Taher, Marwan and Qi, Xiaojuan and Davison, Andrew J}, | |
journal={arXiv preprint arXiv:2402.03908}, | |
year={2024} | |
} | |
``` | |
""" | |
with gr.Blocks() as demo: | |
gr.Markdown(_HEADER_) | |
eschernet_input = gr.State(None) | |
with gr.Row(variant="panel"): | |
# left column | |
with gr.Column(): | |
with gr.Row(): | |
input_image = gr.File(file_count="multiple") | |
with gr.Row(): | |
run_dust3r = gr.Button("Get Pose!", elem_id="dust3r") | |
with gr.Row(): | |
processed_image = gr.Gallery(label='Input Views', columns=2, height="100%") | |
with gr.Row(variant="panel"): | |
# input examples under "examples" folder | |
gr.Examples( | |
examples=get_examples('examples'), | |
inputs=[input_image], | |
label="Examples (click one set of images to start!)", | |
examples_per_page=20 | |
) | |
# right column | |
with gr.Column(): | |
with gr.Row(): | |
outmodel = gr.Model3D() | |
with gr.Row(): | |
gr.Markdown(''' | |
<h4><b>Check if the pose (blue is axis is estimated z-up direction) and segmentation looks correct. If not, remove the incorrect images and try again.</b></h4> | |
''') | |
with gr.Row(): | |
with gr.Group(): | |
do_remove_background = gr.Checkbox( | |
label="Remove Background", value=True | |
) | |
sample_seed = gr.Number(value=42, label="Seed Value", precision=0) | |
sample_steps = gr.Slider( | |
label="Sample Steps", | |
minimum=30, | |
maximum=75, | |
value=50, | |
step=5, | |
visible=False | |
) | |
nvs_num = gr.Slider( | |
label="Number of Novel Views", | |
minimum=5, | |
maximum=100, | |
value=30, | |
step=1 | |
) | |
nvs_mode = gr.Dropdown(["archimedes circle"], # "fixed 4 views", "fixed 8 views" | |
value="archimedes circle", label="Novel Views Pose Chosen", visible=True) | |
with gr.Row(): | |
gr.Markdown(''' | |
<h4><b>Choose your desired novel view poses number and generate! The more output images the longer it takes.</b></h4> | |
''') | |
with gr.Row(): | |
submit = gr.Button("Submit", elem_id="eschernet", variant="primary") | |
with gr.Row(): | |
with gr.Column(): | |
output_video = gr.Video( | |
label="video", format="mp4", | |
width=379, | |
autoplay=True, | |
interactive=False | |
) | |
with gr.Row(): | |
gr.Markdown(''' | |
<h4><b>The novel views are generated on an archimedean spiral (rotating around z-up axis and looking at the object center). You can download the video.</b></h4> | |
''') | |
gr.Markdown(_CITE_) | |
# set dust3r parameter invisible to be clean | |
with gr.Column(): | |
with gr.Row(): | |
schedule = gr.Dropdown(["linear", "cosine"], | |
value='linear', label="schedule", info="For global alignment!", visible=False) | |
niter = gr.Number(value=300, precision=0, minimum=0, maximum=5000, | |
label="num_iterations", info="For global alignment!", visible=False) | |
scenegraph_type = gr.Dropdown(["complete", "swin", "oneref"], | |
value='complete', label="Scenegraph", | |
info="Define how to make pairs", | |
interactive=True, visible=False) | |
same_focals = gr.Checkbox(value=True, label="Focal", info="Use the same focal for all cameras", visible=False) | |
winsize = gr.Slider(label="Scene Graph: Window Size", value=1, | |
minimum=1, maximum=1, step=1, visible=False) | |
refid = gr.Slider(label="Scene Graph: Id", value=0, minimum=0, maximum=0, step=1, visible=False) | |
with gr.Row(): | |
# adjust the confidence threshold | |
min_conf_thr = gr.Slider(label="min_conf_thr", value=3.0, minimum=1.0, maximum=20, step=0.1, visible=False) | |
# adjust the camera size in the output pointcloud | |
cam_size = gr.Slider(label="cam_size", value=0.05, minimum=0.01, maximum=0.5, step=0.001, visible=False) | |
with gr.Row(): | |
as_pointcloud = gr.Checkbox(value=False, label="As pointcloud", visible=False) | |
# two post process implemented | |
mask_sky = gr.Checkbox(value=False, label="Mask sky", visible=False) | |
clean_depth = gr.Checkbox(value=True, label="Clean-up depthmaps", visible=False) | |
transparent_cams = gr.Checkbox(value=False, label="Transparent cameras", visible=False) | |
# events | |
# scenegraph_type.change(set_scenegraph_options, | |
# inputs=[input_image, winsize, refid, scenegraph_type], | |
# outputs=[winsize, refid]) | |
# min_conf_thr.release(fn=model_from_scene_fun, | |
# inputs=[scene, min_conf_thr, as_pointcloud, mask_sky, | |
# clean_depth, transparent_cams, cam_size, same_focals], | |
# outputs=outmodel) | |
# cam_size.change(fn=model_from_scene_fun, | |
# inputs=[scene, min_conf_thr, as_pointcloud, mask_sky, | |
# clean_depth, transparent_cams, cam_size, same_focals], | |
# outputs=outmodel) | |
# as_pointcloud.change(fn=model_from_scene_fun, | |
# inputs=[scene, min_conf_thr, as_pointcloud, mask_sky, | |
# clean_depth, transparent_cams, cam_size, same_focals], | |
# outputs=outmodel) | |
# mask_sky.change(fn=model_from_scene_fun, | |
# inputs=[scene, min_conf_thr, as_pointcloud, mask_sky, | |
# clean_depth, transparent_cams, cam_size, same_focals], | |
# outputs=outmodel) | |
# clean_depth.change(fn=model_from_scene_fun, | |
# inputs=[scene, min_conf_thr, as_pointcloud, mask_sky, | |
# clean_depth, transparent_cams, cam_size, same_focals], | |
# outputs=outmodel) | |
# transparent_cams.change(model_from_scene_fun, | |
# inputs=[scene, min_conf_thr, as_pointcloud, mask_sky, | |
# clean_depth, transparent_cams, cam_size, same_focals], | |
# outputs=outmodel) | |
# run_dust3r.click(fn=recon_fun, | |
# inputs=[input_image, schedule, niter, min_conf_thr, as_pointcloud, | |
# mask_sky, clean_depth, transparent_cams, cam_size, | |
# scenegraph_type, winsize, refid, same_focals], | |
# outputs=[outmodel, processed_image, eschernet_input]) | |
# events | |
input_image.change(set_scenegraph_options, | |
inputs=[input_image, winsize, refid, scenegraph_type], | |
outputs=[winsize, refid]) | |
run_dust3r.click(fn=get_reconstructed_scene, | |
inputs=[input_image, schedule, niter, min_conf_thr, as_pointcloud, | |
mask_sky, clean_depth, transparent_cams, cam_size, | |
scenegraph_type, winsize, refid, same_focals], | |
outputs=[outmodel, processed_image, eschernet_input]) | |
# events | |
input_image.change(fn=preview_input, | |
inputs=[input_image], | |
outputs=[processed_image]) | |
submit.click(fn=run_eschernet, | |
inputs=[eschernet_input, sample_steps, sample_seed, | |
nvs_num, nvs_mode], | |
outputs=[output_video]) | |
demo.queue(max_size=10).launch(share=True) |