File size: 7,206 Bytes
8c70653 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 |
# adopted from https://github.com/jik876/hifi-gan/blob/master/models.py
import torch
from torch import nn
from torch.nn import functional as F
LRELU_SLOPE = 0.1
class DiscriminatorP(torch.nn.Module):
"""HiFiGAN Periodic Discriminator
Takes every Pth value from the input waveform and applied a stack of convoluations.
Note:
if `period` is 2
`waveform = [1, 2, 3, 4, 5, 6 ...] --> [1, 3, 5 ... ] --> convs -> score, feat`
Args:
x (Tensor): input waveform.
Returns:
[Tensor]: discriminator scores per sample in the batch.
[List[Tensor]]: list of features from each convolutional layer.
Shapes:
x: [B, 1, T]
"""
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
super().__init__()
self.period = period
get_padding = lambda k, d: int((k * d - d) / 2)
norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm
self.convs = nn.ModuleList(
[
norm_f(nn.Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
norm_f(nn.Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
norm_f(nn.Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
norm_f(nn.Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
norm_f(nn.Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
]
)
self.conv_post = norm_f(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
def forward(self, x):
"""
Args:
x (Tensor): input waveform.
Returns:
[Tensor]: discriminator scores per sample in the batch.
[List[Tensor]]: list of features from each convolutional layer.
Shapes:
x: [B, 1, T]
"""
feat = []
# 1d to 2d
b, c, t = x.shape
if t % self.period != 0: # pad first
n_pad = self.period - (t % self.period)
x = F.pad(x, (0, n_pad), "reflect")
t = t + n_pad
x = x.view(b, c, t // self.period, self.period)
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, LRELU_SLOPE)
feat.append(x)
x = self.conv_post(x)
feat.append(x)
x = torch.flatten(x, 1, -1)
return x, feat
class MultiPeriodDiscriminator(torch.nn.Module):
"""HiFiGAN Multi-Period Discriminator (MPD)
Wrapper for the `PeriodDiscriminator` to apply it in different periods.
Periods are suggested to be prime numbers to reduce the overlap between each discriminator.
"""
def __init__(self, use_spectral_norm=False):
super().__init__()
self.discriminators = nn.ModuleList(
[
DiscriminatorP(2, use_spectral_norm=use_spectral_norm),
DiscriminatorP(3, use_spectral_norm=use_spectral_norm),
DiscriminatorP(5, use_spectral_norm=use_spectral_norm),
DiscriminatorP(7, use_spectral_norm=use_spectral_norm),
DiscriminatorP(11, use_spectral_norm=use_spectral_norm),
]
)
def forward(self, x):
"""
Args:
x (Tensor): input waveform.
Returns:
[List[Tensor]]: list of scores from each discriminator.
[List[List[Tensor]]]: list of list of features from each discriminator's each convolutional layer.
Shapes:
x: [B, 1, T]
"""
scores = []
feats = []
for _, d in enumerate(self.discriminators):
score, feat = d(x)
scores.append(score)
feats.append(feat)
return scores, feats
class DiscriminatorS(torch.nn.Module):
"""HiFiGAN Scale Discriminator.
It is similar to `MelganDiscriminator` but with a specific architecture explained in the paper.
Args:
use_spectral_norm (bool): if `True` swith to spectral norm instead of weight norm.
"""
def __init__(self, use_spectral_norm=False):
super().__init__()
norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm
self.convs = nn.ModuleList(
[
norm_f(nn.Conv1d(1, 128, 15, 1, padding=7)),
norm_f(nn.Conv1d(128, 128, 41, 2, groups=4, padding=20)),
norm_f(nn.Conv1d(128, 256, 41, 2, groups=16, padding=20)),
norm_f(nn.Conv1d(256, 512, 41, 4, groups=16, padding=20)),
norm_f(nn.Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
norm_f(nn.Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
norm_f(nn.Conv1d(1024, 1024, 5, 1, padding=2)),
]
)
self.conv_post = norm_f(nn.Conv1d(1024, 1, 3, 1, padding=1))
def forward(self, x):
"""
Args:
x (Tensor): input waveform.
Returns:
Tensor: discriminator scores.
List[Tensor]: list of features from the convolutiona layers.
"""
feat = []
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, LRELU_SLOPE)
feat.append(x)
x = self.conv_post(x)
feat.append(x)
x = torch.flatten(x, 1, -1)
return x, feat
class MultiScaleDiscriminator(torch.nn.Module):
"""HiFiGAN Multi-Scale Discriminator.
It is similar to `MultiScaleMelganDiscriminator` but specially tailored for HiFiGAN as in the paper.
"""
def __init__(self):
super().__init__()
self.discriminators = nn.ModuleList(
[
DiscriminatorS(use_spectral_norm=True),
DiscriminatorS(),
DiscriminatorS(),
]
)
self.meanpools = nn.ModuleList([nn.AvgPool1d(4, 2, padding=2), nn.AvgPool1d(4, 2, padding=2)])
def forward(self, x):
"""
Args:
x (Tensor): input waveform.
Returns:
List[Tensor]: discriminator scores.
List[List[Tensor]]: list of list of features from each layers of each discriminator.
"""
scores = []
feats = []
for i, d in enumerate(self.discriminators):
if i != 0:
x = self.meanpools[i - 1](x)
score, feat = d(x)
scores.append(score)
feats.append(feat)
return scores, feats
class HifiganDiscriminator(nn.Module):
"""HiFiGAN discriminator wrapping MPD and MSD."""
def __init__(self):
super().__init__()
self.mpd = MultiPeriodDiscriminator()
self.msd = MultiScaleDiscriminator()
def forward(self, x):
"""
Args:
x (Tensor): input waveform.
Returns:
List[Tensor]: discriminator scores.
List[List[Tensor]]: list of list of features from each layers of each discriminator.
"""
scores, feats = self.mpd(x)
scores_, feats_ = self.msd(x)
return scores + scores_, feats + feats_
|