EMAGE / emage_utils /motion_io.py
H-Liu1997's picture
newapp
b03a8f2
import numpy as np
import torch
import smplx
MASK_DICT = {
"local_upper": [
False, False, False, True, False, False, True, False, False, True,
False, False, True, True, True, True, True, True, True, True,
True, True, False, False, False, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True
],
"local_full": [False] + [True]*54
}
def select_with_mask(motion: np.ndarray, mask: list[bool]) -> np.ndarray:
mask_arr = np.array(mask, dtype=bool)
j = len(mask_arr)
c_channels = motion.shape[-1] // j
new_shape = motion.shape[:-1] + (j, c_channels)
motion = motion.reshape(new_shape)
selected_motion = motion[..., mask_arr, :]
final_shape = selected_motion.shape[:-2] + (selected_motion.shape[-2]*selected_motion.shape[-1],)
selected_motion = selected_motion.reshape(final_shape)
return selected_motion
def recover_from_mask(selected_motion: np.ndarray, mask: list[bool]) -> np.ndarray:
mask_arr = np.array(mask, dtype=bool)
j = len(mask_arr)
# Infer c_channels from selected_motion's last dimension and sum(mask)
c_channels = selected_motion.shape[-1] // mask_arr.sum()
new_shape = selected_motion.shape[:-1] + (mask_arr.sum(), c_channels)
selected_motion = selected_motion.reshape(new_shape)
out_shape = selected_motion.shape[:-2] + (j, c_channels)
recovered = np.zeros(out_shape, dtype=selected_motion.dtype)
recovered[..., mask_arr, :] = selected_motion
final_shape = recovered.shape[:-2] + (j * c_channels,)
recovered = recovered.reshape(final_shape)
return recovered
def select_with_mask_ts(motion: torch.Tensor, mask: list[bool]) -> torch.Tensor:
mask_arr = torch.tensor(mask, dtype=torch.bool, device=motion.device)
j = len(mask_arr)
c_channels = motion.shape[-1] // j
new_shape = motion.shape[:-1] + (j, c_channels)
motion = motion.reshape(new_shape)
selected_motion = motion[..., mask_arr, :]
final_shape = selected_motion.shape[:-2] + (selected_motion.shape[-2]*selected_motion.shape[-1],)
selected_motion = selected_motion.reshape(final_shape)
return selected_motion
def recover_from_mask_ts(selected_motion: torch.Tensor, mask: list[bool]) -> torch.Tensor:
device = selected_motion.device
dtype = selected_motion.dtype
mask_arr = torch.tensor(mask, dtype=torch.bool, device=device)
j = len(mask_arr)
sum_mask = mask_arr.sum().item()
c_channels = selected_motion.shape[-1] // sum_mask
new_shape = selected_motion.shape[:-1] + (sum_mask, c_channels)
selected_motion = selected_motion.reshape(new_shape)
out_shape = list(selected_motion.shape[:-2]) + [j, c_channels]
recovered = torch.zeros(out_shape, dtype=dtype, device=device)
recovered[..., mask_arr, :] = selected_motion
final_shape = list(recovered.shape[:-2]) + [j * c_channels]
recovered = recovered.reshape(final_shape)
return recovered
def time_upsample_numpy(data: np.ndarray, k: int) -> np.ndarray:
# data: (..., t, c)
# output: (..., k*t, c)
if k == 1:
return data.copy()
shape = data.shape
t = shape[-2]
c = shape[-1]
# original and new time indices
original_t = np.arange(t)
new_t = np.linspace(0, t - 1, k * t)
# reshape to (M, c, t)
reshaped = data.reshape(-1, t, c).transpose(0, 2, 1)
M = reshaped.shape[0]
reshaped = reshaped.reshape(M * c, t)
# find interpolation indices
idx = np.searchsorted(original_t, new_t, side='right') - 1
idx = np.clip(idx, 0, t - 2)
idx1 = idx + 1
x0 = original_t[idx]
x1 = original_t[idx1]
w = (new_t - x0) / (x1 - x0)
f0 = reshaped[:, idx]
f1 = reshaped[:, idx1]
out = f0 + (f1 - f0) * w
out = out.reshape(M, c, k * t).transpose(0, 2, 1)
final_shape = shape[:-2] + (k * t, c)
return out.reshape(final_shape)
def beat_format_save(
save_path: str,
motion_data: np.ndarray,
mask: list[bool] = None,
betas: np.ndarray = None,
expressions: np.ndarray = None,
trans: np.ndarray = None,
upsample: int = None,
):
if betas is None:
betas = np.zeros((motion_data.shape[0], 300), dtype=motion_data.dtype)
if expressions is None:
expressions = np.zeros((motion_data.shape[0], 100), dtype=motion_data.dtype)
if trans is None:
smplx_model = smplx.create(
"./emage_evaltools/smplx_models/",
model_type='smplx',
gender='NEUTRAL_2020',
use_face_contour=False,
num_betas=300,
num_expression_coeffs=100,
ext='npz',
use_pca=False
).eval()
betas_ts = torch.from_numpy(betas[0:1]).float()
output = smplx_model(
betas=betas_ts,
transl=torch.zeros(1, 3),
expression=torch.zeros(1, 100),
jaw_pose=torch.zeros(1, 3),
global_orient=torch.zeros(1, 3),
body_pose=torch.zeros(1, 63),
left_hand_pose=torch.zeros(1, 45),
right_hand_pose=torch.zeros(1, 45),
return_joints=True,
leye_pose=torch.zeros(1, 3),
reye_pose=torch.zeros(1, 3)
)
trans = (output["joints"][:, 10, :] + output["joints"][:, 11, :]) / 2
# print(trans)
trans = -trans.repeat(motion_data.shape[0], 1).numpy()
if mask is not None:
motion_data = recover_from_mask(motion_data, mask)
if upsample is not None and upsample > 1:
motion_data = time_upsample_numpy(motion_data, upsample)
betas = time_upsample_numpy(betas, upsample)
expressions = time_upsample_numpy(expressions, upsample)
trans = time_upsample_numpy(trans, upsample)
np.savez(
save_path,
betas=betas[0],
poses=motion_data,
expressions=expressions,
trans=trans,
model='smplx2020',
gender='neutral',
mocap_frame_rate=30
)
def beat_format_load(load_path: str, mask: list[bool] = None):
data = np.load(load_path, allow_pickle=True)
poses = data['poses']
betas = data['betas']
expressions = data['expressions']
trans = data['trans']
if mask is not None:
poses = select_with_mask(poses, mask)
return {
"poses": poses,
"betas": betas,
"expressions": expressions,
"trans": trans
}