|
from torch import nn |
|
|
|
from TTS.vocoder.models.melgan_discriminator import MelganDiscriminator |
|
|
|
|
|
class MelganMultiscaleDiscriminator(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels=1, |
|
out_channels=1, |
|
num_scales=3, |
|
kernel_sizes=(5, 3), |
|
base_channels=16, |
|
max_channels=1024, |
|
downsample_factors=(4, 4, 4), |
|
pooling_kernel_size=4, |
|
pooling_stride=2, |
|
pooling_padding=2, |
|
groups_denominator=4, |
|
): |
|
super().__init__() |
|
|
|
self.discriminators = nn.ModuleList( |
|
[ |
|
MelganDiscriminator( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
kernel_sizes=kernel_sizes, |
|
base_channels=base_channels, |
|
max_channels=max_channels, |
|
downsample_factors=downsample_factors, |
|
groups_denominator=groups_denominator, |
|
) |
|
for _ in range(num_scales) |
|
] |
|
) |
|
|
|
self.pooling = nn.AvgPool1d( |
|
kernel_size=pooling_kernel_size, stride=pooling_stride, padding=pooling_padding, count_include_pad=False |
|
) |
|
|
|
def forward(self, x): |
|
scores = [] |
|
feats = [] |
|
for disc in self.discriminators: |
|
score, feat = disc(x) |
|
scores.append(score) |
|
feats.append(feat) |
|
x = self.pooling(x) |
|
return scores, feats |
|
|