meow
init
d6d3a5b
raw
history blame
8.67 kB
import os
import os.path as op
import re
from abc import abstractmethod
import matplotlib.cm as cm
import numpy as np
from aitviewer.headless import HeadlessRenderer
from aitviewer.renderables.billboard import Billboard
from aitviewer.renderables.meshes import Meshes
from aitviewer.scene.camera import OpenCVCamera
from aitviewer.scene.material import Material
from aitviewer.utils.so3 import aa2rot_numpy
from aitviewer.viewer import Viewer
from easydict import EasyDict as edict
from loguru import logger
from PIL import Image
from tqdm import tqdm
OBJ_ID = 100
SMPLX_ID = 150
LEFT_ID = 200
RIGHT_ID = 250
SEGM_IDS = {"object": OBJ_ID, "smplx": SMPLX_ID, "left": LEFT_ID, "right": RIGHT_ID}
cmap = cm.get_cmap("plasma")
materials = {
"none": None,
"white": Material(color=(1.0, 1.0, 1.0, 1.0), ambient=0.2),
"red": Material(color=(0.969, 0.106, 0.059, 1.0), ambient=0.2),
"blue": Material(color=(0.0, 0.0, 1.0, 1.0), ambient=0.2),
"green": Material(color=(1.0, 0.0, 0.0, 1.0), ambient=0.2),
"cyan": Material(color=(0.051, 0.659, 0.051, 1.0), ambient=0.2),
"light-blue": Material(color=(0.588, 0.5647, 0.9725, 1.0), ambient=0.2),
"cyan-light": Material(color=(0.051, 0.659, 0.051, 1.0), ambient=0.2),
"dark-light": Material(color=(0.404, 0.278, 0.278, 1.0), ambient=0.2),
"rice": Material(color=(0.922, 0.922, 0.102, 1.0), ambient=0.2),
}
class ViewerData(edict):
"""
Interface to standardize viewer data.
"""
def __init__(self, Rt, K, cols, rows, imgnames=None):
self.imgnames = imgnames
self.Rt = Rt
self.K = K
self.num_frames = Rt.shape[0]
self.cols = cols
self.rows = rows
self.validate_format()
def validate_format(self):
assert len(self.Rt.shape) == 3
assert self.Rt.shape[0] == self.num_frames
assert self.Rt.shape[1] == 3
assert self.Rt.shape[2] == 4
assert len(self.K.shape) == 2
assert self.K.shape[0] == 3
assert self.K.shape[1] == 3
if self.imgnames is not None:
assert self.num_frames == len(self.imgnames)
assert self.num_frames > 0
im_p = self.imgnames[0]
assert op.exists(im_p), f"Image path {im_p} does not exist"
class ARCTICViewer:
def __init__(
self,
render_types=["rgb", "depth", "mask"],
interactive=True,
size=(2024, 2024),
):
if not interactive:
v = HeadlessRenderer()
else:
v = Viewer(size=size)
self.v = v
self.interactive = interactive
# self.layers = layers
self.render_types = render_types
def view_interactive(self):
self.v.run()
def view_fn_headless(self, num_iter, out_folder):
v = self.v
v._init_scene()
logger.info("Rendering to video")
if "video" in self.render_types:
vid_p = op.join(out_folder, "video.mp4")
v.save_video(video_dir=vid_p)
pbar = tqdm(range(num_iter))
for fidx in pbar:
out_rgb = op.join(out_folder, "images", f"rgb/{fidx:04d}.png")
out_mask = op.join(out_folder, "images", f"mask/{fidx:04d}.png")
out_depth = op.join(out_folder, "images", f"depth/{fidx:04d}.npy")
# render RGB, depth, segmentation masks
if "rgb" in self.render_types:
v.export_frame(out_rgb)
if "depth" in self.render_types:
os.makedirs(op.dirname(out_depth), exist_ok=True)
render_depth(v, out_depth)
if "mask" in self.render_types:
os.makedirs(op.dirname(out_mask), exist_ok=True)
render_mask(v, out_mask)
v.scene.next_frame()
logger.info(f"Exported to {out_folder}")
@abstractmethod
def load_data(self):
pass
def check_format(self, batch):
meshes_all, data = batch
assert isinstance(meshes_all, dict)
assert len(meshes_all) > 0
for mesh in meshes_all.values():
assert isinstance(mesh, Meshes)
assert isinstance(data, ViewerData)
def render_seq(self, batch, out_folder="./render_out"):
meshes_all, data = batch
self.setup_viewer(data)
for mesh in meshes_all.values():
self.v.scene.add(mesh)
if self.interactive:
self.view_interactive()
else:
num_iter = data["num_frames"]
self.view_fn_headless(num_iter, out_folder)
def setup_viewer(self, data):
v = self.v
fps = 30
if "imgnames" in data:
setup_billboard(data, v)
# camera.show_path()
v.run_animations = True # autoplay
v.run_animations = False # autoplay
v.playback_fps = fps
v.scene.fps = fps
v.scene.origin.enabled = False
v.scene.floor.enabled = False
v.auto_set_floor = False
v.scene.floor.position[1] = -3
# v.scene.camera.position = np.array((0.0, 0.0, 0))
self.v = v
def dist2vc(dist_ro, dist_lo, dist_o, _cmap, tf_fn=None):
if tf_fn is not None:
exp_map = tf_fn
else:
exp_map = small_exp_map
dist_ro = exp_map(dist_ro)
dist_lo = exp_map(dist_lo)
dist_o = exp_map(dist_o)
vc_ro = _cmap(dist_ro)
vc_lo = _cmap(dist_lo)
vc_o = _cmap(dist_o)
return vc_ro, vc_lo, vc_o
def small_exp_map(_dist):
dist = np.copy(_dist)
# dist = 1.0 - np.clip(dist, 0, 0.1) / 0.1
dist = np.exp(-20.0 * dist)
return dist
def construct_viewer_meshes(data, draw_edges=False, flat_shading=True):
rotation_flip = aa2rot_numpy(np.array([1, 0, 0]) * np.pi)
meshes = {}
for key, val in data.items():
if "object" in key:
flat_shading = False
else:
flat_shading = flat_shading
v3d = val["v3d"]
meshes[key] = Meshes(
v3d,
val["f3d"],
vertex_colors=val["vc"],
name=val["name"],
flat_shading=flat_shading,
draw_edges=draw_edges,
material=materials[val["color"]],
rotation=rotation_flip,
)
return meshes
def setup_viewer(
v, shared_folder_p, video, images_path, data, flag, seq_name, side_angle
):
fps = 10
cols, rows = 224, 224
focal = 1000.0
# setup image paths
regex = re.compile(r"(\d*)$")
def sort_key(x):
name = os.path.splitext(x)[0]
return int(regex.search(name).group(0))
# setup billboard
images_path = op.join(shared_folder_p, "images")
images_paths = [
os.path.join(images_path, f)
for f in sorted(os.listdir(images_path), key=sort_key)
]
assert len(images_paths) > 0
cam_t = data[f"{flag}.object.cam_t"]
num_frames = min(cam_t.shape[0], len(images_paths))
cam_t = cam_t[:num_frames]
# setup camera
K = np.array([[focal, 0, rows / 2.0], [0, focal, cols / 2.0], [0, 0, 1]])
Rt = np.zeros((num_frames, 3, 4))
Rt[:, :, 3] = cam_t
Rt[:, :3, :3] = np.eye(3)
Rt[:, 1:3, :3] *= -1.0
camera = OpenCVCamera(K, Rt, cols, rows, viewer=v)
if side_angle is None:
billboard = Billboard.from_camera_and_distance(
camera, 10.0, cols, rows, images_paths
)
v.scene.add(billboard)
v.scene.add(camera)
v.run_animations = True # autoplay
v.playback_fps = fps
v.scene.fps = fps
v.scene.origin.enabled = False
v.scene.floor.enabled = False
v.auto_set_floor = False
v.scene.floor.position[1] = -3
v.set_temp_camera(camera)
# v.scene.camera.position = np.array((0.0, 0.0, 0))
return v
def render_depth(v, depth_p):
depth = np.array(v.get_depth()).astype(np.float16)
np.save(depth_p, depth)
def render_mask(v, mask_p):
nodes_uid = {node.name: node.uid for node in v.scene.collect_nodes()}
my_cmap = {
uid: [SEGM_IDS[name], SEGM_IDS[name], SEGM_IDS[name]]
for name, uid in nodes_uid.items()
if name in SEGM_IDS.keys()
}
mask = np.array(v.get_mask(color_map=my_cmap)).astype(np.uint8)
mask = Image.fromarray(mask)
mask.save(mask_p)
def setup_billboard(data, v):
images_paths = data.imgnames
K = data.K
Rt = data.Rt
rows = data.rows
cols = data.cols
camera = OpenCVCamera(K, Rt, cols, rows, viewer=v)
if images_paths is not None:
billboard = Billboard.from_camera_and_distance(
camera, 10.0, cols, rows, images_paths
)
v.scene.add(billboard)
v.scene.add(camera)
v.scene.camera.load_cam()
v.set_temp_camera(camera)