EMAGE / models /emage_audio /processing_emage_audio.py
H-Liu1997's picture
newapp
b03a8f2
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
def _copysign(a, b):
signs_differ = (a < 0) != (b < 0)
return torch.where(signs_differ, -a, a)
def _sqrt_positive_part(x):
ret = torch.zeros_like(x)
positive_mask = x > 0
ret[positive_mask] = torch.sqrt(x[positive_mask])
return ret
def matrix_to_quaternion(matrix):
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
raise ValueError
m00 = matrix[..., 0, 0]
m11 = matrix[..., 1, 1]
m22 = matrix[..., 2, 2]
o0 = 0.5 * _sqrt_positive_part(1 + m00 + m11 + m22)
x = 0.5 * _sqrt_positive_part(1 + m00 - m11 - m22)
y = 0.5 * _sqrt_positive_part(1 - m00 + m11 - m22)
z = 0.5 * _sqrt_positive_part(1 - m00 - m11 + m22)
o1 = _copysign(x, matrix[..., 2, 1] - matrix[..., 1, 2])
o2 = _copysign(y, matrix[..., 0, 2] - matrix[..., 2, 0])
o3 = _copysign(z, matrix[..., 1, 0] - matrix[..., 0, 1])
return torch.stack((o0, o1, o2, o3), -1)
def quaternion_to_axis_angle(quaternions):
norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True)
half_angles = torch.atan2(norms, quaternions[..., :1])
angles = 2 * half_angles
eps = 1e-6
small_angles = angles.abs() < eps
sin_half_angles_over_angles = torch.empty_like(angles)
sin_half_angles_over_angles[~small_angles] = (
torch.sin(half_angles[~small_angles]) / angles[~small_angles]
)
sin_half_angles_over_angles[small_angles] = (
0.5 - (angles[small_angles]*angles[small_angles])/48
)
return quaternions[..., 1:] / sin_half_angles_over_angles
def matrix_to_axis_angle(matrix):
return quaternion_to_axis_angle(matrix_to_quaternion(matrix))
def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:
a1, a2 = d6[..., :3], d6[..., 3:]
b1 = F.normalize(a1, dim=-1)
b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
b2 = F.normalize(b2, dim=-1)
b3 = torch.cross(b1, b2, dim=-1)
return torch.stack((b1, b2, b3), dim=-2)
def rotation_6d_to_axis_angle(rot6d):
return matrix_to_axis_angle(rotation_6d_to_matrix(rot6d))
def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor:
return matrix[..., :2, :].clone().reshape(*matrix.size()[:-2], 6)
def axis_angle_to_quaternion(axis_angle):
angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
half_angles = 0.5 * angles
eps = 1e-6
small_angles = angles.abs() < eps
sin_half_angles_over_angles = torch.empty_like(angles)
sin_half_angles_over_angles[~small_angles] = (
torch.sin(half_angles[~small_angles]) / angles[~small_angles]
)
sin_half_angles_over_angles[small_angles] = (
0.5 - (angles[small_angles] * angles[small_angles]) / 48
)
quaternions = torch.cat(
[torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1
)
return quaternions
def quaternion_to_matrix(quaternions):
r, i, j, k = torch.unbind(quaternions, -1)
two_s = 2.0 / (quaternions * quaternions).sum(-1)
o = torch.stack(
(
1 - two_s * (j * j + k * k),
two_s * (i * j - k * r),
two_s * (i * k + j * r),
two_s * (i * j + k * r),
1 - two_s * (i * i + k * k),
two_s * (j * k - i * r),
two_s * (i * k - j * r),
two_s * (j * k + i * r),
1 - two_s * (i * i + j * j),
),
-1,
)
return o.reshape(quaternions.shape[:-1] + (3, 3))
def axis_angle_to_matrix(axis_angle):
return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
def axis_angle_to_rotation_6d(axis_angle):
return matrix_to_rotation_6d(axis_angle_to_matrix(axis_angle))
def velocity2position(data_seq, dt, init_pos):
res_trans = []
for i in range(data_seq.shape[1]):
if i == 0:
res_trans.append(init_pos.unsqueeze(1))
else:
res = data_seq[:, i-1:i] * dt + res_trans[-1]
res_trans.append(res)
return torch.cat(res_trans, dim=1)
def recover_from_mask_ts(selected_motion: torch.Tensor, mask: list) -> torch.Tensor:
device = selected_motion.device
dtype = selected_motion.dtype
mask_arr = torch.tensor(mask, dtype=torch.bool, device=device)
j = len(mask_arr)
sum_mask = mask_arr.sum().item()
c_channels = selected_motion.shape[-1] // sum_mask
new_shape = selected_motion.shape[:-1] + (sum_mask, c_channels)
selected_motion = selected_motion.reshape(new_shape)
out_shape = list(selected_motion.shape[:-2]) + [j, c_channels]
recovered = torch.zeros(out_shape, dtype=dtype, device=device)
recovered[..., mask_arr, :] = selected_motion
final_shape = list(recovered.shape[:-2]) + [j*c_channels]
recovered = recovered.reshape(final_shape)
return recovered
class Quantizer(nn.Module):
def __init__(self, n_e, e_dim, beta):
super().__init__()
self.e_dim = e_dim
self.n_e = n_e
self.beta = beta
self.embedding = nn.Embedding(self.n_e, self.e_dim)
self.embedding.weight.data.uniform_(-1.0/self.n_e, 1.0/self.n_e)
def forward(self, z):
assert z.shape[-1] == self.e_dim
z_flattened = z.contiguous().view(-1, self.e_dim)
d = torch.sum(z_flattened**2, dim=1, keepdim=True) + \
torch.sum(self.embedding.weight**2, dim=1) - 2*torch.matmul(z_flattened, self.embedding.weight.t())
min_encoding_indices = torch.argmin(d, dim=1)
z_q = self.embedding(min_encoding_indices).view(z.shape)
loss = torch.mean((z_q - z.detach())**2) + self.beta*torch.mean((z_q.detach() - z)**2)
z_q = z + (z_q - z).detach()
min_encodings = F.one_hot(min_encoding_indices, self.n_e).type(z.dtype)
e_mean = torch.mean(min_encodings, dim=0)
perplexity = torch.exp(-torch.sum(e_mean*torch.log(e_mean+1e-10)))
return loss, z_q, min_encoding_indices, perplexity
def map2index(self, z):
assert z.shape[-1] == self.e_dim
z_flattened = z.contiguous().view(-1, self.e_dim)
d = torch.sum(z_flattened**2, dim=1, keepdim=True) + \
torch.sum(self.embedding.weight**2, dim=1) - 2*torch.matmul(z_flattened, self.embedding.weight.t())
min_encoding_indices = torch.argmin(d, dim=1)
return min_encoding_indices.reshape(z.shape[0], -1)
def get_codebook_entry(self, indices):
index_flattened = indices.view(-1)
z_q = self.embedding(index_flattened)
z_q = z_q.view(indices.shape+(self.e_dim,)).contiguous()
return z_q
def init_weight(m):
if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear) or isinstance(m, nn.ConvTranspose1d):
nn.init.xavier_normal_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
class ResBlock(nn.Module):
def __init__(self, channel):
super().__init__()
self.model = nn.Sequential(
nn.Conv1d(channel, channel, 3, 1, 1),
nn.LeakyReLU(0.2, True),
nn.Conv1d(channel, channel, 3, 1, 1),
)
def forward(self, x):
return self.model(x)+x
class VQEncoderV5(nn.Module):
def __init__(self, args):
super().__init__()
n_down = args.vae_layer
channels = [args.vae_length]*(n_down)
input_size = args.vae_test_dim
layers = [
nn.Conv1d(input_size, channels[0], 3, 1, 1),
nn.LeakyReLU(0.2, True),
ResBlock(channels[0]),
]
for i in range(1, n_down):
layers += [
nn.Conv1d(channels[i-1], channels[i], 3, 1, 1),
nn.LeakyReLU(0.2, True),
ResBlock(channels[i]),
]
self.main = nn.Sequential(*layers)
self.main.apply(init_weight)
def forward(self, inputs):
inputs = inputs.permute(0,2,1)
outputs = self.main(inputs).permute(0,2,1)
return outputs
class VQEncoderV6(nn.Module):
def __init__(self, args):
super().__init__()
n_down = args.vae_layer
channels = [args.vae_length]*(n_down)
input_size = args.vae_test_dim
layers = [
nn.Conv1d(input_size, channels[0], 3, 1, 1),
nn.LeakyReLU(0.2, True),
ResBlock(channels[0]),
]
for i in range(1, n_down):
layers += [
nn.Conv1d(channels[i-1], channels[i], 3, 1, 1),
nn.LeakyReLU(0.2, True),
ResBlock(channels[i]),
]
self.main = nn.Sequential(*layers)
self.main.apply(init_weight)
def forward(self, inputs):
inputs = inputs.permute(0,2,1)
outputs = self.main(inputs).permute(0,2,1)
return outputs
class VQDecoderV5(nn.Module):
def __init__(self, args):
super().__init__()
n_up = args.vae_layer
channels = [args.vae_length]*(n_up)+[args.vae_test_dim]
input_size = args.vae_length
n_resblk = 2
if input_size == channels[0]:
layers = []
else:
layers = [nn.Conv1d(input_size, channels[0], 3, 1, 1)]
for i in range(n_resblk):
layers += [ResBlock(channels[0])]
for i in range(n_up):
layers += [
nn.Conv1d(channels[i], channels[i+1], 3, 1, 1),
nn.LeakyReLU(0.2, True)
]
layers += [nn.Conv1d(channels[-1], channels[-1], 3, 1, 1)]
self.main = nn.Sequential(*layers)
self.main.apply(init_weight)
def forward(self, inputs):
inputs = inputs.permute(0,2,1)
outputs = self.main(inputs).permute(0,2,1)
return outputs
class BasicBlock(nn.Module):
""" based on timm: https://github.com/rwightman/pytorch-image-models """
def __init__(self, inplanes, planes, ker_size, stride=1, downsample=None, dilation=1, first_dilation=None, act_layer=nn.LeakyReLU, norm_layer=nn.BatchNorm1d):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv1d(
inplanes, planes, kernel_size=ker_size, stride=stride, padding=first_dilation,
dilation=dilation, bias=True)
self.bn1 = norm_layer(planes)
self.act1 = act_layer(inplace=True)
self.conv2 = nn.Conv1d(
planes, planes, kernel_size=ker_size, padding=ker_size//2, dilation=dilation, bias=True)
self.bn2 = norm_layer(planes)
self.act2 = act_layer(inplace=True)
if downsample is not None:
self.downsample = nn.Sequential(
nn.Conv1d(inplanes, planes, stride=stride, kernel_size=ker_size, padding=first_dilation, dilation=dilation, bias=True),
norm_layer(planes),
)
else: self.downsample=None
def forward(self, x):
shortcut = x
x = self.conv1(x)
x = self.bn1(x)
x = self.act1(x)
x = self.conv2(x)
x = self.bn2(x)
if self.downsample is not None:
shortcut = self.downsample(shortcut)
x += shortcut
x = self.act2(x)
return x
class WavEncoder(nn.Module):
def __init__(self, out_dim, audio_in=1):
super().__init__()
self.out_dim = out_dim
self.feat_extractor = nn.Sequential(
BasicBlock(audio_in, out_dim//4, 15, 5, first_dilation=1600, downsample=True),
BasicBlock(out_dim//4, out_dim//4, 15, 6, first_dilation=0, downsample=True),
BasicBlock(out_dim//4, out_dim//4, 15, 1, first_dilation=7, ),
BasicBlock(out_dim//4, out_dim//2, 15, 6, first_dilation=0, downsample=True),
BasicBlock(out_dim//2, out_dim//2, 15, 1, first_dilation=7),
BasicBlock(out_dim//2, out_dim, 15, 3, first_dilation=0,downsample=True),
)
def forward(self, wav_data):
if wav_data.dim() == 2:
wav_data = wav_data.unsqueeze(1)
else:
wav_data = wav_data.transpose(1, 2)
out = self.feat_extractor(wav_data)
return out.transpose(1, 2)
class MLP(nn.Module):
def __init__(self, in_dim, middle_dim, out_dim):
super().__init__()
self.fc1 = nn.Linear(in_dim, middle_dim)
self.fc2 = nn.Linear(middle_dim, out_dim)
self.act = nn.LeakyReLU(0.1, True)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.fc2(x)
return x
class PeriodicPositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, period=15, max_seq_len=60):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(period, d_model)
position = torch.arange(0, period, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float()*(-math.log(10000.0)/d_model))
pe[:,0::2] = torch.sin(position*div_term)
pe[:,1::2] = torch.cos(position*div_term)
pe = pe.unsqueeze(0)
repeat_num = (max_seq_len//period)+1
pe = pe.repeat(1, repeat_num, 1)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:, :x.size(1), :]
return self.dropout(x)