import torch import torch.nn as nn import torch.nn.functional as F import math import copy from transformers import PreTrainedModel from .configuration_emage_audio import EmageAudioConfig, EmageVQVAEConvConfig, EmageVAEConvConfig from .processing_emage_audio import Quantizer, VQEncoderV5, VQDecoderV5, WavEncoder, MLP, PeriodicPositionalEncoding, VQEncoderV6, recover_from_mask_ts, rotation_6d_to_axis_angle, velocity2position, axis_angle_to_rotation_6d, rotation_6d_to_matrix, matrix_to_axis_angle, axis_angle_to_matrix, matrix_to_rotation_6d def inverse_selection_tensor(filtered_t, selection_array, n): selection_array = torch.from_numpy(selection_array).cuda() original_shape_t = torch.zeros((n, 165)).cuda() selected_indices = torch.where(selection_array == 1)[0] for i in range(n): original_shape_t[i, selected_indices] = filtered_t[i] return original_shape_t class EmageVAEConv(PreTrainedModel): config_class = EmageVAEConvConfig base_model_prefix = "emage_vaeconv" def __init__(self, config): super().__init__(config) self.encoder = VQEncoderV5(config) self.decoder = VQDecoderV5(config) def forward(self, inputs): pre_latent = self.encoder(inputs) rec_pose = self.decoder(pre_latent) return { "rec_pose": rec_pose } class EmageVQVAEConv(PreTrainedModel): config_class = EmageVQVAEConvConfig base_model_prefix = "emage_vqvaeconv" def __init__(self, config): super().__init__(config) self.encoder = VQEncoderV5(config) self.quantizer = Quantizer(config.vae_codebook_size, config.vae_length, config.vae_quantizer_lambda) self.decoder = VQDecoderV5(config) def forward(self, inputs): pre_latent = self.encoder(inputs) embedding_loss, vq_latent, _, perplexity = self.quantizer(pre_latent) rec_pose = self.decoder(vq_latent) return {"poses_feat":vq_latent,"embedding_loss":embedding_loss,"perplexity":perplexity,"rec_pose": rec_pose} def map2index(self, inputs): pre_latent = self.encoder(inputs) index = self.quantizer.map2index(pre_latent) return index def map2latent(self, inputs): pre_latent = self.encoder(inputs) index = self.quantizer.map2index(pre_latent) z_q = self.quantizer.get_codebook_entry(index) return z_q def decode(self, index): z_q = self.quantizer.get_codebook_entry(index) rec_pose = self.decoder(z_q) return rec_pose def decode_from_latent(self, latent): # print(latent.shape) z_flattened = latent.contiguous().view(-1, self.quantizer.e_dim) d = torch.sum(z_flattened**2, dim=1, keepdim=True) + torch.sum(self.quantizer.embedding.weight**2, dim=1) - 2*torch.matmul(z_flattened, self.quantizer.embedding.weight.t()) min_encoding_indices = torch.argmin(d, dim=1) # print(min_encoding_indices.shape) indices = min_encoding_indices.view(latent.shape[0], latent.shape[1]) z_q = self.quantizer.get_codebook_entry(indices) rec_pose = self.decoder(z_q) return rec_pose class EmageVQModel(nn.Module): def __init__(self, face_model, upper_model, hands_model, lower_model, global_model): super().__init__() self.joint_mask_upper = [ False, False, False, True, False, False, True, False, False, True, False, False, True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False ] self.joint_mask_lower = [ True, True, True, False, True, True, False, True, True, False, True, True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False ] self.vq_model_face = face_model self.vq_model_upper = upper_model self.vq_model_hands = hands_model self.vq_model_lower = lower_model self.global_motion = global_model def spilt_inputs(self, smplx_body_rot6d, expression, tar_contact=None, tar_trans=None): bs, t, j6 = smplx_body_rot6d.shape smplx_body_rot6d = smplx_body_rot6d.reshape(bs, t, j6//6, 6) jaw_rot6d = smplx_body_rot6d[:, :, 22:23, :].reshape(bs, t, 6) face = torch.cat([jaw_rot6d, expression], dim=2) upper_rot6d = smplx_body_rot6d[:, :,self.joint_mask_upper, :].reshape(bs, t, 78) hands_rot6d = smplx_body_rot6d[:, :,25:55, :].reshape(bs, t, 180) lower_rot6d = smplx_body_rot6d[:, :,self.joint_mask_lower, :].reshape(bs, t, 54) tar_contact = torch.zeros(bs, t, 4, device=smplx_body_rot6d.device) if tar_contact is None else tar_contact tar_trans = torch.zeros(bs, t, 3, device=smplx_body_rot6d.device) if tar_trans is None else tar_trans lower = torch.cat([lower_rot6d, tar_trans, tar_contact], dim=2) return dict(face=face, upper=upper_rot6d, hands=hands_rot6d, lower=lower) def map2index(self, smplx_body_rot6d, expression, tar_contact=None, tar_trans=None): inputs = self.spilt_inputs(smplx_body_rot6d, expression, tar_contact=tar_contact, tar_trans=tar_trans) face_index = self.vq_model_face.map2index(inputs["face"]) upper_index = self.vq_model_upper.map2index(inputs["upper"]) hands_index = self.vq_model_hands.map2index(inputs["hands"]) lower_index = self.vq_model_lower.map2index(inputs["lower"]) return dict(face=face_index, upper=upper_index, hands=hands_index, lower=lower_index) def map2latent(self, smplx_body_rot6d, expression, tar_contact=None, tar_trans=None): inputs = self.spilt_inputs(smplx_body_rot6d, expression,tar_contact=tar_contact, tar_trans=tar_trans) face_latent = self.vq_model_face.map2latent(inputs["face"]) upper_latent = self.vq_model_upper.map2latent(inputs["upper"]) hands_latent = self.vq_model_hands.map2latent(inputs["hands"]) lower_latent = self.vq_model_lower.map2latent(inputs["lower"]) return dict(face=face_latent, upper=upper_latent, hands=hands_latent, lower=lower_latent) def decode(self, face_index=None, upper_index=None, hands_index=None, lower_index=None, face_latent=None, upper_latent=None, hands_latent=None, lower_latent=None, get_global_motion=False, ref_trans=None): for input_tensor in [face_index, upper_index, hands_index, lower_index, face_latent, upper_latent, hands_latent, lower_latent]: if input_tensor is not None: bs, t = input_tensor.shape[:2] break if face_index is not None: face_mix = self.vq_model_face.decode(face_index) # bs, t, 106 face_jaw_6d, expression = face_mix[:, :, :6], face_mix[:, :, 6:] face_jaw = rotation_6d_to_axis_angle(face_jaw_6d) elif face_latent is not None: face_mix = self.vq_model_face.decode_from_latent(face_latent) face_jaw_6d, expression = face_mix[:, :, :6], face_mix[:, :, 6:] face_jaw = rotation_6d_to_axis_angle(face_jaw_6d) else: face_jaw = torch.zeros(bs, t, 3, device=self.vq_model_face.device) expression = torch.zeros(bs, t, 100, device=self.vq_model_face.device) if upper_index is not None: # print(upper_index) upper_6d = self.vq_model_upper.decode(upper_index) # bs, t, 78 upper = rotation_6d_to_axis_angle(upper_6d.reshape(bs, t, -1, 6)).reshape(bs, t, -1) elif upper_latent is not None: upper_6d = self.vq_model_upper.decode_from_latent(upper_latent) upper = rotation_6d_to_axis_angle(upper_6d.reshape(bs, t, -1, 6)).reshape(bs, t, -1) else: upper = torch.zeros(bs, t, 39, device=self.vq_model_upper.device) if hands_index is not None: hands_6d = self.vq_model_hands.decode(hands_index) hands = rotation_6d_to_axis_angle(hands_6d.reshape(bs, t, -1, 6)).reshape(bs, t, -1) elif hands_latent is not None: hands_6d = self.vq_model_hands.decode_from_latent(hands_latent) hands = rotation_6d_to_axis_angle(hands_6d.reshape(bs, t, -1, 6)).reshape(bs, t, -1) else: hands = torch.zeros(bs, t, 90, device=self.vq_model_hands.device) if lower_index is not None: lower_mix = self.vq_model_lower.decode(lower_index) lower_6d, transfoot = lower_mix[:, :, :-7], lower_mix[:, :, -7:] lower = rotation_6d_to_axis_angle(lower_6d.reshape(bs, t, -1, 6)).reshape(bs, t, -1) elif lower_latent is not None: lower_mix = self.vq_model_lower.decode_from_latent(lower_latent) lower_6d, transfoot = lower_mix[:, :, :-7], lower_mix[:, :, -7:] lower = rotation_6d_to_axis_angle(lower_6d.reshape(bs, t, -1, 6)).reshape(bs, t, -1) else: lower = torch.zeros(bs, t, 27, device=self.vq_model_lower.device) transfoot = torch.zeros(bs, t, 7, device=self.vq_model_lower.device) lower_6d = axis_angle_to_rotation_6d(lower.reshape(bs, t, -1, 3)).reshape(bs, t, -1) lower_mix = torch.cat([lower_6d, transfoot], dim=-1) upper2all = recover_from_mask_ts(upper, self.joint_mask_upper) hands2all = recover_from_mask_ts(hands, [False]*25+[True]*30) lower2all = recover_from_mask_ts(lower, self.joint_mask_lower) all_motion_axis_angle = upper2all + hands2all + lower2all all_motion_axis_angle[:, :, 22*3:22*3+3] = face_jaw all_motion_rot6d = axis_angle_to_rotation_6d(all_motion_axis_angle.reshape(bs, t, 55, 3)).reshape(bs, t, 55*6) all_motion4inference = torch.cat([all_motion_rot6d, transfoot], dim=2) # 330 + 3 + 4 global_motion = None if get_global_motion: global_motion = self.get_global_motion(lower_mix, ref_trans) return dict(expression=expression, all_motion4inference=all_motion4inference, motion_axis_angle=all_motion_axis_angle, trans=global_motion) def get_global_motion(self, lower_body, ref_trans): global_motion = self.global_motion(lower_body) rec_trans_v_s = global_motion["rec_pose"][:, :, 54:57] if len(ref_trans.shape) == 2: ref_trans = ref_trans.unsqueeze(0).repeat(rec_trans_v_s.shape[0], 1, 1) rec_x_trans = velocity2position(rec_trans_v_s[:, :, 0:1], 1/30, ref_trans[:, 0, 0:1]) rec_z_trans = velocity2position(rec_trans_v_s[:, :, 2:3], 1/30, ref_trans[:, 0, 2:3]) rec_y_trans = rec_trans_v_s[:,:,1:2] global_motion = torch.cat([rec_x_trans, rec_y_trans, rec_z_trans], dim=-1) return global_motion class EmageAudioModel(PreTrainedModel): config_class = EmageAudioConfig base_model_prefix = "emage_audio" def __init__(self, config: EmageAudioConfig): super().__init__(config) self.cfg = config # audio encoder self.audio_encoder_face = WavEncoder(self.cfg.audio_f) self.audio_encoder_body = WavEncoder(self.cfg.audio_f) #speaker id self.speaker_embedding_body = nn.Embedding(self.cfg.speaker_dims, self.cfg.hidden_size) self.speaker_embedding_face = nn.Embedding(self.cfg.speaker_dims, self.cfg.hidden_size) # mask embedding self.mask_embedding = nn.Parameter(torch.zeros(1,1,self.cfg.pose_dims+3+4)) nn.init.normal_(self.mask_embedding, 0, self.cfg.hidden_size**-0.5) # nn.init.normal_(self.speaker_embedding_body.weight, 0, self.cfg.hidden_size/2**-0.5) # nn.init.normal_(self.speaker_embedding_face.weight, 0, self.cfg.hidden_size*2**-0.5) # motion pre encoder args_top = copy.deepcopy(self.cfg) args_top.vae_layer = 3 args_top.vae_length = self.cfg.motion_f args_top.vae_test_dim = self.cfg.pose_dims+3+4 self.motion_encoder = VQEncoderV6(args_top) self.bodyhints_face = MLP(self.cfg.motion_f, self.cfg.hidden_size, self.cfg.motion_f) self.bodyhints_body = MLP(self.cfg.motion_f, self.cfg.hidden_size, self.cfg.motion_f) # motion encoder self.audio_body_motion_proj = nn.Linear(self.cfg.audio_f, self.cfg.hidden_size) self.moton_proj = nn.Linear(self.cfg.motion_f, self.cfg.hidden_size) self.position_embeddings = PeriodicPositionalEncoding(self.cfg.hidden_size, period=self.cfg.pose_length, max_seq_len=self.cfg.pose_length) self.transformer_en_layer = nn.TransformerEncoderLayer(d_model=self.cfg.hidden_size,nhead=4,dim_feedforward=self.cfg.hidden_size*2) self.motion_self_encoder = nn.TransformerEncoder(self.transformer_en_layer, num_layers=1) # coss attn self.audio_motion_cross_attn_layer = nn.TransformerDecoderLayer(d_model=self.cfg.hidden_size,nhead=4,dim_feedforward=self.cfg.hidden_size*2) self.audio_motion_cross_attn = nn.TransformerDecoder(self.audio_motion_cross_attn_layer, num_layers=8) # feed forward self.motion2latent_upper = MLP(self.cfg.hidden_size, self.cfg.hidden_size, self.cfg.hidden_size) self.motion2latent_hands = MLP(self.cfg.hidden_size, self.cfg.hidden_size, self.cfg.hidden_size) self.motion2latent_lower = MLP(self.cfg.hidden_size, self.cfg.hidden_size, self.cfg.hidden_size) # refine self.body_motion_decoder_upper = nn.TransformerDecoder(self.audio_motion_cross_attn_layer, num_layers=1) self.body_motion_decoder_hands = nn.TransformerDecoder(self.audio_motion_cross_attn_layer, num_layers=1) self.body_motion_decoder_lower = nn.TransformerDecoder(self.audio_motion_cross_attn_layer, num_layers=1) # deocder self.motion_out_proj_upper = nn.Linear(self.cfg.hidden_size, self.cfg.vae_codebook_size) self.motion_out_proj_hands = nn.Linear(self.cfg.hidden_size, self.cfg.vae_codebook_size) self.motion_out_proj_lower = nn.Linear(self.cfg.hidden_size, self.cfg.vae_codebook_size) self.motion_cls_upper = MLP(self.cfg.vae_codebook_size, self.cfg.hidden_size, self.cfg.vae_codebook_size) self.motion_cls_hands = MLP(self.cfg.vae_codebook_size, self.cfg.hidden_size, self.cfg.vae_codebook_size) self.motion_cls_lower = MLP(self.cfg.vae_codebook_size, self.cfg.hidden_size, self.cfg.vae_codebook_size) # face decoder self.audio_face_motion_proj = nn.Linear(self.cfg.audio_f+self.cfg.motion_f, self.cfg.hidden_size) self.face_motion_decoder = nn.TransformerDecoder(self.audio_motion_cross_attn_layer, num_layers=4) self.face_out_proj = nn.Linear(self.cfg.hidden_size, self.cfg.vae_codebook_size) self.face_cls = MLP(self.cfg.vae_codebook_size, self.cfg.hidden_size, self.cfg.vae_codebook_size) def forward(self, audio, speaker_id, masked_motion, mask, use_audio=True): audio2face_fea = self.audio_encoder_face(audio) audio2body_fea = self.audio_encoder_body(audio) bs, t, _ = audio2face_fea.shape speaker_motion_fea_proj = self.speaker_embedding_body(speaker_id).repeat(1, t, 1) speaker_face_fea_proj = self.speaker_embedding_face(speaker_id).repeat(1, t, 1) # mask motion masked_embeddings = self.mask_embedding.expand_as(masked_motion) masked_motion = torch.where(mask==1, masked_embeddings, masked_motion) # motion token (spatial hints) body_hint = self.motion_encoder(masked_motion) body_hint_body = self.bodyhints_body(body_hint) body_hint_face = self.bodyhints_face(body_hint) audio2face_fea_proj = self.audio_face_motion_proj(torch.cat([audio2face_fea, body_hint_face], dim=2)) # audio2face_fea_proj = self.position_embeddings(audio2face_fea_proj) # audio2face_fea_proj = speaker_face_fea_proj + audio2face_fea_proj face_proj = self.position_embeddings(speaker_face_fea_proj) decode_face = self.face_motion_decoder(tgt=face_proj.permute(1,0,2), memory=audio2face_fea_proj.permute(1,0,2)).permute(1,0,2) face_latent = self.face_out_proj(decode_face) classify_face = self.face_cls(face_latent) # motion self attn (temporal) masked_motion_proj = self.moton_proj(body_hint_body) masked_motion_proj = self.position_embeddings(masked_motion_proj) masked_motion_proj = speaker_motion_fea_proj + masked_motion_proj motion_fea = self.motion_self_encoder(masked_motion_proj.permute(1,0,2)).permute(1,0,2) # audio_cross_attn audio2body_fea_proj = self.audio_body_motion_proj(audio2body_fea) # audio2body_fea_proj = self.position_embeddings(audio2body_fea_proj) # audio2body_fea_proj = speaker_motion_fea_proj + audio2body_fea_proj motion_fea = motion_fea + speaker_motion_fea_proj motion_fea = self.position_embeddings(motion_fea) audio2body_fea_cross = self.audio_motion_cross_attn(tgt=motion_fea.permute(1,0,2), memory=audio2body_fea_proj.permute(1,0,2)).permute(1,0,2) if not use_audio: audio2body_fea_cross = audio2body_fea_cross * 0. motion_fea = motion_fea + audio2body_fea_cross # mlp upper_latent = self.motion2latent_upper(motion_fea) hands_latent = self.motion2latent_hands(motion_fea) lower_latent = self.motion2latent_lower(motion_fea) # refine motion_upper_refine = self.body_motion_decoder_upper(tgt=upper_latent.permute(1,0,2)+speaker_motion_fea_proj.permute(1,0,2), memory=(hands_latent+lower_latent).permute(1,0,2)).permute(1,0,2) motion_hands_refine = self.body_motion_decoder_hands(tgt=hands_latent.permute(1,0,2)+speaker_motion_fea_proj.permute(1,0,2), memory=(upper_latent+lower_latent).permute(1,0,2)).permute(1,0,2) motion_lower_refine = self.body_motion_decoder_lower(tgt=lower_latent.permute(1,0,2)+speaker_motion_fea_proj.permute(1,0,2), memory=(upper_latent+hands_latent).permute(1,0,2)).permute(1,0,2) upper_latent = self.motion_out_proj_upper(upper_latent + motion_upper_refine) hands_latent = self.motion_out_proj_hands(hands_latent + motion_hands_refine) lower_latent = self.motion_out_proj_lower(lower_latent + motion_lower_refine) # decode body classify_upper = self.motion_cls_upper(upper_latent) classify_hands = self.motion_cls_hands(hands_latent) classify_lower = self.motion_cls_lower(lower_latent) return { "rec_face": face_latent, "rec_upper": upper_latent, "rec_hands": hands_latent, "rec_lower": lower_latent, "cls_face": classify_face, "cls_upper": classify_upper, "cls_hands": classify_hands, "cls_lower": classify_lower, } def inference(self, audio, speaker_id, vq_model, masked_motion=None, mask=None): # generate default mask and masked motion if not provided length = audio.shape[1] * 30 // 16000 bs = audio.shape[0] fake_axis_angle = torch.zeros(bs, length, 55, 3).to(audio.device) fake_motion = axis_angle_to_rotation_6d(fake_axis_angle).reshape(bs, length, -1) fake_foot_and_trans = torch.zeros(bs, length, 7).to(audio.device) fake_motion = torch.cat([fake_motion, fake_foot_and_trans], dim=-1) if masked_motion is not None: fake_motion[:, :masked_motion.shape[1]] = masked_motion masked_motion = fake_motion fake_mask = torch.ones_like(masked_motion) if mask is not None: fake_mask[:, :mask.shape[1]] = mask mask = fake_mask # print(length, masked_motion.shape, mask.shape) # Autoregressive inference bs, total_len, c = masked_motion.shape window = self.cfg.pose_length pre_frames = self.cfg.seed_frames rounds = (total_len - pre_frames) // (window - pre_frames) remain = (total_len - pre_frames) % (window - pre_frames) rec_all_face = [] rec_all_lower = [] rec_all_upper = [] rec_all_hands = [] cls_all_face = [] cls_all_lower = [] cls_all_upper = [] cls_all_hands = [] last_motion = masked_motion[:, :pre_frames, :] for i in range(rounds): start_idx = i*(window - pre_frames) end_idx = start_idx + window window_mask = mask[:, start_idx:end_idx, :].clone() window_motion = masked_motion[:, start_idx:end_idx, :].clone() window_motion[:, :pre_frames, :] = torch.where( (window_mask[:, :pre_frames, :] == 0), masked_motion[:, start_idx:start_idx+pre_frames, :], last_motion, ) window_mask[:, :pre_frames, :] = 0 audio_slice_len = (end_idx - start_idx)*(16000//30) audio_slice = audio[:, start_idx*(16000//30) : start_idx*(16000//30)+audio_slice_len] # print(i, audio_slice.shape, speaker_id.shape, window_motion.shape, window_mask.shape) net_out_val = self.forward(audio_slice, speaker_id, masked_motion=window_motion, mask=window_mask, use_audio=True) _, cls_face = torch.max(F.log_softmax(net_out_val["cls_face"], dim=2), dim=2) _, cls_upper = torch.max(F.log_softmax(net_out_val["cls_upper"], dim=2), dim=2) _, cls_hands = torch.max(F.log_softmax(net_out_val["cls_hands"], dim=2), dim=2) _, cls_lower = torch.max(F.log_softmax(net_out_val["cls_lower"], dim=2), dim=2) face_latent = net_out_val["rec_face"] if self.cfg.lf > 0 and self.cfg.cf == 0 else None upper_latent = net_out_val["rec_upper"] if self.cfg.lu > 0 and self.cfg.cu == 0 else None hands_latent = net_out_val["rec_hands"] if self.cfg.lh > 0 and self.cfg.ch == 0 else None lower_latent = net_out_val["rec_lower"] if self.cfg.ll > 0 and self.cfg.cl == 0 else None face_index = cls_face if self.cfg.cf > 0 else None upper_index = cls_upper if self.cfg.cu > 0 else None hands_index = cls_hands if self.cfg.ch > 0 else None lower_index = cls_lower if self.cfg.cl > 0 else None decode_dict = vq_model.decode( face_latent=face_latent, upper_latent=upper_latent, lower_latent=lower_latent, hands_latent=hands_latent, face_index=face_index, upper_index=upper_index, lower_index=lower_index, hands_index=hands_index,) # decode_dict = vq_model.decode(face_latent=net_out_val["rec_face"], upper_index=net_out_val["cls_upper"], hands_index=net_out_val["cls_hands"], lower_index=net_out_val["cls_lower"]) last_motion = decode_dict["all_motion4inference"][:, -pre_frames:, :] rec_all_face.append(net_out_val["rec_face"][:, :-pre_frames, :]) rec_all_upper.append(net_out_val["rec_upper"][:, :-pre_frames, :]) rec_all_hands.append(net_out_val["rec_hands"][:, :-pre_frames, :]) rec_all_lower.append(net_out_val["rec_lower"][:, :-pre_frames, :]) cls_all_face.append(net_out_val["cls_face"][:, :-pre_frames]) cls_all_upper.append(net_out_val["cls_upper"][:, :-pre_frames]) cls_all_hands.append(net_out_val["cls_hands"][:, :-pre_frames]) cls_all_lower.append(net_out_val["cls_lower"][:, :-pre_frames]) if remain > pre_frames: final_start = rounds*(window - pre_frames) final_end = final_start + pre_frames + remain final_mask = mask[:, final_start:final_end, :].clone() final_motion = masked_motion[:, final_start:final_end, :].clone() final_motion[:, :pre_frames, :] = torch.where( (final_mask[:, :pre_frames, :] == 0), masked_motion[:, final_start:final_start+pre_frames, :], last_motion, ) final_mask[:, :pre_frames, :] = 0 audio_slice_len = (final_end - final_start)*(16000//30) audio_slice = audio[:, final_start*(16000//30) : final_start*(16000//30)+audio_slice_len] net_out_val = self.forward(audio_slice, speaker_id, masked_motion=final_motion, mask=final_mask, use_audio=True) _, cls_face = torch.max(F.log_softmax(net_out_val["cls_face"], dim=2), dim=2) _, cls_upper = torch.max(F.log_softmax(net_out_val["cls_upper"], dim=2), dim=2) _, cls_hands = torch.max(F.log_softmax(net_out_val["cls_hands"], dim=2), dim=2) _, cls_lower = torch.max(F.log_softmax(net_out_val["cls_lower"], dim=2), dim=2) face_latent = net_out_val["rec_face"] if self.cfg.lf > 0 and self.cfg.cf == 0 else None upper_latent = net_out_val["rec_upper"] if self.cfg.lu > 0 and self.cfg.cu == 0 else None hands_latent = net_out_val["rec_hands"] if self.cfg.lh > 0 and self.cfg.ch == 0 else None lower_latent = net_out_val["rec_lower"] if self.cfg.ll > 0 and self.cfg.cl == 0 else None face_index = cls_face if self.cfg.cf > 0 else None upper_index = cls_upper if self.cfg.cu > 0 else None hands_index = cls_hands if self.cfg.ch > 0 else None lower_index = cls_lower if self.cfg.cl > 0 else None decode_dict = vq_model.decode( face_latent=face_latent, upper_latent=upper_latent, lower_latent=lower_latent, hands_latent=hands_latent, face_index=face_index, upper_index=upper_index, lower_index=lower_index, hands_index=hands_index,) rec_all_face.append(net_out_val["rec_face"]) rec_all_upper.append(net_out_val["rec_upper"]) rec_all_hands.append(net_out_val["rec_hands"]) rec_all_lower.append(net_out_val["rec_lower"]) cls_all_face.append(net_out_val["cls_face"]) cls_all_upper.append(net_out_val["cls_upper"]) cls_all_hands.append(net_out_val["cls_hands"]) cls_all_lower.append(net_out_val["cls_lower"]) rec_all_face = torch.cat(rec_all_face, dim=1) rec_all_upper = torch.cat(rec_all_upper, dim=1) rec_all_hands = torch.cat(rec_all_hands, dim=1) rec_all_lower = torch.cat(rec_all_lower, dim=1) cls_all_face = torch.cat(cls_all_face, dim=1) cls_all_upper = torch.cat(cls_all_upper, dim=1) cls_all_hands = torch.cat(cls_all_hands, dim=1) cls_all_lower = torch.cat(cls_all_lower, dim=1) return { "rec_face": rec_all_face, "rec_upper": rec_all_upper, "rec_hands": rec_all_hands, "rec_lower": rec_all_lower, "cls_face": cls_all_face, "cls_upper": cls_all_upper, "cls_hands": cls_all_hands, "cls_lower": cls_all_lower, }