File size: 1,479 Bytes
8c70653 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
from torch import nn
from TTS.vocoder.models.melgan_discriminator import MelganDiscriminator
class MelganMultiscaleDiscriminator(nn.Module):
def __init__(
self,
in_channels=1,
out_channels=1,
num_scales=3,
kernel_sizes=(5, 3),
base_channels=16,
max_channels=1024,
downsample_factors=(4, 4, 4),
pooling_kernel_size=4,
pooling_stride=2,
pooling_padding=2,
groups_denominator=4,
):
super().__init__()
self.discriminators = nn.ModuleList(
[
MelganDiscriminator(
in_channels=in_channels,
out_channels=out_channels,
kernel_sizes=kernel_sizes,
base_channels=base_channels,
max_channels=max_channels,
downsample_factors=downsample_factors,
groups_denominator=groups_denominator,
)
for _ in range(num_scales)
]
)
self.pooling = nn.AvgPool1d(
kernel_size=pooling_kernel_size, stride=pooling_stride, padding=pooling_padding, count_include_pad=False
)
def forward(self, x):
scores = []
feats = []
for disc in self.discriminators:
score, feat = disc(x)
scores.append(score)
feats.append(feat)
x = self.pooling(x)
return scores, feats
|