Spaces:
Runtime error
Runtime error
# part of the code from | |
# https://github.com/benjiebob/SMALify/blob/master/smal_fitter/p3d_renderer.py | |
import torch | |
import torch.nn.functional as F | |
from scipy.io import loadmat | |
import numpy as np | |
# import config | |
import pytorch3d | |
from pytorch3d.structures import Meshes | |
from pytorch3d.renderer import ( | |
PerspectiveCameras, look_at_view_transform, look_at_rotation, | |
RasterizationSettings, MeshRenderer, MeshRasterizer, BlendParams, | |
PointLights, HardPhongShader, SoftSilhouetteShader, Materials, Textures, | |
DirectionalLights | |
) | |
from pytorch3d.renderer import TexturesVertex, SoftPhongShader | |
from pytorch3d.io import load_objs_as_meshes | |
MESH_COLOR_0 = [0, 172, 223] | |
MESH_COLOR_1 = [172, 223, 0] | |
''' | |
Explanation of the shift between projection results from opendr and pytorch3d: | |
(0, 0, ?) will be projected to 127.5 (pytorch3d) instead of 128 (opendr) | |
imagine you have an image of size 4: | |
middle of the first pixel is 0 | |
middle of the last pixel is 3 | |
=> middle of the imgae would be 1.5 and not 2! | |
so in order to go from pytorch3d predictions to opendr we would calculate: p_odr = p_p3d * (128/127.5) | |
To reproject points (p3d) by hand according to this pytorch3d renderer we would do the following steps: | |
1.) build camera matrix | |
K = np.array([[flength, 0, c_x], | |
[0, flength, c_y], | |
[0, 0, 1]], np.float) | |
2.) we don't need to add extrinsics, as the mesh comes with translation (which is | |
added within smal_pytorch). all 3d points are already in the camera coordinate system. | |
-> projection reduces to p2d_proj = K*p3d | |
3.) convert to pytorch3d conventions (0 in the middle of the first pixel) | |
p2d_proj_pytorch3d = p2d_proj / image_size * (image_size-1.) | |
renderer.py - project_points_p3d: shows an example of what is described above, but | |
same focal length for the whole batch | |
''' | |
class SilhRenderer(torch.nn.Module): | |
def __init__(self, image_size, adapt_R_wldo=False): | |
super(SilhRenderer, self).__init__() | |
# see: https://pytorch3d.org/files/fit_textured_mesh.py, line 315 | |
# adapt_R=True is True for all my experiments | |
# image_size: one number, integer | |
# ----- | |
# set mesh color | |
self.register_buffer('mesh_color_0', torch.FloatTensor(MESH_COLOR_0)) | |
self.register_buffer('mesh_color_1', torch.FloatTensor(MESH_COLOR_1)) | |
# prepare extrinsics, which in our case don't change | |
R = torch.Tensor(np.eye(3)).float()[None, :, :] | |
T = torch.Tensor(np.zeros((1, 3))).float() | |
if adapt_R_wldo: | |
R[0, 0, 0] = -1 | |
else: # used for all my own experiments | |
R[0, 0, 0] = -1 | |
R[0, 1, 1] = -1 | |
self.register_buffer('R', R) | |
self.register_buffer('T', T) | |
# prepare that part of the intrinsics which does not change either | |
# principal_point_prep = torch.Tensor([self.image_size / 2., self.image_size / 2.]).float()[None, :].float().to(device) | |
# image_size_prep = torch.Tensor([self.image_size, self.image_size]).float()[None, :].float().to(device) | |
self.img_size_scalar = image_size | |
self.register_buffer('image_size', torch.Tensor([image_size, image_size]).float()[None, :].float()) | |
self.register_buffer('principal_point', torch.Tensor([image_size / 2., image_size / 2.]).float()[None, :].float()) | |
# Rasterization settings for differentiable rendering, where the blur_radius | |
# initialization is based on Liu et al, 'Soft Rasterizer: A Differentiable | |
# Renderer for Image-based 3D Reasoning', ICCV 2019 | |
self.blend_params = BlendParams(sigma=1e-4, gamma=1e-4) | |
self.raster_settings_soft = RasterizationSettings( | |
image_size=image_size, # 128 | |
blur_radius=np.log(1. / 1e-4 - 1.)*self.blend_params.sigma, | |
faces_per_pixel=100) #50, | |
# Renderer for Image-based 3D Reasoning', body part segmentation | |
self.blend_params_parts = BlendParams(sigma=2*1e-4, gamma=1e-4) | |
self.raster_settings_soft_parts = RasterizationSettings( | |
image_size=image_size, # 128 | |
blur_radius=np.log(1. / 1e-4 - 1.)*self.blend_params_parts.sigma, | |
faces_per_pixel=60) #50, | |
# settings for visualization renderer | |
self.raster_settings_vis = RasterizationSettings( | |
image_size=image_size, | |
blur_radius=0.0, | |
faces_per_pixel=1) | |
def _get_cam(self, focal_lengths): | |
device = focal_lengths.device | |
bs = focal_lengths.shape[0] | |
if pytorch3d.__version__ == '0.2.5': | |
cameras = PerspectiveCameras(device=device, | |
focal_length=focal_lengths.repeat((1, 2)), | |
principal_point=self.principal_point.repeat((bs, 1)), | |
R=self.R.repeat((bs, 1, 1)), T=self.T.repeat((bs, 1)), | |
image_size=self.image_size.repeat((bs, 1))) | |
elif pytorch3d.__version__ == '0.6.1': | |
cameras = PerspectiveCameras(device=device, in_ndc=False, | |
focal_length=focal_lengths.repeat((1, 2)), | |
principal_point=self.principal_point.repeat((bs, 1)), | |
R=self.R.repeat((bs, 1, 1)), T=self.T.repeat((bs, 1)), | |
image_size=self.image_size.repeat((bs, 1))) | |
else: | |
print('this part depends on the version of pytorch3d, code was developed with 0.2.5') | |
raise ValueError | |
return cameras | |
def _get_visualization_from_mesh(self, mesh, cameras, lights=None): | |
# color renderer for visualization | |
with torch.no_grad(): | |
device = mesh.device | |
# renderer for visualization | |
if lights is None: | |
lights = PointLights(device=device, location=[[0.0, 0.0, 3.0]]) | |
vis_renderer = MeshRenderer( | |
rasterizer=MeshRasterizer( | |
cameras=cameras, | |
raster_settings=self.raster_settings_vis), | |
shader=HardPhongShader( | |
device=device, | |
cameras=cameras, | |
lights=lights)) | |
# render image: | |
visualization = vis_renderer(mesh).permute(0, 3, 1, 2)[:, :3, :, :] | |
return visualization | |
def calculate_vertex_visibility(self, vertices, faces, focal_lengths, soft=False): | |
tex = torch.ones_like(vertices) * self.mesh_color_0 # (1, V, 3) | |
textures = Textures(verts_rgb=tex) | |
mesh = Meshes(verts=vertices, faces=faces, textures=textures) | |
cameras = self._get_cam(focal_lengths) | |
# NEW: use the rasterizer to check vertex visibility | |
# see: https://github.com/facebookresearch/pytorch3d/issues/126 | |
# Get a rasterizer | |
if soft: | |
rasterizer = MeshRasterizer(cameras=cameras, | |
raster_settings=self.raster_settings_soft) | |
else: | |
rasterizer = MeshRasterizer(cameras=cameras, | |
raster_settings=self.raster_settings_vis) | |
# Get the output from rasterization | |
fragments = rasterizer(mesh) | |
# pix_to_face is of shape (N, H, W, 1) | |
pix_to_face = fragments.pix_to_face | |
# (F, 3) where F is the total number of faces across all the meshes in the batch | |
packed_faces = mesh.faces_packed() | |
# (V, 3) where V is the total number of verts across all the meshes in the batch | |
packed_verts = mesh.verts_packed() | |
vertex_visibility_map = torch.zeros(packed_verts.shape[0]) # (V,) | |
# Indices of unique visible faces | |
visible_faces = pix_to_face.unique() # [0] # (num_visible_faces ) | |
# Get Indices of unique visible verts using the vertex indices in the faces | |
visible_verts_idx = packed_faces[visible_faces] # (num_visible_faces, 3) | |
unique_visible_verts_idx = torch.unique(visible_verts_idx) # (num_visible_verts, ) | |
# Update visibility indicator to 1 for all visible vertices | |
vertex_visibility_map[unique_visible_verts_idx] = 1.0 | |
# since all meshes have the same amount of vertices, we can reshape the result | |
bs = vertices.shape[0] | |
vertex_visibility_map_resh = vertex_visibility_map.reshape((bs, -1)) | |
return pix_to_face, vertex_visibility_map_resh | |
def get_torch_meshes(self, vertices, faces, color=0): | |
# create pytorch mesh | |
if color == 0: | |
mesh_color = self.mesh_color_0 | |
else: | |
mesh_color = self.mesh_color_1 | |
tex = torch.ones_like(vertices) * mesh_color # (1, V, 3) | |
textures = Textures(verts_rgb=tex) | |
mesh = Meshes(verts=vertices, faces=faces, textures=textures) | |
return mesh | |
def get_visualization_nograd(self, vertices, faces, focal_lengths, color=0): | |
# vertices: torch.Size([bs, 3889, 3]) | |
# faces: torch.Size([bs, 7774, 3]), int | |
# focal_lengths: torch.Size([bs, 1]) | |
device = vertices.device | |
# create cameras | |
cameras = self._get_cam(focal_lengths) | |
# create pytorch mesh | |
if color == 0: | |
mesh_color = self.mesh_color_0 # blue | |
elif color == 1: | |
mesh_color = self.mesh_color_1 | |
elif color == 2: | |
MESH_COLOR_2 = [240, 250, 240] # white | |
mesh_color = torch.FloatTensor(MESH_COLOR_2).to(device) | |
elif color == 3: | |
# MESH_COLOR_3 = [223, 0, 172] # pink | |
# MESH_COLOR_3 = [245, 245, 220] # beige | |
MESH_COLOR_3 = [166, 173, 164] | |
mesh_color = torch.FloatTensor(MESH_COLOR_3).to(device) | |
else: | |
MESH_COLOR_2 = [240, 250, 240] | |
mesh_color = torch.FloatTensor(MESH_COLOR_2).to(device) | |
tex = torch.ones_like(vertices) * mesh_color # (1, V, 3) | |
textures = Textures(verts_rgb=tex) | |
mesh = Meshes(verts=vertices, faces=faces, textures=textures) | |
# render mesh (no gradients) | |
# lights = PointLights(device=device, location=[[0.0, 0.0, 3.0]]) | |
# lights = PointLights(device=device, location=[[2.0, 2.0, -2.0]]) | |
lights = DirectionalLights(device=device, direction=[[0.0, -5.0, -10.0]]) | |
visualization = self._get_visualization_from_mesh(mesh, cameras, lights=lights) | |
return visualization | |
def project_points(self, points, focal_lengths=None, cameras=None): | |
# points: torch.Size([bs, n_points, 3]) | |
# either focal_lengths or cameras is needed: | |
# focal_lenghts: torch.Size([bs, 1]) | |
# cameras: pytorch camera, for example PerspectiveCameras() | |
bs = points.shape[0] | |
device = points.device | |
screen_size = self.image_size.repeat((bs, 1)) | |
if cameras is None: | |
cameras = self._get_cam(focal_lengths) | |
if pytorch3d.__version__ == '0.2.5': | |
proj_points_orig = cameras.transform_points_screen(points, screen_size)[:, :, [1, 0]] # used in the original virtuel environment (for cvpr BARC submission) | |
elif pytorch3d.__version__ == '0.6.1': | |
proj_points_orig = cameras.transform_points_screen(points)[:, :, [1, 0]] | |
else: | |
print('this part depends on the version of pytorch3d, code was developed with 0.2.5') | |
raise ValueError | |
# flip, otherwise the 1st and 2nd row are exchanged compared to the ground truth | |
proj_points = torch.flip(proj_points_orig, [2]) | |
# --- project points 'manually' | |
# j_proj = project_points_p3d(image_size, focal_length, points, device) | |
return proj_points | |
def forward(self, vertices, points, faces, focal_lengths, color=None): | |
# vertices: torch.Size([bs, 3889, 3]) | |
# points: torch.Size([bs, n_points, 3]) (or None) | |
# faces: torch.Size([bs, 7774, 3]), int | |
# focal_lengths: torch.Size([bs, 1]) | |
# color: if None we don't render a visualization, else it should | |
# either be 0 or 1 | |
# ---> important: results are around 0.5 pixels off compared to chumpy! | |
# have a look at renderer.py for an explanation | |
# create cameras | |
cameras = self._get_cam(focal_lengths) | |
# create pytorch mesh | |
if color is None or color == 0: | |
mesh_color = self.mesh_color_0 | |
else: | |
mesh_color = self.mesh_color_1 | |
tex = torch.ones_like(vertices) * mesh_color # (1, V, 3) | |
textures = Textures(verts_rgb=tex) | |
mesh = Meshes(verts=vertices, faces=faces, textures=textures) | |
# silhouette renderer | |
renderer_silh = MeshRenderer( | |
rasterizer=MeshRasterizer( | |
cameras=cameras, | |
raster_settings=self.raster_settings_soft), | |
shader=SoftSilhouetteShader(blend_params=self.blend_params)) | |
# project silhouette | |
silh_images = renderer_silh(mesh)[..., -1].unsqueeze(1) | |
# project points | |
if points is None: | |
proj_points = None | |
else: | |
proj_points = self.project_points(points=points, cameras=cameras) | |
if color is not None: | |
# color renderer for visualization (no gradients) | |
visualization = self._get_visualization_from_mesh(mesh, cameras) | |
return silh_images, proj_points, visualization | |
else: | |
return silh_images, proj_points | |