|
''' |
|
not exactly the same as the official repo but the results are good |
|
''' |
|
import sys |
|
import os |
|
|
|
from data_utils.lower_body import c_index_3d, c_index_6d |
|
|
|
sys.path.append(os.getcwd()) |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
import torch.nn.functional as F |
|
import math |
|
|
|
from nets.base import TrainWrapperBaseClass |
|
from nets.layers import SeqEncoder1D |
|
from losses import KeypointLoss, L1Loss, KLLoss |
|
from data_utils.utils import get_melspec, get_mfcc_psf, get_mfcc_ta |
|
from nets.utils import denormalize |
|
|
|
class Conv1d_tf(nn.Conv1d): |
|
""" |
|
Conv1d with the padding behavior from TF |
|
modified from https://github.com/mlperf/inference/blob/482f6a3beb7af2fb0bd2d91d6185d5e71c22c55f/others/edge/object_detection/ssd_mobilenet/pytorch/utils.py |
|
""" |
|
|
|
def __init__(self, *args, **kwargs): |
|
super(Conv1d_tf, self).__init__(*args, **kwargs) |
|
self.padding = kwargs.get("padding", "same") |
|
|
|
def _compute_padding(self, input, dim): |
|
input_size = input.size(dim + 2) |
|
filter_size = self.weight.size(dim + 2) |
|
effective_filter_size = (filter_size - 1) * self.dilation[dim] + 1 |
|
out_size = (input_size + self.stride[dim] - 1) // self.stride[dim] |
|
total_padding = max( |
|
0, (out_size - 1) * self.stride[dim] + effective_filter_size - input_size |
|
) |
|
additional_padding = int(total_padding % 2 != 0) |
|
|
|
return additional_padding, total_padding |
|
|
|
def forward(self, input): |
|
if self.padding == "VALID": |
|
return F.conv1d( |
|
input, |
|
self.weight, |
|
self.bias, |
|
self.stride, |
|
padding=0, |
|
dilation=self.dilation, |
|
groups=self.groups, |
|
) |
|
rows_odd, padding_rows = self._compute_padding(input, dim=0) |
|
if rows_odd: |
|
input = F.pad(input, [0, rows_odd]) |
|
|
|
return F.conv1d( |
|
input, |
|
self.weight, |
|
self.bias, |
|
self.stride, |
|
padding=(padding_rows // 2), |
|
dilation=self.dilation, |
|
groups=self.groups, |
|
) |
|
|
|
|
|
def ConvNormRelu(in_channels, out_channels, type='1d', downsample=False, k=None, s=None, norm='bn', padding='valid'): |
|
if k is None and s is None: |
|
if not downsample: |
|
k = 3 |
|
s = 1 |
|
else: |
|
k = 4 |
|
s = 2 |
|
|
|
if type == '1d': |
|
conv_block = Conv1d_tf(in_channels, out_channels, kernel_size=k, stride=s, padding=padding) |
|
if norm == 'bn': |
|
norm_block = nn.BatchNorm1d(out_channels) |
|
elif norm == 'ln': |
|
norm_block = nn.LayerNorm(out_channels) |
|
elif type == '2d': |
|
conv_block = Conv2d_tf(in_channels, out_channels, kernel_size=k, stride=s, padding=padding) |
|
norm_block = nn.BatchNorm2d(out_channels) |
|
else: |
|
assert False |
|
|
|
return nn.Sequential( |
|
conv_block, |
|
norm_block, |
|
nn.LeakyReLU(0.2, True) |
|
) |
|
|
|
class Decoder(nn.Module): |
|
def __init__(self, in_ch, out_ch): |
|
super(Decoder, self).__init__() |
|
self.up1 = nn.Sequential( |
|
ConvNormRelu(in_ch // 2 + in_ch, in_ch // 2), |
|
ConvNormRelu(in_ch // 2, in_ch // 2), |
|
nn.Upsample(scale_factor=2, mode='nearest') |
|
) |
|
self.up2 = nn.Sequential( |
|
ConvNormRelu(in_ch // 4 + in_ch // 2, in_ch // 4), |
|
ConvNormRelu(in_ch // 4, in_ch // 4), |
|
nn.Upsample(scale_factor=2, mode='nearest') |
|
) |
|
self.up3 = nn.Sequential( |
|
ConvNormRelu(in_ch // 8 + in_ch // 4, in_ch // 8), |
|
ConvNormRelu(in_ch // 8, in_ch // 8), |
|
nn.Conv1d(in_ch // 8, out_ch, 1, 1) |
|
) |
|
|
|
def forward(self, x, x1, x2, x3): |
|
x = F.interpolate(x, x3.shape[2]) |
|
x = torch.cat([x, x3], dim=1) |
|
x = self.up1(x) |
|
x = F.interpolate(x, x2.shape[2]) |
|
x = torch.cat([x, x2], dim=1) |
|
x = self.up2(x) |
|
x = F.interpolate(x, x1.shape[2]) |
|
x = torch.cat([x, x1], dim=1) |
|
x = self.up3(x) |
|
return x |
|
|
|
|
|
class EncoderDecoder(nn.Module): |
|
def __init__(self, n_frames, each_dim): |
|
super().__init__() |
|
self.n_frames = n_frames |
|
|
|
self.down1 = nn.Sequential( |
|
ConvNormRelu(64, 64, '1d', False), |
|
ConvNormRelu(64, 128, '1d', False), |
|
) |
|
self.down2 = nn.Sequential( |
|
ConvNormRelu(128, 128, '1d', False), |
|
ConvNormRelu(128, 256, '1d', False), |
|
) |
|
self.down3 = nn.Sequential( |
|
ConvNormRelu(256, 256, '1d', False), |
|
ConvNormRelu(256, 512, '1d', False), |
|
) |
|
self.down4 = nn.Sequential( |
|
ConvNormRelu(512, 512, '1d', False), |
|
ConvNormRelu(512, 1024, '1d', False), |
|
) |
|
|
|
self.down = nn.MaxPool1d(kernel_size=2) |
|
self.up = nn.Upsample(scale_factor=2, mode='nearest') |
|
|
|
self.face_decoder = Decoder(1024, each_dim[0] + each_dim[3]) |
|
self.body_decoder = Decoder(1024, each_dim[1]) |
|
self.hand_decoder = Decoder(1024, each_dim[2]) |
|
|
|
def forward(self, spectrogram, time_steps=None): |
|
if time_steps is None: |
|
time_steps = self.n_frames |
|
|
|
x1 = self.down1(spectrogram) |
|
x = self.down(x1) |
|
x2 = self.down2(x) |
|
x = self.down(x2) |
|
x3 = self.down3(x) |
|
x = self.down(x3) |
|
x = self.down4(x) |
|
x = self.up(x) |
|
|
|
face = self.face_decoder(x, x1, x2, x3) |
|
body = self.body_decoder(x, x1, x2, x3) |
|
hand = self.hand_decoder(x, x1, x2, x3) |
|
|
|
return face, body, hand |
|
|
|
|
|
class Generator(nn.Module): |
|
def __init__(self, |
|
each_dim, |
|
training=False, |
|
device=None |
|
): |
|
super().__init__() |
|
|
|
self.training = training |
|
self.device = device |
|
|
|
self.encoderdecoder = EncoderDecoder(15, each_dim) |
|
|
|
def forward(self, in_spec, time_steps=None): |
|
if time_steps is not None: |
|
self.gen_length = time_steps |
|
|
|
face, body, hand = self.encoderdecoder(in_spec) |
|
out = torch.cat([face, body, hand], dim=1) |
|
out = out.transpose(1, 2) |
|
|
|
return out |
|
|
|
|
|
class Discriminator(nn.Module): |
|
def __init__(self, input_dim): |
|
super().__init__() |
|
self.net = nn.Sequential( |
|
ConvNormRelu(input_dim, 128, '1d'), |
|
ConvNormRelu(128, 256, '1d'), |
|
nn.MaxPool1d(kernel_size=2), |
|
ConvNormRelu(256, 256, '1d'), |
|
ConvNormRelu(256, 512, '1d'), |
|
nn.MaxPool1d(kernel_size=2), |
|
ConvNormRelu(512, 512, '1d'), |
|
ConvNormRelu(512, 1024, '1d'), |
|
nn.MaxPool1d(kernel_size=2), |
|
nn.Conv1d(1024, 1, 1, 1), |
|
nn.Sigmoid() |
|
) |
|
|
|
def forward(self, x): |
|
x = x.transpose(1, 2) |
|
|
|
out = self.net(x) |
|
return out |
|
|
|
|
|
class TrainWrapper(TrainWrapperBaseClass): |
|
def __init__(self, args, config) -> None: |
|
self.args = args |
|
self.config = config |
|
self.device = torch.device(self.args.gpu) |
|
self.global_step = 0 |
|
self.convert_to_6d = self.config.Data.pose.convert_to_6d |
|
self.init_params() |
|
|
|
self.generator = Generator( |
|
each_dim=self.each_dim, |
|
training=not self.args.infer, |
|
device=self.device, |
|
).to(self.device) |
|
self.discriminator = Discriminator( |
|
input_dim=self.each_dim[1] + self.each_dim[2] + 64 |
|
).to(self.device) |
|
if self.convert_to_6d: |
|
self.c_index = c_index_6d |
|
else: |
|
self.c_index = c_index_3d |
|
self.MSELoss = KeypointLoss().to(self.device) |
|
self.L1Loss = L1Loss().to(self.device) |
|
super().__init__(args, config) |
|
|
|
def init_params(self): |
|
scale = 1 |
|
|
|
global_orient = round(0 * scale) |
|
leye_pose = reye_pose = round(0 * scale) |
|
jaw_pose = round(3 * scale) |
|
body_pose = round((63 - 24) * scale) |
|
left_hand_pose = right_hand_pose = round(45 * scale) |
|
|
|
expression = 100 |
|
|
|
b_j = 0 |
|
jaw_dim = jaw_pose |
|
b_e = b_j + jaw_dim |
|
eye_dim = leye_pose + reye_pose |
|
b_b = b_e + eye_dim |
|
body_dim = global_orient + body_pose |
|
b_h = b_b + body_dim |
|
hand_dim = left_hand_pose + right_hand_pose |
|
b_f = b_h + hand_dim |
|
face_dim = expression |
|
|
|
self.dim_list = [b_j, b_e, b_b, b_h, b_f] |
|
self.full_dim = jaw_dim + eye_dim + body_dim + hand_dim |
|
self.pose = int(self.full_dim / round(3 * scale)) |
|
self.each_dim = [jaw_dim, eye_dim + body_dim, hand_dim, face_dim] |
|
|
|
def __call__(self, bat): |
|
assert (not self.args.infer), "infer mode" |
|
self.global_step += 1 |
|
|
|
loss_dict = {} |
|
|
|
aud, poses = bat['aud_feat'].to(self.device).to(torch.float32), bat['poses'].to(self.device).to(torch.float32) |
|
expression = bat['expression'].to(self.device).to(torch.float32) |
|
jaw = poses[:, :3, :] |
|
poses = poses[:, self.c_index, :] |
|
|
|
pred = self.generator(in_spec=aud) |
|
|
|
D_loss, D_loss_dict = self.get_loss( |
|
pred_poses=pred.detach(), |
|
gt_poses=poses, |
|
aud=aud, |
|
mode='training_D', |
|
) |
|
|
|
self.discriminator_optimizer.zero_grad() |
|
D_loss.backward() |
|
self.discriminator_optimizer.step() |
|
|
|
G_loss, G_loss_dict = self.get_loss( |
|
pred_poses=pred, |
|
gt_poses=poses, |
|
aud=aud, |
|
expression=expression, |
|
jaw=jaw, |
|
mode='training_G', |
|
) |
|
self.generator_optimizer.zero_grad() |
|
G_loss.backward() |
|
self.generator_optimizer.step() |
|
|
|
total_loss = None |
|
loss_dict = {} |
|
for key in list(D_loss_dict.keys()) + list(G_loss_dict.keys()): |
|
loss_dict[key] = G_loss_dict.get(key, 0) + D_loss_dict.get(key, 0) |
|
|
|
return total_loss, loss_dict |
|
|
|
def get_loss(self, |
|
pred_poses, |
|
gt_poses, |
|
aud=None, |
|
jaw=None, |
|
expression=None, |
|
mode='training_G', |
|
): |
|
loss_dict = {} |
|
aud = aud.transpose(1, 2) |
|
gt_poses = gt_poses.transpose(1, 2) |
|
gt_aud = torch.cat([gt_poses, aud], dim=2) |
|
pred_aud = torch.cat([pred_poses[:, :, 103:], aud], dim=2) |
|
|
|
if mode == 'training_D': |
|
dis_real = self.discriminator(gt_aud) |
|
dis_fake = self.discriminator(pred_aud) |
|
dis_error = self.MSELoss(torch.ones_like(dis_real).to(self.device), dis_real) + self.MSELoss( |
|
torch.zeros_like(dis_fake).to(self.device), dis_fake) |
|
loss_dict['dis'] = dis_error |
|
|
|
return dis_error, loss_dict |
|
elif mode == 'training_G': |
|
jaw_loss = self.L1Loss(pred_poses[:, :, :3], jaw.transpose(1, 2)) |
|
face_loss = self.MSELoss(pred_poses[:, :, 3:103], expression.transpose(1, 2)) |
|
body_loss = self.L1Loss(pred_poses[:, :, 103:142], gt_poses[:, :, :39]) |
|
hand_loss = self.L1Loss(pred_poses[:, :, 142:], gt_poses[:, :, 39:]) |
|
l1_loss = jaw_loss + face_loss + body_loss + hand_loss |
|
|
|
dis_output = self.discriminator(pred_aud) |
|
gen_error = self.MSELoss(torch.ones_like(dis_output).to(self.device), dis_output) |
|
gen_loss = self.config.Train.weights.keypoint_loss_weight * l1_loss + self.config.Train.weights.gan_loss_weight * gen_error |
|
|
|
loss_dict['gen'] = gen_error |
|
loss_dict['jaw_loss'] = jaw_loss |
|
loss_dict['face_loss'] = face_loss |
|
loss_dict['body_loss'] = body_loss |
|
loss_dict['hand_loss'] = hand_loss |
|
return gen_loss, loss_dict |
|
else: |
|
raise ValueError(mode) |
|
|
|
def infer_on_audio(self, aud_fn, fps=30, initial_pose=None, norm_stats=None, id=None, B=1, **kwargs): |
|
output = [] |
|
assert self.args.infer, "train mode" |
|
self.generator.eval() |
|
|
|
if self.config.Data.pose.normalization: |
|
assert norm_stats is not None |
|
data_mean = norm_stats[0] |
|
data_std = norm_stats[1] |
|
|
|
pre_length = self.config.Data.pose.pre_pose_length |
|
generate_length = self.config.Data.pose.generate_length |
|
|
|
|
|
|
|
|
|
aud_feat = get_mfcc_ta(aud_fn, sr=22000, fps=fps, smlpx=True, type='mfcc').transpose(1, 0) |
|
num_poses_to_generate = aud_feat.shape[-1] |
|
aud_feat = aud_feat[np.newaxis, ...].repeat(B, axis=0) |
|
aud_feat = torch.tensor(aud_feat, dtype=torch.float32).to(self.device) |
|
|
|
with torch.no_grad(): |
|
pred_poses = self.generator(aud_feat) |
|
pred_poses = pred_poses.cpu().numpy() |
|
output = pred_poses.squeeze() |
|
|
|
return output |
|
|
|
def generate(self, aud, id): |
|
self.generator.eval() |
|
pred_poses = self.generator(aud) |
|
return pred_poses |
|
|
|
|
|
if __name__ == '__main__': |
|
from trainer.options import parse_args |
|
|
|
parser = parse_args() |
|
args = parser.parse_args( |
|
['--exp_name', '0', '--data_root', '0', '--speakers', '0', '--pre_pose_length', '4', '--generate_length', '64', |
|
'--infer']) |
|
|
|
generator = TrainWrapper(args) |
|
|
|
aud_fn = '../sample_audio/jon.wav' |
|
initial_pose = torch.randn(64, 108, 4) |
|
norm_stats = (np.random.randn(108), np.random.randn(108)) |
|
output = generator.infer_on_audio(aud_fn, initial_pose, norm_stats) |
|
|
|
print(output.shape) |
|
|