Spaces:
Runtime error
Runtime error
''' | |
Adjusted version of other PyTorch implementation of the SMAL/SMPL model | |
see: | |
1.) https://github.com/silviazuffi/smalst/blob/master/smal_model/smal_torch.py | |
2.) https://github.com/benjiebob/SMALify/blob/master/smal_model/smal_torch.py | |
''' | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import torch | |
import numpy as np | |
def batch_skew(vec, batch_size=None): | |
""" | |
vec is N x 3, batch_size is int | |
returns N x 3 x 3. Skew_sym version of each matrix. | |
""" | |
device = vec.device | |
if batch_size is None: | |
batch_size = vec.shape.as_list()[0] | |
col_inds = torch.LongTensor([1, 2, 3, 5, 6, 7]) | |
indices = torch.reshape(torch.reshape(torch.arange(0, batch_size) * 9, [-1, 1]) + col_inds, [-1, 1]) | |
updates = torch.reshape( | |
torch.stack( | |
[ | |
-vec[:, 2], vec[:, 1], vec[:, 2], -vec[:, 0], -vec[:, 1], | |
vec[:, 0] | |
], | |
dim=1), [-1]) | |
out_shape = [batch_size * 9] | |
res = torch.Tensor(np.zeros(out_shape[0])).to(device=device) | |
res[np.array(indices.flatten())] = updates | |
res = torch.reshape(res, [batch_size, 3, 3]) | |
return res | |
def batch_rodrigues(theta): | |
""" | |
Theta is Nx3 | |
""" | |
device = theta.device | |
batch_size = theta.shape[0] | |
angle = (torch.norm(theta + 1e-8, p=2, dim=1)).unsqueeze(-1) | |
r = (torch.div(theta, angle)).unsqueeze(-1) | |
angle = angle.unsqueeze(-1) | |
cos = torch.cos(angle) | |
sin = torch.sin(angle) | |
outer = torch.matmul(r, r.transpose(1,2)) | |
eyes = torch.eye(3).unsqueeze(0).repeat([batch_size, 1, 1]).to(device=device) | |
H = batch_skew(r, batch_size=batch_size) | |
R = cos * eyes + (1 - cos) * outer + sin * H | |
return R | |
def batch_lrotmin(theta): | |
""" | |
Output of this is used to compute joint-to-pose blend shape mapping. | |
Equation 9 in SMPL paper. | |
Args: | |
pose: `Tensor`, N x 72 vector holding the axis-angle rep of K joints. | |
This includes the global rotation so K=24 | |
Returns | |
diff_vec : `Tensor`: N x 207 rotation matrix of 23=(K-1) joints with identity subtracted., | |
""" | |
# Ignore global rotation | |
theta = theta[:,3:] | |
Rs = batch_rodrigues(torch.reshape(theta, [-1,3])) | |
lrotmin = torch.reshape(Rs - torch.eye(3), [-1, 207]) | |
return lrotmin | |
def batch_global_rigid_transformation(Rs, Js, parent, rotate_base=False): | |
""" | |
Computes absolute joint locations given pose. | |
rotate_base: if True, rotates the global rotation by 90 deg in x axis. | |
if False, this is the original SMPL coordinate. | |
Args: | |
Rs: N x 24 x 3 x 3 rotation vector of K joints | |
Js: N x 24 x 3, joint locations before posing | |
parent: 24 holding the parent id for each index | |
Returns | |
new_J : `Tensor`: N x 24 x 3 location of absolute joints | |
A : `Tensor`: N x 24 4 x 4 relative joint transformations for LBS. | |
""" | |
device = Rs.device | |
if rotate_base: | |
print('Flipping the SMPL coordinate frame!!!!') | |
rot_x = torch.Tensor([[1, 0, 0], [0, -1, 0], [0, 0, -1]]) | |
rot_x = torch.reshape(torch.repeat(rot_x, [N, 1]), [N, 3, 3]) # In tf it was tile | |
root_rotation = torch.matmul(Rs[:, 0, :, :], rot_x) | |
else: | |
root_rotation = Rs[:, 0, :, :] | |
# Now Js is N x 24 x 3 x 1 | |
Js = Js.unsqueeze(-1) | |
N = Rs.shape[0] | |
def make_A(R, t): | |
# Rs is N x 3 x 3, ts is N x 3 x 1 | |
R_homo = torch.nn.functional.pad(R, (0,0,0,1,0,0)) | |
t_homo = torch.cat([t, torch.ones([N, 1, 1]).to(device=device)], 1) | |
return torch.cat([R_homo, t_homo], 2) | |
A0 = make_A(root_rotation, Js[:, 0]) | |
results = [A0] | |
for i in range(1, parent.shape[0]): | |
j_here = Js[:, i] - Js[:, parent[i]] | |
A_here = make_A(Rs[:, i], j_here) | |
res_here = torch.matmul( | |
results[parent[i]], A_here) | |
results.append(res_here) | |
# 10 x 24 x 4 x 4 | |
results = torch.stack(results, dim=1) | |
new_J = results[:, :, :3, 3] | |
# --- Compute relative A: Skinning is based on | |
# how much the bone moved (not the final location of the bone) | |
# but (final_bone - init_bone) | |
# --- | |
Js_w0 = torch.cat([Js, torch.zeros([N, 35, 1, 1]).to(device=device)], 2) | |
init_bone = torch.matmul(results, Js_w0) | |
# Append empty 4 x 3: | |
init_bone = torch.nn.functional.pad(init_bone, (3,0,0,0,0,0,0,0)) | |
A = results - init_bone | |
return new_J, A | |
######################################################################################### | |
def get_bone_length_scales(part_list, betas_logscale): | |
leg_joints = list(range(7,11)) + list(range(11,15)) + list(range(17,21)) + list(range(21,25)) | |
tail_joints = list(range(25, 32)) | |
ear_joints = [33, 34] | |
neck_joints = [15, 6] # ? | |
core_joints = [4, 5] # ? | |
mouth_joints = [16, 32] | |
log_scales = torch.zeros(betas_logscale.shape[0], 35).to(betas_logscale.device) | |
for ind, part in enumerate(part_list): | |
if part == 'legs_l': | |
log_scales[:, leg_joints] = betas_logscale[:, ind][:, None] | |
elif part == 'tail_l': | |
log_scales[:, tail_joints] = betas_logscale[:, ind][:, None] | |
elif part == 'ears_l': | |
log_scales[:, ear_joints] = betas_logscale[:, ind][:, None] | |
elif part == 'neck_l': | |
log_scales[:, neck_joints] = betas_logscale[:, ind][:, None] | |
elif part == 'core_l': | |
log_scales[:, core_joints] = betas_logscale[:, ind][:, None] | |
elif part == 'head_l': | |
log_scales[:, mouth_joints] = betas_logscale[:, ind][:, None] | |
else: | |
pass | |
all_scales = torch.exp(log_scales) | |
return all_scales[:, 1:] # don't count root | |
def get_beta_scale_mask(part_list): | |
# which joints belong to which bodypart | |
leg_joints = list(range(7,11)) + list(range(11,15)) + list(range(17,21)) + list(range(21,25)) | |
tail_joints = list(range(25, 32)) | |
ear_joints = [33, 34] | |
neck_joints = [15, 6] # ? | |
core_joints = [4, 5] # ? | |
mouth_joints = [16, 32] | |
n_b_log = len(part_list) #betas_logscale.shape[1] # 8 # 6 | |
beta_scale_mask = torch.zeros(35, 3, n_b_log) # .to(betas_logscale.device) | |
for ind, part in enumerate(part_list): | |
if part == 'legs_l': | |
beta_scale_mask[leg_joints, [2], [ind]] = 1.0 # Leg lengthening | |
elif part == 'legs_f': | |
beta_scale_mask[leg_joints, [0], [ind]] = 1.0 # Leg fatness | |
beta_scale_mask[leg_joints, [1], [ind]] = 1.0 # Leg fatness | |
elif part == 'tail_l': | |
beta_scale_mask[tail_joints, [0], [ind]] = 1.0 # Tail lengthening | |
elif part == 'tail_f': | |
beta_scale_mask[tail_joints, [1], [ind]] = 1.0 # Tail fatness | |
beta_scale_mask[tail_joints, [2], [ind]] = 1.0 # Tail fatness | |
elif part == 'ears_y': | |
beta_scale_mask[ear_joints, [1], [ind]] = 1.0 # Ear y | |
elif part == 'ears_l': | |
beta_scale_mask[ear_joints, [2], [ind]] = 1.0 # Ear z | |
elif part == 'neck_l': | |
beta_scale_mask[neck_joints, [0], [ind]] = 1.0 # Neck lengthening | |
elif part == 'neck_f': | |
beta_scale_mask[neck_joints, [1], [ind]] = 1.0 # Neck fatness | |
beta_scale_mask[neck_joints, [2], [ind]] = 1.0 # Neck fatness | |
elif part == 'core_l': | |
beta_scale_mask[core_joints, [0], [ind]] = 1.0 # Core lengthening | |
# beta_scale_mask[core_joints, [1], [ind]] = 1.0 # Core fatness (height) | |
elif part == 'core_fs': | |
beta_scale_mask[core_joints, [2], [ind]] = 1.0 # Core fatness (side) | |
elif part == 'head_l': | |
beta_scale_mask[mouth_joints, [0], [ind]] = 1.0 # Head lengthening | |
elif part == 'head_f': | |
beta_scale_mask[mouth_joints, [1], [ind]] = 1.0 # Head fatness 0 | |
beta_scale_mask[mouth_joints, [2], [ind]] = 1.0 # Head fatness 1 | |
else: | |
print(part + ' not available') | |
raise ValueError | |
beta_scale_mask = torch.transpose( | |
beta_scale_mask.reshape(35*3, n_b_log), 0, 1) | |
return beta_scale_mask | |
def batch_global_rigid_transformation_biggs(Rs, Js, parent, scale_factors_3x3, rotate_base = False, betas_logscale=None, opts=None): | |
""" | |
Computes absolute joint locations given pose. | |
rotate_base: if True, rotates the global rotation by 90 deg in x axis. | |
if False, this is the original SMPL coordinate. | |
Args: | |
Rs: N x 24 x 3 x 3 rotation vector of K joints | |
Js: N x 24 x 3, joint locations before posing | |
parent: 24 holding the parent id for each index | |
Returns | |
new_J : `Tensor`: N x 24 x 3 location of absolute joints | |
A : `Tensor`: N x 24 4 x 4 relative joint transformations for LBS. | |
""" | |
if rotate_base: | |
print('Flipping the SMPL coordinate frame!!!!') | |
rot_x = torch.Tensor([[1, 0, 0], [0, -1, 0], [0, 0, -1]]) | |
rot_x = torch.reshape(torch.repeat(rot_x, [N, 1]), [N, 3, 3]) # In tf it was tile | |
root_rotation = torch.matmul(Rs[:, 0, :, :], rot_x) | |
else: | |
root_rotation = Rs[:, 0, :, :] | |
# Now Js is N x 24 x 3 x 1 | |
Js = Js.unsqueeze(-1) | |
N = Rs.shape[0] | |
Js_orig = Js.clone() | |
def make_A(R, t): | |
# Rs is N x 3 x 3, ts is N x 3 x 1 | |
R_homo = torch.nn.functional.pad(R, (0,0,0,1,0,0)) | |
t_homo = torch.cat([t, torch.ones([N, 1, 1]).to(Rs.device)], 1) | |
return torch.cat([R_homo, t_homo], 2) | |
A0 = make_A(root_rotation, Js[:, 0]) | |
results = [A0] | |
for i in range(1, parent.shape[0]): | |
j_here = Js[:, i] - Js[:, parent[i]] | |
try: | |
s_par_inv = torch.inverse(scale_factors_3x3[:, parent[i]]) | |
except: | |
# import pdb; pdb.set_trace() | |
s_par_inv = torch.max(scale_factors_3x3[:, parent[i]], 0.01*torch.eye((3))[None, :, :].to(scale_factors_3x3.device)) | |
rot = Rs[:, i] | |
s = scale_factors_3x3[:, i] | |
rot_new = s_par_inv @ rot @ s | |
A_here = make_A(rot_new, j_here) | |
res_here = torch.matmul( | |
results[parent[i]], A_here) | |
results.append(res_here) | |
# 10 x 24 x 4 x 4 | |
results = torch.stack(results, dim=1) | |
# scale updates | |
new_J = results[:, :, :3, 3] | |
# --- Compute relative A: Skinning is based on | |
# how much the bone moved (not the final location of the bone) | |
# but (final_bone - init_bone) | |
# --- | |
Js_w0 = torch.cat([Js_orig, torch.zeros([N, 35, 1, 1]).to(Rs.device)], 2) | |
init_bone = torch.matmul(results, Js_w0) | |
# Append empty 4 x 3: | |
init_bone = torch.nn.functional.pad(init_bone, (3,0,0,0,0,0,0,0)) | |
A = results - init_bone | |
return new_J, A |