PSHuman / lib /common /render_utils.py
fffiloni's picture
Migrated from GitHub
2252f3d verified
# -*- coding: utf-8 -*-
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
# holder of all proprietary rights on this computer program.
# You can only use this computer program if you have closed
# a license agreement with MPG or you get the right to use the computer
# program from someone who is authorized to grant you that right.
# Any use of the computer program without a valid license is prohibited and
# liable to prosecution.
#
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
# for Intelligent Systems. All rights reserved.
#
# Contact: [email protected]
import torch
from torch import nn
import trimesh
import math
from typing import NewType
from pytorch3d.structures import Meshes
from pytorch3d.renderer.mesh import rasterize_meshes
Tensor = NewType('Tensor', torch.Tensor)
def solid_angles(points: Tensor,
triangles: Tensor,
thresh: float = 1e-8) -> Tensor:
''' Compute solid angle between the input points and triangles
Follows the method described in:
The Solid Angle of a Plane Triangle
A. VAN OOSTEROM AND J. STRACKEE
IEEE TRANSACTIONS ON BIOMEDICAL ENGINEERING,
VOL. BME-30, NO. 2, FEBRUARY 1983
Parameters
-----------
points: BxQx3
Tensor of input query points
triangles: BxFx3x3
Target triangles
thresh: float
float threshold
Returns
-------
solid_angles: BxQxF
A tensor containing the solid angle between all query points
and input triangles
'''
# Center the triangles on the query points. Size should be BxQxFx3x3
centered_tris = triangles[:, None] - points[:, :, None, None]
# BxQxFx3
norms = torch.norm(centered_tris, dim=-1)
# Should be BxQxFx3
cross_prod = torch.cross(centered_tris[:, :, :, 1],
centered_tris[:, :, :, 2],
dim=-1)
# Should be BxQxF
numerator = (centered_tris[:, :, :, 0] * cross_prod).sum(dim=-1)
del cross_prod
dot01 = (centered_tris[:, :, :, 0] * centered_tris[:, :, :, 1]).sum(dim=-1)
dot12 = (centered_tris[:, :, :, 1] * centered_tris[:, :, :, 2]).sum(dim=-1)
dot02 = (centered_tris[:, :, :, 0] * centered_tris[:, :, :, 2]).sum(dim=-1)
del centered_tris
denominator = (norms.prod(dim=-1) + dot01 * norms[:, :, :, 2] +
dot02 * norms[:, :, :, 1] + dot12 * norms[:, :, :, 0])
del dot01, dot12, dot02, norms
# Should be BxQ
solid_angle = torch.atan2(numerator, denominator)
del numerator, denominator
torch.cuda.empty_cache()
return 2 * solid_angle
def winding_numbers(points: Tensor,
triangles: Tensor,
thresh: float = 1e-8) -> Tensor:
''' Uses winding_numbers to compute inside/outside
Robust inside-outside segmentation using generalized winding numbers
Alec Jacobson,
Ladislav Kavan,
Olga Sorkine-Hornung
Fast Winding Numbers for Soups and Clouds SIGGRAPH 2018
Gavin Barill
NEIL G. Dickson
Ryan Schmidt
David I.W. Levin
and Alec Jacobson
Parameters
-----------
points: BxQx3
Tensor of input query points
triangles: BxFx3x3
Target triangles
thresh: float
float threshold
Returns
-------
winding_numbers: BxQ
A tensor containing the Generalized winding numbers
'''
# The generalized winding number is the sum of solid angles of the point
# with respect to all triangles.
return 1 / (4 * math.pi) * solid_angles(points, triangles,
thresh=thresh).sum(dim=-1)
def batch_contains(verts, faces, points):
B = verts.shape[0]
N = points.shape[1]
verts = verts.detach().cpu()
faces = faces.detach().cpu()
points = points.detach().cpu()
contains = torch.zeros(B, N)
for i in range(B):
contains[i] = torch.as_tensor(
trimesh.Trimesh(verts[i], faces[i]).contains(points[i]))
return 2.0 * (contains - 0.5)
def dict2obj(d):
# if isinstance(d, list):
# d = [dict2obj(x) for x in d]
if not isinstance(d, dict):
return d
class C(object):
pass
o = C()
for k in d:
o.__dict__[k] = dict2obj(d[k])
return o
def face_vertices(vertices, faces):
"""
:param vertices: [batch size, number of vertices, 3]
:param faces: [batch size, number of faces, 3]
:return: [batch size, number of faces, 3, 3]
"""
bs, nv = vertices.shape[:2]
bs, nf = faces.shape[:2]
device = vertices.device
faces = faces + (torch.arange(bs, dtype=torch.int32).to(device) *
nv)[:, None, None]
vertices = vertices.reshape((bs * nv, vertices.shape[-1]))
return vertices[faces.long()]
class Pytorch3dRasterizer(nn.Module):
""" Borrowed from https://github.com/facebookresearch/pytorch3d
Notice:
x,y,z are in image space, normalized
can only render squared image now
"""
def __init__(self, image_size=224):
"""
use fixed raster_settings for rendering faces
"""
super().__init__()
raster_settings = {
'image_size': image_size,
'blur_radius': 0.0,
'faces_per_pixel': 1,
'bin_size': None,
'max_faces_per_bin': None,
'perspective_correct': True,
'cull_backfaces': True,
}
raster_settings = dict2obj(raster_settings)
self.raster_settings = raster_settings
def forward(self, vertices, faces, attributes=None):
fixed_vertices = vertices.clone()
fixed_vertices[..., :2] = -fixed_vertices[..., :2]
meshes_screen = Meshes(verts=fixed_vertices.float(),
faces=faces.long())
raster_settings = self.raster_settings
pix_to_face, zbuf, bary_coords, dists = rasterize_meshes(
meshes_screen,
image_size=raster_settings.image_size,
blur_radius=raster_settings.blur_radius,
faces_per_pixel=raster_settings.faces_per_pixel,
bin_size=raster_settings.bin_size,
max_faces_per_bin=raster_settings.max_faces_per_bin,
perspective_correct=raster_settings.perspective_correct,
)
vismask = (pix_to_face > -1).float()
D = attributes.shape[-1]
attributes = attributes.clone()
attributes = attributes.view(attributes.shape[0] * attributes.shape[1],
3, attributes.shape[-1])
N, H, W, K, _ = bary_coords.shape
mask = pix_to_face == -1
pix_to_face = pix_to_face.clone()
pix_to_face[mask] = 0
idx = pix_to_face.view(N * H * W * K, 1, 1).expand(N * H * W * K, 3, D)
pixel_face_vals = attributes.gather(0, idx).view(N, H, W, K, 3, D)
pixel_vals = (bary_coords[..., None] * pixel_face_vals).sum(dim=-2)
pixel_vals[mask] = 0 # Replace masked values in output.
pixel_vals = pixel_vals[:, :, :, 0].permute(0, 3, 1, 2)
pixel_vals = torch.cat(
[pixel_vals, vismask[:, :, :, 0][:, None, :, :]], dim=1)
return pixel_vals