Spaces:
Runtime error
Runtime error
import torch.nn as nn | |
import torch | |
import logging | |
from modules.FastDiff.module.modules import DiffusionDBlock, TimeAware_LVCBlock | |
from modules.FastDiff.module.util import calc_diffusion_step_embedding | |
def swish(x): | |
return x * torch.sigmoid(x) | |
class FastDiff(nn.Module): | |
"""FastDiff module.""" | |
def __init__(self, | |
audio_channels=1, | |
inner_channels=32, | |
cond_channels=80, | |
upsample_ratios=[8, 8, 4], | |
lvc_layers_each_block=4, | |
lvc_kernel_size=3, | |
kpnet_hidden_channels=64, | |
kpnet_conv_size=3, | |
dropout=0.0, | |
diffusion_step_embed_dim_in=128, | |
diffusion_step_embed_dim_mid=512, | |
diffusion_step_embed_dim_out=512, | |
use_weight_norm=True): | |
super().__init__() | |
self.diffusion_step_embed_dim_in = diffusion_step_embed_dim_in | |
self.audio_channels = audio_channels | |
self.cond_channels = cond_channels | |
self.lvc_block_nums = len(upsample_ratios) | |
self.first_audio_conv = nn.Conv1d(1, inner_channels, | |
kernel_size=7, padding=(7 - 1) // 2, | |
dilation=1, bias=True) | |
# define residual blocks | |
self.lvc_blocks = nn.ModuleList() | |
self.downsample = nn.ModuleList() | |
# the layer-specific fc for noise scale embedding | |
self.fc_t = nn.ModuleList() | |
self.fc_t1 = nn.Linear(diffusion_step_embed_dim_in, diffusion_step_embed_dim_mid) | |
self.fc_t2 = nn.Linear(diffusion_step_embed_dim_mid, diffusion_step_embed_dim_out) | |
cond_hop_length = 1 | |
for n in range(self.lvc_block_nums): | |
cond_hop_length = cond_hop_length * upsample_ratios[n] | |
lvcb = TimeAware_LVCBlock( | |
in_channels=inner_channels, | |
cond_channels=cond_channels, | |
upsample_ratio=upsample_ratios[n], | |
conv_layers=lvc_layers_each_block, | |
conv_kernel_size=lvc_kernel_size, | |
cond_hop_length=cond_hop_length, | |
kpnet_hidden_channels=kpnet_hidden_channels, | |
kpnet_conv_size=kpnet_conv_size, | |
kpnet_dropout=dropout, | |
noise_scale_embed_dim_out=diffusion_step_embed_dim_out | |
) | |
self.lvc_blocks += [lvcb] | |
self.downsample.append(DiffusionDBlock(inner_channels, inner_channels, upsample_ratios[self.lvc_block_nums-n-1])) | |
# define output layers | |
self.final_conv = nn.Sequential(nn.Conv1d(inner_channels, audio_channels, kernel_size=7, padding=(7 - 1) // 2, | |
dilation=1, bias=True)) | |
# apply weight norm | |
if use_weight_norm: | |
self.apply_weight_norm() | |
def forward(self, data): | |
"""Calculate forward propagation. | |
Args: | |
x (Tensor): Input noise signal (B, 1, T). | |
c (Tensor): Local conditioning auxiliary features (B, C ,T'). | |
Returns: | |
Tensor: Output tensor (B, out_channels, T) | |
""" | |
audio, c, diffusion_steps = data | |
# embed diffusion step t | |
diffusion_step_embed = calc_diffusion_step_embedding(diffusion_steps, self.diffusion_step_embed_dim_in) | |
diffusion_step_embed = swish(self.fc_t1(diffusion_step_embed)) | |
diffusion_step_embed = swish(self.fc_t2(diffusion_step_embed)) | |
audio = self.first_audio_conv(audio) | |
downsample = [] | |
for down_layer in self.downsample: | |
downsample.append(audio) | |
audio = down_layer(audio) | |
x = audio | |
for n, audio_down in enumerate(reversed(downsample)): | |
x = self.lvc_blocks[n]((x, audio_down, c, diffusion_step_embed)) | |
# apply final layers | |
x = self.final_conv(x) | |
return x | |
def remove_weight_norm(self): | |
"""Remove weight normalization module from all of the layers.""" | |
def _remove_weight_norm(m): | |
try: | |
logging.debug(f"Weight norm is removed from {m}.") | |
torch.nn.utils.remove_weight_norm(m) | |
except ValueError: # this module didn't have weight norm | |
return | |
self.apply(_remove_weight_norm) | |
def apply_weight_norm(self): | |
"""Apply weight normalization module from all of the layers.""" | |
def _apply_weight_norm(m): | |
if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d): | |
torch.nn.utils.weight_norm(m) | |
logging.debug(f"Weight norm is applied to {m}.") | |
self.apply(_apply_weight_norm) | |