Spaces:
Build error
Build error
import io | |
import os | |
import torch | |
from skimage.io import imread | |
import numpy as np | |
import cv2 | |
from tqdm import tqdm_notebook as tqdm | |
import base64 | |
from IPython.display import HTML | |
# Util function for loading meshes | |
from pytorch3d.io import load_objs_as_meshes | |
from IPython.display import HTML | |
from base64 import b64encode | |
# Data structures and functions for rendering | |
from pytorch3d.structures import Meshes | |
from pytorch3d.renderer import ( | |
look_at_view_transform, | |
OpenGLOrthographicCameras, | |
PointLights, | |
DirectionalLights, | |
Materials, | |
RasterizationSettings, | |
MeshRenderer, | |
MeshRasterizer, | |
SoftPhongShader, | |
HardPhongShader, | |
TexturesVertex | |
) | |
def set_renderer(): | |
# Setup | |
device = torch.device("cuda:0") | |
torch.cuda.set_device(device) | |
# Initialize an OpenGL perspective camera. | |
R, T = look_at_view_transform(2.0, 0, 180) | |
cameras = OpenGLOrthographicCameras(device=device, R=R, T=T) | |
raster_settings = RasterizationSettings( | |
image_size=512, | |
blur_radius=0.0, | |
faces_per_pixel=1, | |
bin_size = None, | |
max_faces_per_bin = None | |
) | |
lights = PointLights(device=device, location=((2.0, 2.0, 2.0),)) | |
renderer = MeshRenderer( | |
rasterizer=MeshRasterizer( | |
cameras=cameras, | |
raster_settings=raster_settings | |
), | |
shader=HardPhongShader( | |
device=device, | |
cameras=cameras, | |
lights=lights | |
) | |
) | |
return renderer | |
def get_verts_rgb_colors(obj_path): | |
rgb_colors = [] | |
f = open(obj_path) | |
lines = f.readlines() | |
for line in lines: | |
ls = line.split(' ') | |
if len(ls) == 7: | |
rgb_colors.append(ls[-3:]) | |
return np.array(rgb_colors, dtype='float32')[None, :, :] | |
def generate_video_from_obj(obj_path, video_path, renderer): | |
# Setup | |
device = torch.device("cuda:0") | |
torch.cuda.set_device(device) | |
# Load obj file | |
verts_rgb_colors = get_verts_rgb_colors(obj_path) | |
verts_rgb_colors = torch.from_numpy(verts_rgb_colors).to(device) | |
textures = TexturesVertex(verts_features=verts_rgb_colors) | |
wo_textures = TexturesVertex(verts_features=torch.ones_like(verts_rgb_colors)*0.75) | |
# Load obj | |
mesh = load_objs_as_meshes([obj_path], device=device) | |
# Set mesh | |
vers = mesh._verts_list | |
faces = mesh._faces_list | |
mesh_w_tex = Meshes(vers, faces, textures) | |
mesh_wo_tex = Meshes(vers, faces, wo_textures) | |
# create VideoWriter | |
fourcc = cv2. VideoWriter_fourcc(*'MP4V') | |
out = cv2.VideoWriter(video_path, fourcc, 20.0, (1024,512)) | |
for i in tqdm(range(90)): | |
R, T = look_at_view_transform(1.8, 0, i*4, device=device) | |
images_w_tex = renderer(mesh_w_tex, R=R, T=T) | |
images_w_tex = np.clip(images_w_tex[0, ..., :3].cpu().numpy(), 0.0, 1.0)[:, :, ::-1] * 255 | |
images_wo_tex = renderer(mesh_wo_tex, R=R, T=T) | |
images_wo_tex = np.clip(images_wo_tex[0, ..., :3].cpu().numpy(), 0.0, 1.0)[:, :, ::-1] * 255 | |
image = np.concatenate([images_w_tex, images_wo_tex], axis=1) | |
out.write(image.astype('uint8')) | |
out.release() | |
def video(path): | |
mp4 = open(path,'rb').read() | |
data_url = "data:video/mp4;base64," + b64encode(mp4).decode() | |
return HTML('<video width=500 controls loop> <source src="%s" type="video/mp4"></video>' % data_url) | |