Spaces:
Running
on
T4
Running
on
T4
import math | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
def _copysign(a, b): | |
signs_differ = (a < 0) != (b < 0) | |
return torch.where(signs_differ, -a, a) | |
def _sqrt_positive_part(x): | |
ret = torch.zeros_like(x) | |
positive_mask = x > 0 | |
ret[positive_mask] = torch.sqrt(x[positive_mask]) | |
return ret | |
def matrix_to_quaternion(matrix): | |
if matrix.size(-1) != 3 or matrix.size(-2) != 3: | |
raise ValueError | |
m00 = matrix[..., 0, 0] | |
m11 = matrix[..., 1, 1] | |
m22 = matrix[..., 2, 2] | |
o0 = 0.5 * _sqrt_positive_part(1 + m00 + m11 + m22) | |
x = 0.5 * _sqrt_positive_part(1 + m00 - m11 - m22) | |
y = 0.5 * _sqrt_positive_part(1 - m00 + m11 - m22) | |
z = 0.5 * _sqrt_positive_part(1 - m00 - m11 + m22) | |
o1 = _copysign(x, matrix[..., 2, 1] - matrix[..., 1, 2]) | |
o2 = _copysign(y, matrix[..., 0, 2] - matrix[..., 2, 0]) | |
o3 = _copysign(z, matrix[..., 1, 0] - matrix[..., 0, 1]) | |
return torch.stack((o0, o1, o2, o3), -1) | |
def quaternion_to_axis_angle(quaternions): | |
norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True) | |
half_angles = torch.atan2(norms, quaternions[..., :1]) | |
angles = 2 * half_angles | |
eps = 1e-6 | |
small_angles = angles.abs() < eps | |
sin_half_angles_over_angles = torch.empty_like(angles) | |
sin_half_angles_over_angles[~small_angles] = ( | |
torch.sin(half_angles[~small_angles]) / angles[~small_angles] | |
) | |
sin_half_angles_over_angles[small_angles] = ( | |
0.5 - (angles[small_angles]*angles[small_angles])/48 | |
) | |
return quaternions[..., 1:] / sin_half_angles_over_angles | |
def matrix_to_axis_angle(matrix): | |
return quaternion_to_axis_angle(matrix_to_quaternion(matrix)) | |
def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor: | |
a1, a2 = d6[..., :3], d6[..., 3:] | |
b1 = F.normalize(a1, dim=-1) | |
b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1 | |
b2 = F.normalize(b2, dim=-1) | |
b3 = torch.cross(b1, b2, dim=-1) | |
return torch.stack((b1, b2, b3), dim=-2) | |
def rotation_6d_to_axis_angle(rot6d): | |
return matrix_to_axis_angle(rotation_6d_to_matrix(rot6d)) | |
def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor: | |
return matrix[..., :2, :].clone().reshape(*matrix.size()[:-2], 6) | |
def axis_angle_to_quaternion(axis_angle): | |
angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True) | |
half_angles = 0.5 * angles | |
eps = 1e-6 | |
small_angles = angles.abs() < eps | |
sin_half_angles_over_angles = torch.empty_like(angles) | |
sin_half_angles_over_angles[~small_angles] = ( | |
torch.sin(half_angles[~small_angles]) / angles[~small_angles] | |
) | |
sin_half_angles_over_angles[small_angles] = ( | |
0.5 - (angles[small_angles] * angles[small_angles]) / 48 | |
) | |
quaternions = torch.cat( | |
[torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1 | |
) | |
return quaternions | |
def quaternion_to_matrix(quaternions): | |
r, i, j, k = torch.unbind(quaternions, -1) | |
two_s = 2.0 / (quaternions * quaternions).sum(-1) | |
o = torch.stack( | |
( | |
1 - two_s * (j * j + k * k), | |
two_s * (i * j - k * r), | |
two_s * (i * k + j * r), | |
two_s * (i * j + k * r), | |
1 - two_s * (i * i + k * k), | |
two_s * (j * k - i * r), | |
two_s * (i * k - j * r), | |
two_s * (j * k + i * r), | |
1 - two_s * (i * i + j * j), | |
), | |
-1, | |
) | |
return o.reshape(quaternions.shape[:-1] + (3, 3)) | |
def axis_angle_to_matrix(axis_angle): | |
return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle)) | |
def axis_angle_to_rotation_6d(axis_angle): | |
return matrix_to_rotation_6d(axis_angle_to_matrix(axis_angle)) | |
def velocity2position(data_seq, dt, init_pos): | |
res_trans = [] | |
for i in range(data_seq.shape[1]): | |
if i == 0: | |
res_trans.append(init_pos.unsqueeze(1)) | |
else: | |
res = data_seq[:, i-1:i] * dt + res_trans[-1] | |
res_trans.append(res) | |
return torch.cat(res_trans, dim=1) | |
def recover_from_mask_ts(selected_motion: torch.Tensor, mask: list) -> torch.Tensor: | |
device = selected_motion.device | |
dtype = selected_motion.dtype | |
mask_arr = torch.tensor(mask, dtype=torch.bool, device=device) | |
j = len(mask_arr) | |
sum_mask = mask_arr.sum().item() | |
c_channels = selected_motion.shape[-1] // sum_mask | |
new_shape = selected_motion.shape[:-1] + (sum_mask, c_channels) | |
selected_motion = selected_motion.reshape(new_shape) | |
out_shape = list(selected_motion.shape[:-2]) + [j, c_channels] | |
recovered = torch.zeros(out_shape, dtype=dtype, device=device) | |
recovered[..., mask_arr, :] = selected_motion | |
final_shape = list(recovered.shape[:-2]) + [j*c_channels] | |
recovered = recovered.reshape(final_shape) | |
return recovered | |
class Quantizer(nn.Module): | |
def __init__(self, n_e, e_dim, beta): | |
super().__init__() | |
self.e_dim = e_dim | |
self.n_e = n_e | |
self.beta = beta | |
self.embedding = nn.Embedding(self.n_e, self.e_dim) | |
self.embedding.weight.data.uniform_(-1.0/self.n_e, 1.0/self.n_e) | |
def forward(self, z): | |
assert z.shape[-1] == self.e_dim | |
z_flattened = z.contiguous().view(-1, self.e_dim) | |
d = torch.sum(z_flattened**2, dim=1, keepdim=True) + \ | |
torch.sum(self.embedding.weight**2, dim=1) - 2*torch.matmul(z_flattened, self.embedding.weight.t()) | |
min_encoding_indices = torch.argmin(d, dim=1) | |
z_q = self.embedding(min_encoding_indices).view(z.shape) | |
loss = torch.mean((z_q - z.detach())**2) + self.beta*torch.mean((z_q.detach() - z)**2) | |
z_q = z + (z_q - z).detach() | |
min_encodings = F.one_hot(min_encoding_indices, self.n_e).type(z.dtype) | |
e_mean = torch.mean(min_encodings, dim=0) | |
perplexity = torch.exp(-torch.sum(e_mean*torch.log(e_mean+1e-10))) | |
return loss, z_q, min_encoding_indices, perplexity | |
def map2index(self, z): | |
assert z.shape[-1] == self.e_dim | |
z_flattened = z.contiguous().view(-1, self.e_dim) | |
d = torch.sum(z_flattened**2, dim=1, keepdim=True) + \ | |
torch.sum(self.embedding.weight**2, dim=1) - 2*torch.matmul(z_flattened, self.embedding.weight.t()) | |
min_encoding_indices = torch.argmin(d, dim=1) | |
return min_encoding_indices.reshape(z.shape[0], -1) | |
def get_codebook_entry(self, indices): | |
index_flattened = indices.view(-1) | |
z_q = self.embedding(index_flattened) | |
z_q = z_q.view(indices.shape+(self.e_dim,)).contiguous() | |
return z_q | |
def init_weight(m): | |
if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear) or isinstance(m, nn.ConvTranspose1d): | |
nn.init.xavier_normal_(m.weight) | |
if m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
class ResBlock(nn.Module): | |
def __init__(self, channel): | |
super().__init__() | |
self.model = nn.Sequential( | |
nn.Conv1d(channel, channel, 3, 1, 1), | |
nn.LeakyReLU(0.2, True), | |
nn.Conv1d(channel, channel, 3, 1, 1), | |
) | |
def forward(self, x): | |
return self.model(x)+x | |
class VQEncoderV5(nn.Module): | |
def __init__(self, args): | |
super().__init__() | |
n_down = args.vae_layer | |
channels = [args.vae_length]*(n_down) | |
input_size = args.vae_test_dim | |
layers = [ | |
nn.Conv1d(input_size, channels[0], 3, 1, 1), | |
nn.LeakyReLU(0.2, True), | |
ResBlock(channels[0]), | |
] | |
for i in range(1, n_down): | |
layers += [ | |
nn.Conv1d(channels[i-1], channels[i], 3, 1, 1), | |
nn.LeakyReLU(0.2, True), | |
ResBlock(channels[i]), | |
] | |
self.main = nn.Sequential(*layers) | |
self.main.apply(init_weight) | |
def forward(self, inputs): | |
inputs = inputs.permute(0,2,1) | |
outputs = self.main(inputs).permute(0,2,1) | |
return outputs | |
class VQEncoderV6(nn.Module): | |
def __init__(self, args): | |
super().__init__() | |
n_down = args.vae_layer | |
channels = [args.vae_length]*(n_down) | |
input_size = args.vae_test_dim | |
layers = [ | |
nn.Conv1d(input_size, channels[0], 3, 1, 1), | |
nn.LeakyReLU(0.2, True), | |
ResBlock(channels[0]), | |
] | |
for i in range(1, n_down): | |
layers += [ | |
nn.Conv1d(channels[i-1], channels[i], 3, 1, 1), | |
nn.LeakyReLU(0.2, True), | |
ResBlock(channels[i]), | |
] | |
self.main = nn.Sequential(*layers) | |
self.main.apply(init_weight) | |
def forward(self, inputs): | |
inputs = inputs.permute(0,2,1) | |
outputs = self.main(inputs).permute(0,2,1) | |
return outputs | |
class VQDecoderV5(nn.Module): | |
def __init__(self, args): | |
super().__init__() | |
n_up = args.vae_layer | |
channels = [args.vae_length]*(n_up)+[args.vae_test_dim] | |
input_size = args.vae_length | |
n_resblk = 2 | |
if input_size == channels[0]: | |
layers = [] | |
else: | |
layers = [nn.Conv1d(input_size, channels[0], 3, 1, 1)] | |
for i in range(n_resblk): | |
layers += [ResBlock(channels[0])] | |
for i in range(n_up): | |
layers += [ | |
nn.Conv1d(channels[i], channels[i+1], 3, 1, 1), | |
nn.LeakyReLU(0.2, True) | |
] | |
layers += [nn.Conv1d(channels[-1], channels[-1], 3, 1, 1)] | |
self.main = nn.Sequential(*layers) | |
self.main.apply(init_weight) | |
def forward(self, inputs): | |
inputs = inputs.permute(0,2,1) | |
outputs = self.main(inputs).permute(0,2,1) | |
return outputs | |
class BasicBlock(nn.Module): | |
""" based on timm: https://github.com/rwightman/pytorch-image-models """ | |
def __init__(self, inplanes, planes, ker_size, stride=1, downsample=None, dilation=1, first_dilation=None, act_layer=nn.LeakyReLU, norm_layer=nn.BatchNorm1d): | |
super(BasicBlock, self).__init__() | |
self.conv1 = nn.Conv1d( | |
inplanes, planes, kernel_size=ker_size, stride=stride, padding=first_dilation, | |
dilation=dilation, bias=True) | |
self.bn1 = norm_layer(planes) | |
self.act1 = act_layer(inplace=True) | |
self.conv2 = nn.Conv1d( | |
planes, planes, kernel_size=ker_size, padding=ker_size//2, dilation=dilation, bias=True) | |
self.bn2 = norm_layer(planes) | |
self.act2 = act_layer(inplace=True) | |
if downsample is not None: | |
self.downsample = nn.Sequential( | |
nn.Conv1d(inplanes, planes, stride=stride, kernel_size=ker_size, padding=first_dilation, dilation=dilation, bias=True), | |
norm_layer(planes), | |
) | |
else: self.downsample=None | |
def forward(self, x): | |
shortcut = x | |
x = self.conv1(x) | |
x = self.bn1(x) | |
x = self.act1(x) | |
x = self.conv2(x) | |
x = self.bn2(x) | |
if self.downsample is not None: | |
shortcut = self.downsample(shortcut) | |
x += shortcut | |
x = self.act2(x) | |
return x | |
class WavEncoder(nn.Module): | |
def __init__(self, out_dim, audio_in=1): | |
super().__init__() | |
self.out_dim = out_dim | |
self.feat_extractor = nn.Sequential( | |
BasicBlock(audio_in, out_dim//4, 15, 5, first_dilation=1600, downsample=True), | |
BasicBlock(out_dim//4, out_dim//4, 15, 6, first_dilation=0, downsample=True), | |
BasicBlock(out_dim//4, out_dim//4, 15, 1, first_dilation=7, ), | |
BasicBlock(out_dim//4, out_dim//2, 15, 6, first_dilation=0, downsample=True), | |
BasicBlock(out_dim//2, out_dim//2, 15, 1, first_dilation=7), | |
BasicBlock(out_dim//2, out_dim, 15, 3, first_dilation=0,downsample=True), | |
) | |
def forward(self, wav_data): | |
if wav_data.dim() == 2: | |
wav_data = wav_data.unsqueeze(1) | |
else: | |
wav_data = wav_data.transpose(1, 2) | |
out = self.feat_extractor(wav_data) | |
return out.transpose(1, 2) | |
class MLP(nn.Module): | |
def __init__(self, in_dim, middle_dim, out_dim): | |
super().__init__() | |
self.fc1 = nn.Linear(in_dim, middle_dim) | |
self.fc2 = nn.Linear(middle_dim, out_dim) | |
self.act = nn.LeakyReLU(0.1, True) | |
def forward(self, x): | |
x = self.fc1(x) | |
x = self.act(x) | |
x = self.fc2(x) | |
return x | |
class PeriodicPositionalEncoding(nn.Module): | |
def __init__(self, d_model, dropout=0.1, period=15, max_seq_len=60): | |
super().__init__() | |
self.dropout = nn.Dropout(p=dropout) | |
pe = torch.zeros(period, d_model) | |
position = torch.arange(0, period, dtype=torch.float).unsqueeze(1) | |
div_term = torch.exp(torch.arange(0, d_model, 2).float()*(-math.log(10000.0)/d_model)) | |
pe[:,0::2] = torch.sin(position*div_term) | |
pe[:,1::2] = torch.cos(position*div_term) | |
pe = pe.unsqueeze(0) | |
repeat_num = (max_seq_len//period)+1 | |
pe = pe.repeat(1, repeat_num, 1) | |
self.register_buffer('pe', pe) | |
def forward(self, x): | |
x = x + self.pe[:, :x.size(1), :] | |
return self.dropout(x) |