Spaces:
Runtime error
Runtime error
""" | |
original from https://github.com/vchoutas/smplx | |
modified by Vassilis and Yao | |
""" | |
import pickle | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from .lbs import ( | |
JointsFromVerticesSelector, | |
Struct, | |
find_dynamic_lmk_idx_and_bcoords, | |
lbs, | |
to_np, | |
to_tensor, | |
vertices2landmarks, | |
) | |
# SMPLX | |
J14_NAMES = [ | |
"right_ankle", | |
"right_knee", | |
"right_hip", | |
"left_hip", | |
"left_knee", | |
"left_ankle", | |
"right_wrist", | |
"right_elbow", | |
"right_shoulder", | |
"left_shoulder", | |
"left_elbow", | |
"left_wrist", | |
"neck", | |
"head", | |
] | |
SMPLX_names = [ | |
"pelvis", | |
"left_hip", | |
"right_hip", | |
"spine1", | |
"left_knee", | |
"right_knee", | |
"spine2", | |
"left_ankle", | |
"right_ankle", | |
"spine3", | |
"left_foot", | |
"right_foot", | |
"neck", | |
"left_collar", | |
"right_collar", | |
"head", | |
"left_shoulder", | |
"right_shoulder", | |
"left_elbow", | |
"right_elbow", | |
"left_wrist", | |
"right_wrist", | |
"jaw", | |
"left_eye_smplx", | |
"right_eye_smplx", | |
"left_index1", | |
"left_index2", | |
"left_index3", | |
"left_middle1", | |
"left_middle2", | |
"left_middle3", | |
"left_pinky1", | |
"left_pinky2", | |
"left_pinky3", | |
"left_ring1", | |
"left_ring2", | |
"left_ring3", | |
"left_thumb1", | |
"left_thumb2", | |
"left_thumb3", | |
"right_index1", | |
"right_index2", | |
"right_index3", | |
"right_middle1", | |
"right_middle2", | |
"right_middle3", | |
"right_pinky1", | |
"right_pinky2", | |
"right_pinky3", | |
"right_ring1", | |
"right_ring2", | |
"right_ring3", | |
"right_thumb1", | |
"right_thumb2", | |
"right_thumb3", | |
"right_eye_brow1", | |
"right_eye_brow2", | |
"right_eye_brow3", | |
"right_eye_brow4", | |
"right_eye_brow5", | |
"left_eye_brow5", | |
"left_eye_brow4", | |
"left_eye_brow3", | |
"left_eye_brow2", | |
"left_eye_brow1", | |
"nose1", | |
"nose2", | |
"nose3", | |
"nose4", | |
"right_nose_2", | |
"right_nose_1", | |
"nose_middle", | |
"left_nose_1", | |
"left_nose_2", | |
"right_eye1", | |
"right_eye2", | |
"right_eye3", | |
"right_eye4", | |
"right_eye5", | |
"right_eye6", | |
"left_eye4", | |
"left_eye3", | |
"left_eye2", | |
"left_eye1", | |
"left_eye6", | |
"left_eye5", | |
"right_mouth_1", | |
"right_mouth_2", | |
"right_mouth_3", | |
"mouth_top", | |
"left_mouth_3", | |
"left_mouth_2", | |
"left_mouth_1", | |
"left_mouth_5", | |
"left_mouth_4", | |
"mouth_bottom", | |
"right_mouth_4", | |
"right_mouth_5", | |
"right_lip_1", | |
"right_lip_2", | |
"lip_top", | |
"left_lip_2", | |
"left_lip_1", | |
"left_lip_3", | |
"lip_bottom", | |
"right_lip_3", | |
"right_contour_1", | |
"right_contour_2", | |
"right_contour_3", | |
"right_contour_4", | |
"right_contour_5", | |
"right_contour_6", | |
"right_contour_7", | |
"right_contour_8", | |
"contour_middle", | |
"left_contour_8", | |
"left_contour_7", | |
"left_contour_6", | |
"left_contour_5", | |
"left_contour_4", | |
"left_contour_3", | |
"left_contour_2", | |
"left_contour_1", | |
"head_top", | |
"left_big_toe", | |
"left_ear", | |
"left_eye", | |
"left_heel", | |
"left_index", | |
"left_middle", | |
"left_pinky", | |
"left_ring", | |
"left_small_toe", | |
"left_thumb", | |
"nose", | |
"right_big_toe", | |
"right_ear", | |
"right_eye", | |
"right_heel", | |
"right_index", | |
"right_middle", | |
"right_pinky", | |
"right_ring", | |
"right_small_toe", | |
"right_thumb", | |
] | |
extra_names = [ | |
"head_top", | |
"left_big_toe", | |
"left_ear", | |
"left_eye", | |
"left_heel", | |
"left_index", | |
"left_middle", | |
"left_pinky", | |
"left_ring", | |
"left_small_toe", | |
"left_thumb", | |
"nose", | |
"right_big_toe", | |
"right_ear", | |
"right_eye", | |
"right_heel", | |
"right_index", | |
"right_middle", | |
"right_pinky", | |
"right_ring", | |
"right_small_toe", | |
"right_thumb", | |
] | |
SMPLX_names += extra_names | |
part_indices = {} | |
part_indices["body"] = np.array([ | |
0, | |
1, | |
2, | |
3, | |
4, | |
5, | |
6, | |
7, | |
8, | |
9, | |
10, | |
11, | |
12, | |
13, | |
14, | |
15, | |
16, | |
17, | |
18, | |
19, | |
20, | |
21, | |
22, | |
23, | |
24, | |
123, | |
124, | |
125, | |
126, | |
127, | |
132, | |
134, | |
135, | |
136, | |
137, | |
138, | |
143, | |
]) | |
part_indices["torso"] = np.array([ | |
0, | |
1, | |
2, | |
3, | |
6, | |
9, | |
12, | |
13, | |
14, | |
15, | |
16, | |
17, | |
18, | |
19, | |
22, | |
23, | |
24, | |
55, | |
56, | |
57, | |
58, | |
59, | |
76, | |
77, | |
78, | |
79, | |
80, | |
81, | |
82, | |
83, | |
84, | |
85, | |
86, | |
87, | |
88, | |
89, | |
90, | |
91, | |
92, | |
93, | |
94, | |
95, | |
96, | |
97, | |
98, | |
99, | |
100, | |
101, | |
102, | |
103, | |
104, | |
105, | |
106, | |
107, | |
108, | |
109, | |
110, | |
111, | |
112, | |
113, | |
114, | |
115, | |
116, | |
117, | |
118, | |
119, | |
120, | |
121, | |
122, | |
123, | |
124, | |
125, | |
126, | |
127, | |
128, | |
129, | |
130, | |
131, | |
132, | |
133, | |
134, | |
135, | |
136, | |
137, | |
138, | |
139, | |
140, | |
141, | |
142, | |
143, | |
144, | |
]) | |
part_indices["head"] = np.array([ | |
12, | |
15, | |
22, | |
23, | |
24, | |
55, | |
56, | |
57, | |
58, | |
59, | |
60, | |
61, | |
62, | |
63, | |
64, | |
65, | |
66, | |
67, | |
68, | |
69, | |
70, | |
71, | |
72, | |
73, | |
74, | |
75, | |
76, | |
77, | |
78, | |
79, | |
80, | |
81, | |
82, | |
83, | |
84, | |
85, | |
86, | |
87, | |
88, | |
89, | |
90, | |
91, | |
92, | |
93, | |
94, | |
95, | |
96, | |
97, | |
98, | |
99, | |
100, | |
101, | |
102, | |
103, | |
104, | |
105, | |
106, | |
107, | |
108, | |
109, | |
110, | |
111, | |
112, | |
113, | |
114, | |
115, | |
116, | |
117, | |
118, | |
119, | |
120, | |
121, | |
122, | |
123, | |
125, | |
126, | |
134, | |
136, | |
137, | |
]) | |
part_indices["face"] = np.array([ | |
55, | |
56, | |
57, | |
58, | |
59, | |
60, | |
61, | |
62, | |
63, | |
64, | |
65, | |
66, | |
67, | |
68, | |
69, | |
70, | |
71, | |
72, | |
73, | |
74, | |
75, | |
76, | |
77, | |
78, | |
79, | |
80, | |
81, | |
82, | |
83, | |
84, | |
85, | |
86, | |
87, | |
88, | |
89, | |
90, | |
91, | |
92, | |
93, | |
94, | |
95, | |
96, | |
97, | |
98, | |
99, | |
100, | |
101, | |
102, | |
103, | |
104, | |
105, | |
106, | |
107, | |
108, | |
109, | |
110, | |
111, | |
112, | |
113, | |
114, | |
115, | |
116, | |
117, | |
118, | |
119, | |
120, | |
121, | |
122, | |
]) | |
part_indices["upper"] = np.array([ | |
12, | |
13, | |
14, | |
55, | |
56, | |
57, | |
58, | |
59, | |
60, | |
61, | |
62, | |
63, | |
64, | |
65, | |
66, | |
67, | |
68, | |
69, | |
70, | |
71, | |
72, | |
73, | |
74, | |
75, | |
76, | |
77, | |
78, | |
79, | |
80, | |
81, | |
82, | |
83, | |
84, | |
85, | |
86, | |
87, | |
88, | |
89, | |
90, | |
91, | |
92, | |
93, | |
94, | |
95, | |
96, | |
97, | |
98, | |
99, | |
100, | |
101, | |
102, | |
103, | |
104, | |
105, | |
106, | |
107, | |
108, | |
109, | |
110, | |
111, | |
112, | |
113, | |
114, | |
115, | |
116, | |
117, | |
118, | |
119, | |
120, | |
121, | |
122, | |
]) | |
part_indices["hand"] = np.array([ | |
20, | |
21, | |
25, | |
26, | |
27, | |
28, | |
29, | |
30, | |
31, | |
32, | |
33, | |
34, | |
35, | |
36, | |
37, | |
38, | |
39, | |
40, | |
41, | |
42, | |
43, | |
44, | |
45, | |
46, | |
47, | |
48, | |
49, | |
50, | |
51, | |
52, | |
53, | |
54, | |
128, | |
129, | |
130, | |
131, | |
133, | |
139, | |
140, | |
141, | |
142, | |
144, | |
]) | |
part_indices["left_hand"] = np.array([ | |
20, | |
25, | |
26, | |
27, | |
28, | |
29, | |
30, | |
31, | |
32, | |
33, | |
34, | |
35, | |
36, | |
37, | |
38, | |
39, | |
128, | |
129, | |
130, | |
131, | |
133, | |
]) | |
part_indices["right_hand"] = np.array([ | |
21, | |
40, | |
41, | |
42, | |
43, | |
44, | |
45, | |
46, | |
47, | |
48, | |
49, | |
50, | |
51, | |
52, | |
53, | |
54, | |
139, | |
140, | |
141, | |
142, | |
144, | |
]) | |
# kinematic tree | |
head_kin_chain = [15, 12, 9, 6, 3, 0] | |
# --smplx joints | |
# 00 - Global | |
# 01 - L_Thigh | |
# 02 - R_Thigh | |
# 03 - Spine | |
# 04 - L_Calf | |
# 05 - R_Calf | |
# 06 - Spine1 | |
# 07 - L_Foot | |
# 08 - R_Foot | |
# 09 - Spine2 | |
# 10 - L_Toes | |
# 11 - R_Toes | |
# 12 - Neck | |
# 13 - L_Shoulder | |
# 14 - R_Shoulder | |
# 15 - Head | |
# 16 - L_UpperArm | |
# 17 - R_UpperArm | |
# 18 - L_ForeArm | |
# 19 - R_ForeArm | |
# 20 - L_Hand | |
# 21 - R_Hand | |
# 22 - Jaw | |
# 23 - L_Eye | |
# 24 - R_Eye | |
class SMPLX(nn.Module): | |
""" | |
Given smplx parameters, this class generates a differentiable SMPLX function | |
which outputs a mesh and 3D joints | |
""" | |
def __init__(self, config): | |
super(SMPLX, self).__init__() | |
# print("creating the SMPLX Decoder") | |
ss = np.load(config.smplx_model_path, allow_pickle=True) | |
smplx_model = Struct(**ss) | |
self.dtype = torch.float32 | |
self.register_buffer( | |
"faces_tensor", | |
to_tensor(to_np(smplx_model.f, dtype=np.int64), dtype=torch.long), | |
) | |
# The vertices of the template model | |
self.register_buffer( | |
"v_template", to_tensor(to_np(smplx_model.v_template), dtype=self.dtype) | |
) | |
# The shape components and expression | |
# expression space is the same as FLAME | |
shapedirs = to_tensor(to_np(smplx_model.shapedirs), dtype=self.dtype) | |
shapedirs = torch.cat( | |
[ | |
shapedirs[:, :, :config.n_shape], | |
shapedirs[:, :, 300:300 + config.n_exp], | |
], | |
2, | |
) | |
self.register_buffer("shapedirs", shapedirs) | |
# The pose components | |
num_pose_basis = smplx_model.posedirs.shape[-1] | |
posedirs = np.reshape(smplx_model.posedirs, [-1, num_pose_basis]).T | |
self.register_buffer("posedirs", to_tensor(to_np(posedirs), dtype=self.dtype)) | |
self.register_buffer( | |
"J_regressor", to_tensor(to_np(smplx_model.J_regressor), dtype=self.dtype) | |
) | |
parents = to_tensor(to_np(smplx_model.kintree_table[0])).long() | |
parents[0] = -1 | |
self.register_buffer("parents", parents) | |
self.register_buffer("lbs_weights", to_tensor(to_np(smplx_model.weights), dtype=self.dtype)) | |
# for face keypoints | |
self.register_buffer( | |
"lmk_faces_idx", torch.tensor(smplx_model.lmk_faces_idx, dtype=torch.long) | |
) | |
self.register_buffer( | |
"lmk_bary_coords", | |
torch.tensor(smplx_model.lmk_bary_coords, dtype=self.dtype), | |
) | |
self.register_buffer( | |
"dynamic_lmk_faces_idx", | |
torch.tensor(smplx_model.dynamic_lmk_faces_idx, dtype=torch.long), | |
) | |
self.register_buffer( | |
"dynamic_lmk_bary_coords", | |
torch.tensor(smplx_model.dynamic_lmk_bary_coords, dtype=self.dtype), | |
) | |
# pelvis to head, to calculate head yaw angle, then find the dynamic landmarks | |
self.register_buffer("head_kin_chain", torch.tensor(head_kin_chain, dtype=torch.long)) | |
# -- initialize parameters | |
# shape and expression | |
self.register_buffer( | |
"shape_params", | |
nn.Parameter(torch.zeros([1, config.n_shape], dtype=self.dtype), requires_grad=False), | |
) | |
self.register_buffer( | |
"expression_params", | |
nn.Parameter(torch.zeros([1, config.n_exp], dtype=self.dtype), requires_grad=False), | |
) | |
# pose: represented as rotation matrx [number of joints, 3, 3] | |
self.register_buffer( | |
"global_pose", | |
nn.Parameter( | |
torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(1, 1, 1), | |
requires_grad=False, | |
), | |
) | |
self.register_buffer( | |
"head_pose", | |
nn.Parameter( | |
torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(1, 1, 1), | |
requires_grad=False, | |
), | |
) | |
self.register_buffer( | |
"neck_pose", | |
nn.Parameter( | |
torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(1, 1, 1), | |
requires_grad=False, | |
), | |
) | |
self.register_buffer( | |
"jaw_pose", | |
nn.Parameter( | |
torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(1, 1, 1), | |
requires_grad=False, | |
), | |
) | |
self.register_buffer( | |
"eye_pose", | |
nn.Parameter( | |
torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(2, 1, 1), | |
requires_grad=False, | |
), | |
) | |
self.register_buffer( | |
"body_pose", | |
nn.Parameter( | |
torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(21, 1, 1), | |
requires_grad=False, | |
), | |
) | |
self.register_buffer( | |
"left_hand_pose", | |
nn.Parameter( | |
torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(15, 1, 1), | |
requires_grad=False, | |
), | |
) | |
self.register_buffer( | |
"right_hand_pose", | |
nn.Parameter( | |
torch.eye(3, dtype=self.dtype).unsqueeze(0).repeat(15, 1, 1), | |
requires_grad=False, | |
), | |
) | |
if config.extra_joint_path: | |
self.extra_joint_selector = JointsFromVerticesSelector(fname=config.extra_joint_path) | |
self.use_joint_regressor = True | |
self.keypoint_names = SMPLX_names | |
if self.use_joint_regressor: | |
with open(config.j14_regressor_path, "rb") as f: | |
j14_regressor = pickle.load(f, encoding="latin1") | |
source = [] | |
target = [] | |
for idx, name in enumerate(self.keypoint_names): | |
if name in J14_NAMES: | |
source.append(idx) | |
target.append(J14_NAMES.index(name)) | |
source = np.asarray(source) | |
target = np.asarray(target) | |
self.register_buffer("source_idxs", torch.from_numpy(source)) | |
self.register_buffer("target_idxs", torch.from_numpy(target)) | |
self.register_buffer( | |
"extra_joint_regressor", | |
torch.from_numpy(j14_regressor).to(torch.float32) | |
) | |
self.part_indices = part_indices | |
def forward( | |
self, | |
shape_params=None, | |
expression_params=None, | |
global_pose=None, | |
body_pose=None, | |
jaw_pose=None, | |
eye_pose=None, | |
left_hand_pose=None, | |
right_hand_pose=None, | |
): | |
""" | |
Args: | |
shape_params: [N, number of shape parameters] | |
expression_params: [N, number of expression parameters] | |
global_pose: pelvis pose, [N, 1, 3, 3] | |
body_pose: [N, 21, 3, 3] | |
jaw_pose: [N, 1, 3, 3] | |
eye_pose: [N, 2, 3, 3] | |
left_hand_pose: [N, 15, 3, 3] | |
right_hand_pose: [N, 15, 3, 3] | |
Returns: | |
vertices: [N, number of vertices, 3] | |
landmarks: [N, number of landmarks (68 face keypoints), 3] | |
joints: [N, number of smplx joints (145), 3] | |
""" | |
if shape_params is None: | |
batch_size = global_pose.shape[0] | |
shape_params = self.shape_params.expand(batch_size, -1) | |
else: | |
batch_size = shape_params.shape[0] | |
if expression_params is None: | |
expression_params = self.expression_params.expand(batch_size, -1) | |
if global_pose is None: | |
global_pose = self.global_pose.unsqueeze(0).expand(batch_size, -1, -1, -1) | |
if body_pose is None: | |
body_pose = self.body_pose.unsqueeze(0).expand(batch_size, -1, -1, -1) | |
if jaw_pose is None: | |
jaw_pose = self.jaw_pose.unsqueeze(0).expand(batch_size, -1, -1, -1) | |
if eye_pose is None: | |
eye_pose = self.eye_pose.unsqueeze(0).expand(batch_size, -1, -1, -1) | |
if left_hand_pose is None: | |
left_hand_pose = self.left_hand_pose.unsqueeze(0).expand(batch_size, -1, -1, -1) | |
if right_hand_pose is None: | |
right_hand_pose = self.right_hand_pose.unsqueeze(0).expand(batch_size, -1, -1, -1) | |
shape_components = torch.cat([shape_params, expression_params], dim=1) | |
full_pose = torch.cat( | |
[ | |
global_pose, | |
body_pose, | |
jaw_pose, | |
eye_pose, | |
left_hand_pose, | |
right_hand_pose, | |
], | |
dim=1, | |
) | |
template_vertices = self.v_template.unsqueeze(0).expand(batch_size, -1, -1) | |
# smplx | |
vertices, joints = lbs( | |
shape_components, | |
full_pose, | |
template_vertices, | |
self.shapedirs, | |
self.posedirs, | |
self.J_regressor, | |
self.parents, | |
self.lbs_weights, | |
dtype=self.dtype, | |
pose2rot=False, | |
) | |
# face dynamic landmarks | |
lmk_faces_idx = self.lmk_faces_idx.unsqueeze(dim=0).expand(batch_size, -1) | |
lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).expand(batch_size, -1, -1) | |
dyn_lmk_faces_idx, dyn_lmk_bary_coords = find_dynamic_lmk_idx_and_bcoords( | |
vertices, | |
full_pose, | |
self.dynamic_lmk_faces_idx, | |
self.dynamic_lmk_bary_coords, | |
self.head_kin_chain, | |
) | |
lmk_faces_idx = torch.cat([lmk_faces_idx, dyn_lmk_faces_idx], 1) | |
lmk_bary_coords = torch.cat([lmk_bary_coords, dyn_lmk_bary_coords], 1) | |
landmarks = vertices2landmarks(vertices, self.faces_tensor, lmk_faces_idx, lmk_bary_coords) | |
final_joint_set = [joints, landmarks] | |
if hasattr(self, "extra_joint_selector"): | |
# Add any extra joints that might be needed | |
extra_joints = self.extra_joint_selector(vertices, self.faces_tensor) | |
final_joint_set.append(extra_joints) | |
# Create the final joint set | |
joints = torch.cat(final_joint_set, dim=1) | |
# if self.use_joint_regressor: | |
# reg_joints = torch.einsum("ji,bik->bjk", | |
# self.extra_joint_regressor, vertices) | |
# joints[:, self.source_idxs] = reg_joints[:, self.target_idxs] | |
return vertices, landmarks, joints | |
def pose_abs2rel(self, global_pose, body_pose, abs_joint="head"): | |
"""change absolute pose to relative pose | |
Basic knowledge for SMPLX kinematic tree: | |
absolute pose = parent pose * relative pose | |
Here, pose must be represented as rotation matrix (batch_sizexnx3x3) | |
""" | |
if abs_joint == "head": | |
# Pelvis -> Spine 1, 2, 3 -> Neck -> Head | |
kin_chain = [15, 12, 9, 6, 3, 0] | |
elif abs_joint == "neck": | |
# Pelvis -> Spine 1, 2, 3 -> Neck -> Head | |
kin_chain = [12, 9, 6, 3, 0] | |
elif abs_joint == "right_wrist": | |
# Pelvis -> Spine 1, 2, 3 -> right Collar -> right shoulder | |
# -> right elbow -> right wrist | |
kin_chain = [21, 19, 17, 14, 9, 6, 3, 0] | |
elif abs_joint == "left_wrist": | |
# Pelvis -> Spine 1, 2, 3 -> Left Collar -> Left shoulder | |
# -> Left elbow -> Left wrist | |
kin_chain = [20, 18, 16, 13, 9, 6, 3, 0] | |
else: | |
raise NotImplementedError(f"pose_abs2rel does not support: {abs_joint}") | |
batch_size = global_pose.shape[0] | |
dtype = global_pose.dtype | |
device = global_pose.device | |
full_pose = torch.cat([global_pose, body_pose], dim=1) | |
rel_rot_mat = ( | |
torch.eye(3, device=device, dtype=dtype).unsqueeze_(dim=0).repeat(batch_size, 1, 1) | |
) | |
for idx in kin_chain[1:]: | |
rel_rot_mat = torch.bmm(full_pose[:, idx], rel_rot_mat) | |
# This contains the absolute pose of the parent | |
abs_parent_pose = rel_rot_mat.detach() | |
# Let's assume that in the input this specific joint is predicted as an absolute value | |
abs_joint_pose = body_pose[:, kin_chain[0] - 1] | |
# abs_head = parents(abs_neck) * rel_head ==> rel_head = abs_neck.T * abs_head | |
rel_joint_pose = torch.matmul( | |
abs_parent_pose.reshape(-1, 3, 3).transpose(1, 2), | |
abs_joint_pose.reshape(-1, 3, 3), | |
) | |
# Replace the new relative pose | |
body_pose[:, kin_chain[0] - 1, :, :] = rel_joint_pose | |
return body_pose | |
def pose_rel2abs(self, global_pose, body_pose, abs_joint="head"): | |
"""change relative pose to absolute pose | |
Basic knowledge for SMPLX kinematic tree: | |
absolute pose = parent pose * relative pose | |
Here, pose must be represented as rotation matrix (batch_sizexnx3x3) | |
""" | |
full_pose = torch.cat([global_pose, body_pose], dim=1) | |
if abs_joint == "head": | |
# Pelvis -> Spine 1, 2, 3 -> Neck -> Head | |
kin_chain = [15, 12, 9, 6, 3, 0] | |
elif abs_joint == "neck": | |
# Pelvis -> Spine 1, 2, 3 -> Neck -> Head | |
kin_chain = [12, 9, 6, 3, 0] | |
elif abs_joint == "right_wrist": | |
# Pelvis -> Spine 1, 2, 3 -> right Collar -> right shoulder | |
# -> right elbow -> right wrist | |
kin_chain = [21, 19, 17, 14, 9, 6, 3, 0] | |
elif abs_joint == "left_wrist": | |
# Pelvis -> Spine 1, 2, 3 -> Left Collar -> Left shoulder | |
# -> Left elbow -> Left wrist | |
kin_chain = [20, 18, 16, 13, 9, 6, 3, 0] | |
else: | |
raise NotImplementedError(f"pose_rel2abs does not support: {abs_joint}") | |
rel_rot_mat = torch.eye(3, device=full_pose.device, dtype=full_pose.dtype).unsqueeze_(dim=0) | |
for idx in kin_chain: | |
rel_rot_mat = torch.matmul(full_pose[:, idx], rel_rot_mat) | |
abs_pose = rel_rot_mat[:, None, :, :] | |
return abs_pose | |