Spaces:
Sleeping
Sleeping
import numpy as np | |
import torch | |
import smplx | |
MASK_DICT = { | |
"local_upper": [ | |
False, False, False, True, False, False, True, False, False, True, | |
False, False, True, True, True, True, True, True, True, True, | |
True, True, False, False, False, True, True, True, True, True, | |
True, True, True, True, True, True, True, True, True, True, | |
True, True, True, True, True, True, True, True, True, True, | |
True, True, True, True, True | |
], | |
"local_full": [False] + [True]*54 | |
} | |
def select_with_mask(motion: np.ndarray, mask: list[bool]) -> np.ndarray: | |
mask_arr = np.array(mask, dtype=bool) | |
j = len(mask_arr) | |
c_channels = motion.shape[-1] // j | |
new_shape = motion.shape[:-1] + (j, c_channels) | |
motion = motion.reshape(new_shape) | |
selected_motion = motion[..., mask_arr, :] | |
final_shape = selected_motion.shape[:-2] + (selected_motion.shape[-2]*selected_motion.shape[-1],) | |
selected_motion = selected_motion.reshape(final_shape) | |
return selected_motion | |
def recover_from_mask(selected_motion: np.ndarray, mask: list[bool]) -> np.ndarray: | |
mask_arr = np.array(mask, dtype=bool) | |
j = len(mask_arr) | |
# Infer c_channels from selected_motion's last dimension and sum(mask) | |
c_channels = selected_motion.shape[-1] // mask_arr.sum() | |
new_shape = selected_motion.shape[:-1] + (mask_arr.sum(), c_channels) | |
selected_motion = selected_motion.reshape(new_shape) | |
out_shape = selected_motion.shape[:-2] + (j, c_channels) | |
recovered = np.zeros(out_shape, dtype=selected_motion.dtype) | |
recovered[..., mask_arr, :] = selected_motion | |
final_shape = recovered.shape[:-2] + (j * c_channels,) | |
recovered = recovered.reshape(final_shape) | |
return recovered | |
def select_with_mask_ts(motion: torch.Tensor, mask: list[bool]) -> torch.Tensor: | |
mask_arr = torch.tensor(mask, dtype=torch.bool, device=motion.device) | |
j = len(mask_arr) | |
c_channels = motion.shape[-1] // j | |
new_shape = motion.shape[:-1] + (j, c_channels) | |
motion = motion.reshape(new_shape) | |
selected_motion = motion[..., mask_arr, :] | |
final_shape = selected_motion.shape[:-2] + (selected_motion.shape[-2]*selected_motion.shape[-1],) | |
selected_motion = selected_motion.reshape(final_shape) | |
return selected_motion | |
def recover_from_mask_ts(selected_motion: torch.Tensor, mask: list[bool]) -> torch.Tensor: | |
device = selected_motion.device | |
dtype = selected_motion.dtype | |
mask_arr = torch.tensor(mask, dtype=torch.bool, device=device) | |
j = len(mask_arr) | |
sum_mask = mask_arr.sum().item() | |
c_channels = selected_motion.shape[-1] // sum_mask | |
new_shape = selected_motion.shape[:-1] + (sum_mask, c_channels) | |
selected_motion = selected_motion.reshape(new_shape) | |
out_shape = list(selected_motion.shape[:-2]) + [j, c_channels] | |
recovered = torch.zeros(out_shape, dtype=dtype, device=device) | |
recovered[..., mask_arr, :] = selected_motion | |
final_shape = list(recovered.shape[:-2]) + [j * c_channels] | |
recovered = recovered.reshape(final_shape) | |
return recovered | |
def time_upsample_numpy(data: np.ndarray, k: int) -> np.ndarray: | |
# data: (..., t, c) | |
# output: (..., k*t, c) | |
if k == 1: | |
return data.copy() | |
shape = data.shape | |
t = shape[-2] | |
c = shape[-1] | |
# original and new time indices | |
original_t = np.arange(t) | |
new_t = np.linspace(0, t - 1, k * t) | |
# reshape to (M, c, t) | |
reshaped = data.reshape(-1, t, c).transpose(0, 2, 1) | |
M = reshaped.shape[0] | |
reshaped = reshaped.reshape(M * c, t) | |
# find interpolation indices | |
idx = np.searchsorted(original_t, new_t, side='right') - 1 | |
idx = np.clip(idx, 0, t - 2) | |
idx1 = idx + 1 | |
x0 = original_t[idx] | |
x1 = original_t[idx1] | |
w = (new_t - x0) / (x1 - x0) | |
f0 = reshaped[:, idx] | |
f1 = reshaped[:, idx1] | |
out = f0 + (f1 - f0) * w | |
out = out.reshape(M, c, k * t).transpose(0, 2, 1) | |
final_shape = shape[:-2] + (k * t, c) | |
return out.reshape(final_shape) | |
def beat_format_save( | |
save_path: str, | |
motion_data: np.ndarray, | |
mask: list[bool] = None, | |
betas: np.ndarray = None, | |
expressions: np.ndarray = None, | |
trans: np.ndarray = None, | |
upsample: int = None, | |
): | |
if betas is None: | |
betas = np.zeros((motion_data.shape[0], 300), dtype=motion_data.dtype) | |
if expressions is None: | |
expressions = np.zeros((motion_data.shape[0], 100), dtype=motion_data.dtype) | |
if trans is None: | |
smplx_model = smplx.create( | |
"./emage_evaltools/smplx_models/", | |
model_type='smplx', | |
gender='NEUTRAL_2020', | |
use_face_contour=False, | |
num_betas=300, | |
num_expression_coeffs=100, | |
ext='npz', | |
use_pca=False | |
).eval() | |
betas_ts = torch.from_numpy(betas[0:1]).float() | |
output = smplx_model( | |
betas=betas_ts, | |
transl=torch.zeros(1, 3), | |
expression=torch.zeros(1, 100), | |
jaw_pose=torch.zeros(1, 3), | |
global_orient=torch.zeros(1, 3), | |
body_pose=torch.zeros(1, 63), | |
left_hand_pose=torch.zeros(1, 45), | |
right_hand_pose=torch.zeros(1, 45), | |
return_joints=True, | |
leye_pose=torch.zeros(1, 3), | |
reye_pose=torch.zeros(1, 3) | |
) | |
trans = (output["joints"][:, 10, :] + output["joints"][:, 11, :]) / 2 | |
# print(trans) | |
trans = -trans.repeat(motion_data.shape[0], 1).numpy() | |
if mask is not None: | |
motion_data = recover_from_mask(motion_data, mask) | |
if upsample is not None and upsample > 1: | |
motion_data = time_upsample_numpy(motion_data, upsample) | |
betas = time_upsample_numpy(betas, upsample) | |
expressions = time_upsample_numpy(expressions, upsample) | |
trans = time_upsample_numpy(trans, upsample) | |
np.savez( | |
save_path, | |
betas=betas[0], | |
poses=motion_data, | |
expressions=expressions, | |
trans=trans, | |
model='smplx2020', | |
gender='neutral', | |
mocap_frame_rate=30 | |
) | |
def beat_format_load(load_path: str, mask: list[bool] = None): | |
data = np.load(load_path, allow_pickle=True) | |
poses = data['poses'] | |
betas = data['betas'] | |
expressions = data['expressions'] | |
trans = data['trans'] | |
if mask is not None: | |
poses = select_with_mask(poses, mask) | |
return { | |
"poses": poses, | |
"betas": betas, | |
"expressions": expressions, | |
"trans": trans | |
} |