File size: 3,306 Bytes
8c70653 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import torch
from torch import nn
from torch.nn.utils import weight_norm
from TTS.utils.io import load_fsspec
from TTS.vocoder.layers.melgan import ResidualStack
class MelganGenerator(nn.Module):
def __init__(
self,
in_channels=80,
out_channels=1,
proj_kernel=7,
base_channels=512,
upsample_factors=(8, 8, 2, 2),
res_kernel=3,
num_res_blocks=3,
):
super().__init__()
# assert model parameters
assert (proj_kernel - 1) % 2 == 0, " [!] proj_kernel should be an odd number."
# setup additional model parameters
base_padding = (proj_kernel - 1) // 2
act_slope = 0.2
self.inference_padding = 2
# initial layer
layers = []
layers += [
nn.ReflectionPad1d(base_padding),
weight_norm(nn.Conv1d(in_channels, base_channels, kernel_size=proj_kernel, stride=1, bias=True)),
]
# upsampling layers and residual stacks
for idx, upsample_factor in enumerate(upsample_factors):
layer_in_channels = base_channels // (2**idx)
layer_out_channels = base_channels // (2 ** (idx + 1))
layer_filter_size = upsample_factor * 2
layer_stride = upsample_factor
layer_output_padding = upsample_factor % 2
layer_padding = upsample_factor // 2 + layer_output_padding
layers += [
nn.LeakyReLU(act_slope),
weight_norm(
nn.ConvTranspose1d(
layer_in_channels,
layer_out_channels,
layer_filter_size,
stride=layer_stride,
padding=layer_padding,
output_padding=layer_output_padding,
bias=True,
)
),
ResidualStack(channels=layer_out_channels, num_res_blocks=num_res_blocks, kernel_size=res_kernel),
]
layers += [nn.LeakyReLU(act_slope)]
# final layer
layers += [
nn.ReflectionPad1d(base_padding),
weight_norm(nn.Conv1d(layer_out_channels, out_channels, proj_kernel, stride=1, bias=True)),
nn.Tanh(),
]
self.layers = nn.Sequential(*layers)
def forward(self, c):
return self.layers(c)
def inference(self, c):
c = c.to(self.layers[1].weight.device)
c = torch.nn.functional.pad(c, (self.inference_padding, self.inference_padding), "replicate")
return self.layers(c)
def remove_weight_norm(self):
for _, layer in enumerate(self.layers):
if len(layer.state_dict()) != 0:
try:
nn.utils.remove_weight_norm(layer)
except ValueError:
layer.remove_weight_norm()
def load_checkpoint(
self, config, checkpoint_path, eval=False, cache=False
): # pylint: disable=unused-argument, redefined-builtin
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
self.load_state_dict(state["model"])
if eval:
self.eval()
assert not self.training
self.remove_weight_norm()
|