Spaces:
Sleeping
Sleeping
# Copyright (c) OpenMMLab. All rights reserved. | |
import torch | |
import torch.nn as nn | |
__all__ = ['MeshLoss', 'GANLoss'] | |
def rot6d_to_rotmat(x): | |
"""Convert 6D rotation representation to 3x3 rotation matrix. | |
Based on Zhou et al., "On the Continuity of Rotation | |
Representations in Neural Networks", CVPR 2019 | |
Input: | |
(B,6) Batch of 6-D rotation representations | |
Output: | |
(B,3,3) Batch of corresponding rotation matrices | |
""" | |
x = x.view(-1, 3, 2) | |
a1 = x[:, :, 0] | |
a2 = x[:, :, 1] | |
b1 = F.normalize(a1) | |
b2 = F.normalize(a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1) | |
b3 = torch.cross(b1, b2) | |
return torch.stack((b1, b2, b3), dim=-1) | |
def batch_rodrigues(theta): | |
"""Convert axis-angle representation to rotation matrix. | |
Args: | |
theta: size = [B, 3] | |
Returns: | |
Rotation matrix corresponding to the quaternion | |
-- size = [B, 3, 3] | |
""" | |
l2norm = torch.norm(theta + 1e-8, p=2, dim=1) | |
angle = torch.unsqueeze(l2norm, -1) | |
normalized = torch.div(theta, angle) | |
angle = angle * 0.5 | |
v_cos = torch.cos(angle) | |
v_sin = torch.sin(angle) | |
quat = torch.cat([v_cos, v_sin * normalized], dim=1) | |
return quat_to_rotmat(quat) | |
def quat_to_rotmat(quat): | |
"""Convert quaternion coefficients to rotation matrix. | |
Args: | |
quat: size = [B, 4] 4 <===>(w, x, y, z) | |
Returns: | |
Rotation matrix corresponding to the quaternion | |
-- size = [B, 3, 3] | |
""" | |
norm_quat = quat | |
norm_quat = norm_quat / norm_quat.norm(p=2, dim=1, keepdim=True) | |
w, x, y, z = norm_quat[:, 0], norm_quat[:, 1],\ | |
norm_quat[:, 2], norm_quat[:, 3] | |
B = quat.size(0) | |
w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2) | |
wx, wy, wz = w * x, w * y, w * z | |
xy, xz, yz = x * y, x * z, y * z | |
rotMat = torch.stack([ | |
w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz, 2 * wz + 2 * xy, | |
w2 - x2 + y2 - z2, 2 * yz - 2 * wx, 2 * xz - 2 * wy, 2 * wx + 2 * yz, | |
w2 - x2 - y2 + z2 | |
], | |
dim=1).view(B, 3, 3) | |
return rotMat | |
def perspective_projection(points, rotation, translation, focal_length, | |
camera_center): | |
"""This function computes the perspective projection of a set of 3D points. | |
Note: | |
- batch size: B | |
- point number: N | |
Args: | |
points (Tensor([B, N, 3])): A set of 3D points | |
rotation (Tensor([B, 3, 3])): Camera rotation matrix | |
translation (Tensor([B, 3])): Camera translation | |
focal_length (Tensor([B,])): Focal length | |
camera_center (Tensor([B, 2])): Camera center | |
Returns: | |
projected_points (Tensor([B, N, 2])): Projected 2D | |
points in image space. | |
""" | |
batch_size = points.shape[0] | |
K = torch.zeros([batch_size, 3, 3], device=points.device) | |
K[:, 0, 0] = focal_length | |
K[:, 1, 1] = focal_length | |
K[:, 2, 2] = 1. | |
K[:, :-1, -1] = camera_center | |
# Transform points | |
points = torch.einsum('bij,bkj->bki', rotation, points) | |
points = points + translation.unsqueeze(1) | |
# Apply perspective distortion | |
projected_points = points / points[:, :, -1].unsqueeze(-1) | |
# Apply camera intrinsics | |
projected_points = torch.einsum('bij,bkj->bki', K, projected_points) | |
projected_points = projected_points[:, :, :-1] | |
return projected_points | |
class MeshLoss(nn.Module): | |
"""Mix loss for 3D human mesh. It is composed of loss on 2D joints, 3D | |
joints, mesh vertices and smpl parameters (if any). | |
Args: | |
joints_2d_loss_weight (float): Weight for loss on 2D joints. | |
joints_3d_loss_weight (float): Weight for loss on 3D joints. | |
vertex_loss_weight (float): Weight for loss on 3D verteices. | |
smpl_pose_loss_weight (float): Weight for loss on SMPL | |
pose parameters. | |
smpl_beta_loss_weight (float): Weight for loss on SMPL | |
shape parameters. | |
img_res (int): Input image resolution. | |
focal_length (float): Focal length of camera model. Default=5000. | |
""" | |
def __init__(self, | |
joints_2d_loss_weight, | |
joints_3d_loss_weight, | |
vertex_loss_weight, | |
smpl_pose_loss_weight, | |
smpl_beta_loss_weight, | |
img_res, | |
focal_length=5000): | |
super().__init__() | |
# Per-vertex loss on the mesh | |
self.criterion_vertex = nn.L1Loss(reduction='none') | |
# Joints (2D and 3D) loss | |
self.criterion_joints_2d = nn.SmoothL1Loss(reduction='none') | |
self.criterion_joints_3d = nn.SmoothL1Loss(reduction='none') | |
# Loss for SMPL parameter regression | |
self.criterion_regr = nn.MSELoss(reduction='none') | |
self.joints_2d_loss_weight = joints_2d_loss_weight | |
self.joints_3d_loss_weight = joints_3d_loss_weight | |
self.vertex_loss_weight = vertex_loss_weight | |
self.smpl_pose_loss_weight = smpl_pose_loss_weight | |
self.smpl_beta_loss_weight = smpl_beta_loss_weight | |
self.focal_length = focal_length | |
self.img_res = img_res | |
def joints_2d_loss(self, pred_joints_2d, gt_joints_2d, joints_2d_visible): | |
"""Compute 2D reprojection loss on the joints. | |
The loss is weighted by joints_2d_visible. | |
""" | |
conf = joints_2d_visible.float() | |
loss = (conf * | |
self.criterion_joints_2d(pred_joints_2d, gt_joints_2d)).mean() | |
return loss | |
def joints_3d_loss(self, pred_joints_3d, gt_joints_3d, joints_3d_visible): | |
"""Compute 3D joints loss for the examples that 3D joint annotations | |
are available. | |
The loss is weighted by joints_3d_visible. | |
""" | |
conf = joints_3d_visible.float() | |
if len(gt_joints_3d) > 0: | |
gt_pelvis = (gt_joints_3d[:, 2, :] + gt_joints_3d[:, 3, :]) / 2 | |
gt_joints_3d = gt_joints_3d - gt_pelvis[:, None, :] | |
pred_pelvis = (pred_joints_3d[:, 2, :] + | |
pred_joints_3d[:, 3, :]) / 2 | |
pred_joints_3d = pred_joints_3d - pred_pelvis[:, None, :] | |
return ( | |
conf * | |
self.criterion_joints_3d(pred_joints_3d, gt_joints_3d)).mean() | |
return pred_joints_3d.sum() * 0 | |
def vertex_loss(self, pred_vertices, gt_vertices, has_smpl): | |
"""Compute 3D vertex loss for the examples that 3D human mesh | |
annotations are available. | |
The loss is weighted by the has_smpl. | |
""" | |
conf = has_smpl.float() | |
loss_vertex = self.criterion_vertex(pred_vertices, gt_vertices) | |
loss_vertex = (conf[:, None, None] * loss_vertex).mean() | |
return loss_vertex | |
def smpl_losses(self, pred_rotmat, pred_betas, gt_pose, gt_betas, | |
has_smpl): | |
"""Compute SMPL parameters loss for the examples that SMPL parameter | |
annotations are available. | |
The loss is weighted by has_smpl. | |
""" | |
conf = has_smpl.float() | |
gt_rotmat = batch_rodrigues(gt_pose.view(-1, 3)).view(-1, 24, 3, 3) | |
loss_regr_pose = self.criterion_regr(pred_rotmat, gt_rotmat) | |
loss_regr_betas = self.criterion_regr(pred_betas, gt_betas) | |
loss_regr_pose = (conf[:, None, None, None] * loss_regr_pose).mean() | |
loss_regr_betas = (conf[:, None] * loss_regr_betas).mean() | |
return loss_regr_pose, loss_regr_betas | |
def project_points(self, points_3d, camera): | |
"""Perform orthographic projection of 3D points using the camera | |
parameters, return projected 2D points in image plane. | |
Note: | |
- batch size: B | |
- point number: N | |
Args: | |
points_3d (Tensor([B, N, 3])): 3D points. | |
camera (Tensor([B, 3])): camera parameters with the | |
3 channel as (scale, translation_x, translation_y) | |
Returns: | |
Tensor([B, N, 2]): projected 2D points \ | |
in image space. | |
""" | |
batch_size = points_3d.shape[0] | |
device = points_3d.device | |
cam_t = torch.stack([ | |
camera[:, 1], camera[:, 2], 2 * self.focal_length / | |
(self.img_res * camera[:, 0] + 1e-9) | |
], | |
dim=-1) | |
camera_center = camera.new_zeros([batch_size, 2]) | |
rot_t = torch.eye( | |
3, device=device, | |
dtype=points_3d.dtype).unsqueeze(0).expand(batch_size, -1, -1) | |
joints_2d = perspective_projection( | |
points_3d, | |
rotation=rot_t, | |
translation=cam_t, | |
focal_length=self.focal_length, | |
camera_center=camera_center) | |
return joints_2d | |
def forward(self, output, target): | |
"""Forward function. | |
Args: | |
output (dict): dict of network predicted results. | |
Keys: 'vertices', 'joints_3d', 'camera', | |
'pose'(optional), 'beta'(optional) | |
target (dict): dict of ground-truth labels. | |
Keys: 'vertices', 'joints_3d', 'joints_3d_visible', | |
'joints_2d', 'joints_2d_visible', 'pose', 'beta', | |
'has_smpl' | |
Returns: | |
dict: dict of losses. | |
""" | |
losses = {} | |
# Per-vertex loss for the shape | |
pred_vertices = output['vertices'] | |
gt_vertices = target['vertices'] | |
has_smpl = target['has_smpl'] | |
loss_vertex = self.vertex_loss(pred_vertices, gt_vertices, has_smpl) | |
losses['vertex_loss'] = loss_vertex * self.vertex_loss_weight | |
# Compute loss on SMPL parameters, if available | |
if 'pose' in output.keys() and 'beta' in output.keys(): | |
pred_rotmat = output['pose'] | |
pred_betas = output['beta'] | |
gt_pose = target['pose'] | |
gt_betas = target['beta'] | |
loss_regr_pose, loss_regr_betas = self.smpl_losses( | |
pred_rotmat, pred_betas, gt_pose, gt_betas, has_smpl) | |
losses['smpl_pose_loss'] = \ | |
loss_regr_pose * self.smpl_pose_loss_weight | |
losses['smpl_beta_loss'] = \ | |
loss_regr_betas * self.smpl_beta_loss_weight | |
# Compute 3D joints loss | |
pred_joints_3d = output['joints_3d'] | |
gt_joints_3d = target['joints_3d'] | |
joints_3d_visible = target['joints_3d_visible'] | |
loss_joints_3d = self.joints_3d_loss(pred_joints_3d, gt_joints_3d, | |
joints_3d_visible) | |
losses['joints_3d_loss'] = loss_joints_3d * self.joints_3d_loss_weight | |
# Compute 2D reprojection loss for the 2D joints | |
pred_camera = output['camera'] | |
gt_joints_2d = target['joints_2d'] | |
joints_2d_visible = target['joints_2d_visible'] | |
pred_joints_2d = self.project_points(pred_joints_3d, pred_camera) | |
# Normalize keypoints to [-1,1] | |
# The coordinate origin of pred_joints_2d is | |
# the center of the input image. | |
pred_joints_2d = 2 * pred_joints_2d / (self.img_res - 1) | |
# The coordinate origin of gt_joints_2d is | |
# the top left corner of the input image. | |
gt_joints_2d = 2 * gt_joints_2d / (self.img_res - 1) - 1 | |
loss_joints_2d = self.joints_2d_loss(pred_joints_2d, gt_joints_2d, | |
joints_2d_visible) | |
losses['joints_2d_loss'] = loss_joints_2d * self.joints_2d_loss_weight | |
return losses | |
class GANLoss(nn.Module): | |
"""Define GAN loss. | |
Args: | |
gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'. | |
real_label_val (float): The value for real label. Default: 1.0. | |
fake_label_val (float): The value for fake label. Default: 0.0. | |
loss_weight (float): Loss weight. Default: 1.0. | |
Note that loss_weight is only for generators; and it is always 1.0 | |
for discriminators. | |
""" | |
def __init__(self, | |
gan_type, | |
real_label_val=1.0, | |
fake_label_val=0.0, | |
loss_weight=1.0): | |
super().__init__() | |
self.gan_type = gan_type | |
self.loss_weight = loss_weight | |
self.real_label_val = real_label_val | |
self.fake_label_val = fake_label_val | |
if self.gan_type == 'vanilla': | |
self.loss = nn.BCEWithLogitsLoss() | |
elif self.gan_type == 'lsgan': | |
self.loss = nn.MSELoss() | |
elif self.gan_type == 'wgan': | |
self.loss = self._wgan_loss | |
elif self.gan_type == 'hinge': | |
self.loss = nn.ReLU() | |
else: | |
raise NotImplementedError( | |
f'GAN type {self.gan_type} is not implemented.') | |
def _wgan_loss(input, target): | |
"""wgan loss. | |
Args: | |
input (Tensor): Input tensor. | |
target (bool): Target label. | |
Returns: | |
Tensor: wgan loss. | |
""" | |
return -input.mean() if target else input.mean() | |
def get_target_label(self, input, target_is_real): | |
"""Get target label. | |
Args: | |
input (Tensor): Input tensor. | |
target_is_real (bool): Whether the target is real or fake. | |
Returns: | |
(bool | Tensor): Target tensor. Return bool for wgan, \ | |
otherwise, return Tensor. | |
""" | |
if self.gan_type == 'wgan': | |
return target_is_real | |
target_val = ( | |
self.real_label_val if target_is_real else self.fake_label_val) | |
return input.new_ones(input.size()) * target_val | |
def forward(self, input, target_is_real, is_disc=False): | |
""" | |
Args: | |
input (Tensor): The input for the loss module, i.e., the network | |
prediction. | |
target_is_real (bool): Whether the targe is real or fake. | |
is_disc (bool): Whether the loss for discriminators or not. | |
Default: False. | |
Returns: | |
Tensor: GAN loss value. | |
""" | |
target_label = self.get_target_label(input, target_is_real) | |
if self.gan_type == 'hinge': | |
if is_disc: # for discriminators in hinge-gan | |
input = -input if target_is_real else input | |
loss = self.loss(1 + input).mean() | |
else: # for generators in hinge-gan | |
loss = -input.mean() | |
else: # other gan types | |
loss = self.loss(input, target_label) | |
# loss_weight is always 1.0 for discriminators | |
return loss if is_disc else loss * self.loss_weight | |