Plachta commited on
Commit
a4d0945
·
verified ·
1 Parent(s): 48dfa3d

Upload 69 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. alias_free_torch/__init__.py +5 -0
  2. alias_free_torch/__pycache__/__init__.cpython-310.pyc +0 -0
  3. alias_free_torch/__pycache__/act.cpython-310.pyc +0 -0
  4. alias_free_torch/__pycache__/filter.cpython-310.pyc +0 -0
  5. alias_free_torch/__pycache__/resample.cpython-310.pyc +0 -0
  6. alias_free_torch/act.py +29 -0
  7. alias_free_torch/filter.py +96 -0
  8. alias_free_torch/resample.py +57 -0
  9. dac/__init__.py +16 -0
  10. dac/__main__.py +36 -0
  11. dac/__pycache__/__init__.cpython-310.pyc +0 -0
  12. dac/model/__init__.py +4 -0
  13. dac/model/__pycache__/__init__.cpython-310.pyc +0 -0
  14. dac/model/__pycache__/base.cpython-310.pyc +0 -0
  15. dac/model/__pycache__/dac.cpython-310.pyc +0 -0
  16. dac/model/__pycache__/discriminator.cpython-310.pyc +0 -0
  17. dac/model/__pycache__/encodec.cpython-310.pyc +0 -0
  18. dac/model/base.py +294 -0
  19. dac/model/dac.py +389 -0
  20. dac/model/discriminator.py +228 -0
  21. dac/model/encodec.py +288 -0
  22. dac/nn/__init__.py +3 -0
  23. dac/nn/__pycache__/__init__.cpython-310.pyc +0 -0
  24. dac/nn/__pycache__/layers.cpython-310.pyc +0 -0
  25. dac/nn/__pycache__/loss.cpython-310.pyc +0 -0
  26. dac/nn/__pycache__/quantize.cpython-310.pyc +0 -0
  27. dac/nn/layers.py +33 -0
  28. dac/nn/loss.py +368 -0
  29. dac/nn/quantize.py +262 -0
  30. dac/utils/__init__.py +123 -0
  31. dac/utils/__pycache__/__init__.cpython-310.pyc +0 -0
  32. dac/utils/decode.py +95 -0
  33. dac/utils/encode.py +94 -0
  34. hf_utils.py +11 -0
  35. modules/__pycache__/attentions.cpython-310.pyc +0 -0
  36. modules/__pycache__/commons.cpython-310.pyc +0 -0
  37. modules/__pycache__/layers.cpython-310.pyc +0 -0
  38. modules/__pycache__/mamba.cpython-310.pyc +0 -0
  39. modules/__pycache__/quantize.cpython-310.pyc +0 -0
  40. modules/__pycache__/redecoder.cpython-310.pyc +0 -0
  41. modules/__pycache__/style_encoder.cpython-310.pyc +0 -0
  42. modules/__pycache__/wavenet.cpython-310.pyc +0 -0
  43. modules/attentions.py +324 -0
  44. modules/beta_vae.py +101 -0
  45. modules/commons.py +479 -0
  46. modules/layers.py +354 -0
  47. modules/quantize.py +613 -0
  48. modules/redecoder.py +63 -0
  49. modules/style_encoder.py +91 -0
  50. modules/wavenet.py +174 -0
alias_free_torch/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+
3
+ from .filter import *
4
+ from .resample import *
5
+ from .act import *
alias_free_torch/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (190 Bytes). View file
 
alias_free_torch/__pycache__/act.cpython-310.pyc ADDED
Binary file (1.02 kB). View file
 
alias_free_torch/__pycache__/filter.cpython-310.pyc ADDED
Binary file (2.6 kB). View file
 
alias_free_torch/__pycache__/resample.cpython-310.pyc ADDED
Binary file (1.88 kB). View file
 
alias_free_torch/act.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+
3
+ import torch.nn as nn
4
+ from .resample import UpSample1d, DownSample1d
5
+
6
+
7
+ class Activation1d(nn.Module):
8
+ def __init__(
9
+ self,
10
+ activation,
11
+ up_ratio: int = 2,
12
+ down_ratio: int = 2,
13
+ up_kernel_size: int = 12,
14
+ down_kernel_size: int = 12,
15
+ ):
16
+ super().__init__()
17
+ self.up_ratio = up_ratio
18
+ self.down_ratio = down_ratio
19
+ self.act = activation
20
+ self.upsample = UpSample1d(up_ratio, up_kernel_size)
21
+ self.downsample = DownSample1d(down_ratio, down_kernel_size)
22
+
23
+ # x: [B,C,T]
24
+ def forward(self, x):
25
+ x = self.upsample(x)
26
+ x = self.act(x)
27
+ x = self.downsample(x)
28
+
29
+ return x
alias_free_torch/filter.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import math
7
+
8
+ if "sinc" in dir(torch):
9
+ sinc = torch.sinc
10
+ else:
11
+ # This code is adopted from adefossez's julius.core.sinc under the MIT License
12
+ # https://adefossez.github.io/julius/julius/core.html
13
+ def sinc(x: torch.Tensor):
14
+ """
15
+ Implementation of sinc, i.e. sin(pi * x) / (pi * x)
16
+ __Warning__: Different to julius.sinc, the input is multiplied by `pi`!
17
+ """
18
+ return torch.where(
19
+ x == 0,
20
+ torch.tensor(1.0, device=x.device, dtype=x.dtype),
21
+ torch.sin(math.pi * x) / math.pi / x,
22
+ )
23
+
24
+
25
+ # This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
26
+ # https://adefossez.github.io/julius/julius/lowpass.html
27
+ def kaiser_sinc_filter1d(
28
+ cutoff, half_width, kernel_size
29
+ ): # return filter [1,1,kernel_size]
30
+ even = kernel_size % 2 == 0
31
+ half_size = kernel_size // 2
32
+
33
+ # For kaiser window
34
+ delta_f = 4 * half_width
35
+ A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
36
+ if A > 50.0:
37
+ beta = 0.1102 * (A - 8.7)
38
+ elif A >= 21.0:
39
+ beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
40
+ else:
41
+ beta = 0.0
42
+ window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
43
+
44
+ # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
45
+ if even:
46
+ time = torch.arange(-half_size, half_size) + 0.5
47
+ else:
48
+ time = torch.arange(kernel_size) - half_size
49
+ if cutoff == 0:
50
+ filter_ = torch.zeros_like(time)
51
+ else:
52
+ filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
53
+ # Normalize filter to have sum = 1, otherwise we will have a small leakage
54
+ # of the constant component in the input signal.
55
+ filter_ /= filter_.sum()
56
+ filter = filter_.view(1, 1, kernel_size)
57
+
58
+ return filter
59
+
60
+
61
+ class LowPassFilter1d(nn.Module):
62
+ def __init__(
63
+ self,
64
+ cutoff=0.5,
65
+ half_width=0.6,
66
+ stride: int = 1,
67
+ padding: bool = True,
68
+ padding_mode: str = "replicate",
69
+ kernel_size: int = 12,
70
+ ):
71
+ # kernel_size should be even number for stylegan3 setup,
72
+ # in this implementation, odd number is also possible.
73
+ super().__init__()
74
+ if cutoff < -0.0:
75
+ raise ValueError("Minimum cutoff must be larger than zero.")
76
+ if cutoff > 0.5:
77
+ raise ValueError("A cutoff above 0.5 does not make sense.")
78
+ self.kernel_size = kernel_size
79
+ self.even = kernel_size % 2 == 0
80
+ self.pad_left = kernel_size // 2 - int(self.even)
81
+ self.pad_right = kernel_size // 2
82
+ self.stride = stride
83
+ self.padding = padding
84
+ self.padding_mode = padding_mode
85
+ filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
86
+ self.register_buffer("filter", filter)
87
+
88
+ # input [B, C, T]
89
+ def forward(self, x):
90
+ _, C, _ = x.shape
91
+
92
+ if self.padding:
93
+ x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
94
+ out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
95
+
96
+ return out
alias_free_torch/resample.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+
3
+ import torch.nn as nn
4
+ from torch.nn import functional as F
5
+ from .filter import LowPassFilter1d
6
+ from .filter import kaiser_sinc_filter1d
7
+
8
+
9
+ class UpSample1d(nn.Module):
10
+ def __init__(self, ratio=2, kernel_size=None):
11
+ super().__init__()
12
+ self.ratio = ratio
13
+ self.kernel_size = (
14
+ int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
15
+ )
16
+ self.stride = ratio
17
+ self.pad = self.kernel_size // ratio - 1
18
+ self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
19
+ self.pad_right = (
20
+ self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
21
+ )
22
+ filter = kaiser_sinc_filter1d(
23
+ cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size
24
+ )
25
+ self.register_buffer("filter", filter)
26
+
27
+ # x: [B, C, T]
28
+ def forward(self, x):
29
+ _, C, _ = x.shape
30
+
31
+ x = F.pad(x, (self.pad, self.pad), mode="replicate")
32
+ x = self.ratio * F.conv_transpose1d(
33
+ x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C
34
+ )
35
+ x = x[..., self.pad_left : -self.pad_right]
36
+
37
+ return x
38
+
39
+
40
+ class DownSample1d(nn.Module):
41
+ def __init__(self, ratio=2, kernel_size=None):
42
+ super().__init__()
43
+ self.ratio = ratio
44
+ self.kernel_size = (
45
+ int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
46
+ )
47
+ self.lowpass = LowPassFilter1d(
48
+ cutoff=0.5 / ratio,
49
+ half_width=0.6 / ratio,
50
+ stride=ratio,
51
+ kernel_size=self.kernel_size,
52
+ )
53
+
54
+ def forward(self, x):
55
+ xx = self.lowpass(x)
56
+
57
+ return xx
dac/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __version__ = "1.0.0"
2
+
3
+ # preserved here for legacy reasons
4
+ __model_version__ = "latest"
5
+
6
+ import audiotools
7
+
8
+ audiotools.ml.BaseModel.INTERN += ["dac.**"]
9
+ audiotools.ml.BaseModel.EXTERN += ["einops"]
10
+
11
+
12
+ from . import nn
13
+ from . import model
14
+ from . import utils
15
+ from .model import DAC
16
+ from .model import DACFile
dac/__main__.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ import argbind
4
+
5
+ from dac.utils import download
6
+ from dac.utils.decode import decode
7
+ from dac.utils.encode import encode
8
+
9
+ STAGES = ["encode", "decode", "download"]
10
+
11
+
12
+ def run(stage: str):
13
+ """Run stages.
14
+
15
+ Parameters
16
+ ----------
17
+ stage : str
18
+ Stage to run
19
+ """
20
+ if stage not in STAGES:
21
+ raise ValueError(f"Unknown command: {stage}. Allowed commands are {STAGES}")
22
+ stage_fn = globals()[stage]
23
+
24
+ if stage == "download":
25
+ stage_fn()
26
+ return
27
+
28
+ stage_fn()
29
+
30
+
31
+ if __name__ == "__main__":
32
+ group = sys.argv.pop(1)
33
+ args = argbind.parse_args(group=group)
34
+
35
+ with argbind.scope(args):
36
+ run(group)
dac/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (435 Bytes). View file
 
dac/model/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .base import CodecMixin
2
+ from .base import DACFile
3
+ from .dac import DAC
4
+ from .discriminator import Discriminator
dac/model/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (276 Bytes). View file
 
dac/model/__pycache__/base.cpython-310.pyc ADDED
Binary file (7.18 kB). View file
 
dac/model/__pycache__/dac.cpython-310.pyc ADDED
Binary file (11.2 kB). View file
 
dac/model/__pycache__/discriminator.cpython-310.pyc ADDED
Binary file (7.97 kB). View file
 
dac/model/__pycache__/encodec.cpython-310.pyc ADDED
Binary file (10.7 kB). View file
 
dac/model/base.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+ from pathlib import Path
4
+ from typing import Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ import tqdm
9
+ from audiotools import AudioSignal
10
+ from torch import nn
11
+
12
+ SUPPORTED_VERSIONS = ["1.0.0"]
13
+
14
+
15
+ @dataclass
16
+ class DACFile:
17
+ codes: torch.Tensor
18
+
19
+ # Metadata
20
+ chunk_length: int
21
+ original_length: int
22
+ input_db: float
23
+ channels: int
24
+ sample_rate: int
25
+ padding: bool
26
+ dac_version: str
27
+
28
+ def save(self, path):
29
+ artifacts = {
30
+ "codes": self.codes.numpy().astype(np.uint16),
31
+ "metadata": {
32
+ "input_db": self.input_db.numpy().astype(np.float32),
33
+ "original_length": self.original_length,
34
+ "sample_rate": self.sample_rate,
35
+ "chunk_length": self.chunk_length,
36
+ "channels": self.channels,
37
+ "padding": self.padding,
38
+ "dac_version": SUPPORTED_VERSIONS[-1],
39
+ },
40
+ }
41
+ path = Path(path).with_suffix(".dac")
42
+ with open(path, "wb") as f:
43
+ np.save(f, artifacts)
44
+ return path
45
+
46
+ @classmethod
47
+ def load(cls, path):
48
+ artifacts = np.load(path, allow_pickle=True)[()]
49
+ codes = torch.from_numpy(artifacts["codes"].astype(int))
50
+ if artifacts["metadata"].get("dac_version", None) not in SUPPORTED_VERSIONS:
51
+ raise RuntimeError(
52
+ f"Given file {path} can't be loaded with this version of descript-audio-codec."
53
+ )
54
+ return cls(codes=codes, **artifacts["metadata"])
55
+
56
+
57
+ class CodecMixin:
58
+ @property
59
+ def padding(self):
60
+ if not hasattr(self, "_padding"):
61
+ self._padding = True
62
+ return self._padding
63
+
64
+ @padding.setter
65
+ def padding(self, value):
66
+ assert isinstance(value, bool)
67
+
68
+ layers = [
69
+ l for l in self.modules() if isinstance(l, (nn.Conv1d, nn.ConvTranspose1d))
70
+ ]
71
+
72
+ for layer in layers:
73
+ if value:
74
+ if hasattr(layer, "original_padding"):
75
+ layer.padding = layer.original_padding
76
+ else:
77
+ layer.original_padding = layer.padding
78
+ layer.padding = tuple(0 for _ in range(len(layer.padding)))
79
+
80
+ self._padding = value
81
+
82
+ def get_delay(self):
83
+ # Any number works here, delay is invariant to input length
84
+ l_out = self.get_output_length(0)
85
+ L = l_out
86
+
87
+ layers = []
88
+ for layer in self.modules():
89
+ if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
90
+ layers.append(layer)
91
+
92
+ for layer in reversed(layers):
93
+ d = layer.dilation[0]
94
+ k = layer.kernel_size[0]
95
+ s = layer.stride[0]
96
+
97
+ if isinstance(layer, nn.ConvTranspose1d):
98
+ L = ((L - d * (k - 1) - 1) / s) + 1
99
+ elif isinstance(layer, nn.Conv1d):
100
+ L = (L - 1) * s + d * (k - 1) + 1
101
+
102
+ L = math.ceil(L)
103
+
104
+ l_in = L
105
+
106
+ return (l_in - l_out) // 2
107
+
108
+ def get_output_length(self, input_length):
109
+ L = input_length
110
+ # Calculate output length
111
+ for layer in self.modules():
112
+ if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
113
+ d = layer.dilation[0]
114
+ k = layer.kernel_size[0]
115
+ s = layer.stride[0]
116
+
117
+ if isinstance(layer, nn.Conv1d):
118
+ L = ((L - d * (k - 1) - 1) / s) + 1
119
+ elif isinstance(layer, nn.ConvTranspose1d):
120
+ L = (L - 1) * s + d * (k - 1) + 1
121
+
122
+ L = math.floor(L)
123
+ return L
124
+
125
+ @torch.no_grad()
126
+ def compress(
127
+ self,
128
+ audio_path_or_signal: Union[str, Path, AudioSignal],
129
+ win_duration: float = 1.0,
130
+ verbose: bool = False,
131
+ normalize_db: float = -16,
132
+ n_quantizers: int = None,
133
+ ) -> DACFile:
134
+ """Processes an audio signal from a file or AudioSignal object into
135
+ discrete codes. This function processes the signal in short windows,
136
+ using constant GPU memory.
137
+
138
+ Parameters
139
+ ----------
140
+ audio_path_or_signal : Union[str, Path, AudioSignal]
141
+ audio signal to reconstruct
142
+ win_duration : float, optional
143
+ window duration in seconds, by default 5.0
144
+ verbose : bool, optional
145
+ by default False
146
+ normalize_db : float, optional
147
+ normalize db, by default -16
148
+
149
+ Returns
150
+ -------
151
+ DACFile
152
+ Object containing compressed codes and metadata
153
+ required for decompression
154
+ """
155
+ audio_signal = audio_path_or_signal
156
+ if isinstance(audio_signal, (str, Path)):
157
+ audio_signal = AudioSignal.load_from_file_with_ffmpeg(str(audio_signal))
158
+
159
+ self.eval()
160
+ original_padding = self.padding
161
+ original_device = audio_signal.device
162
+
163
+ audio_signal = audio_signal.clone()
164
+ original_sr = audio_signal.sample_rate
165
+
166
+ resample_fn = audio_signal.resample
167
+ loudness_fn = audio_signal.loudness
168
+
169
+ # If audio is > 10 minutes long, use the ffmpeg versions
170
+ if audio_signal.signal_duration >= 10 * 60 * 60:
171
+ resample_fn = audio_signal.ffmpeg_resample
172
+ loudness_fn = audio_signal.ffmpeg_loudness
173
+
174
+ original_length = audio_signal.signal_length
175
+ resample_fn(self.sample_rate)
176
+ input_db = loudness_fn()
177
+
178
+ if normalize_db is not None:
179
+ audio_signal.normalize(normalize_db)
180
+ audio_signal.ensure_max_of_audio()
181
+
182
+ nb, nac, nt = audio_signal.audio_data.shape
183
+ audio_signal.audio_data = audio_signal.audio_data.reshape(nb * nac, 1, nt)
184
+ win_duration = (
185
+ audio_signal.signal_duration if win_duration is None else win_duration
186
+ )
187
+
188
+ if audio_signal.signal_duration <= win_duration:
189
+ # Unchunked compression (used if signal length < win duration)
190
+ self.padding = True
191
+ n_samples = nt
192
+ hop = nt
193
+ else:
194
+ # Chunked inference
195
+ self.padding = False
196
+ # Zero-pad signal on either side by the delay
197
+ audio_signal.zero_pad(self.delay, self.delay)
198
+ n_samples = int(win_duration * self.sample_rate)
199
+ # Round n_samples to nearest hop length multiple
200
+ n_samples = int(math.ceil(n_samples / self.hop_length) * self.hop_length)
201
+ hop = self.get_output_length(n_samples)
202
+
203
+ codes = []
204
+ range_fn = range if not verbose else tqdm.trange
205
+
206
+ for i in range_fn(0, nt, hop):
207
+ x = audio_signal[..., i : i + n_samples]
208
+ x = x.zero_pad(0, max(0, n_samples - x.shape[-1]))
209
+
210
+ audio_data = x.audio_data.to(self.device)
211
+ audio_data = self.preprocess(audio_data, self.sample_rate)
212
+ _, c, _, _, _ = self.encode(audio_data, n_quantizers)
213
+ codes.append(c.to(original_device))
214
+ chunk_length = c.shape[-1]
215
+
216
+ codes = torch.cat(codes, dim=-1)
217
+
218
+ dac_file = DACFile(
219
+ codes=codes,
220
+ chunk_length=chunk_length,
221
+ original_length=original_length,
222
+ input_db=input_db,
223
+ channels=nac,
224
+ sample_rate=original_sr,
225
+ padding=self.padding,
226
+ dac_version=SUPPORTED_VERSIONS[-1],
227
+ )
228
+
229
+ if n_quantizers is not None:
230
+ codes = codes[:, :n_quantizers, :]
231
+
232
+ self.padding = original_padding
233
+ return dac_file
234
+
235
+ @torch.no_grad()
236
+ def decompress(
237
+ self,
238
+ obj: Union[str, Path, DACFile],
239
+ verbose: bool = False,
240
+ ) -> AudioSignal:
241
+ """Reconstruct audio from a given .dac file
242
+
243
+ Parameters
244
+ ----------
245
+ obj : Union[str, Path, DACFile]
246
+ .dac file location or corresponding DACFile object.
247
+ verbose : bool, optional
248
+ Prints progress if True, by default False
249
+
250
+ Returns
251
+ -------
252
+ AudioSignal
253
+ Object with the reconstructed audio
254
+ """
255
+ self.eval()
256
+ if isinstance(obj, (str, Path)):
257
+ obj = DACFile.load(obj)
258
+
259
+ original_padding = self.padding
260
+ self.padding = obj.padding
261
+
262
+ range_fn = range if not verbose else tqdm.trange
263
+ codes = obj.codes
264
+ original_device = codes.device
265
+ chunk_length = obj.chunk_length
266
+ recons = []
267
+
268
+ for i in range_fn(0, codes.shape[-1], chunk_length):
269
+ c = codes[..., i : i + chunk_length].to(self.device)
270
+ z = self.quantizer.from_codes(c)[0]
271
+ r = self.decode(z)
272
+ recons.append(r.to(original_device))
273
+
274
+ recons = torch.cat(recons, dim=-1)
275
+ recons = AudioSignal(recons, self.sample_rate)
276
+
277
+ resample_fn = recons.resample
278
+ loudness_fn = recons.loudness
279
+
280
+ # If audio is > 10 minutes long, use the ffmpeg versions
281
+ if recons.signal_duration >= 10 * 60 * 60:
282
+ resample_fn = recons.ffmpeg_resample
283
+ loudness_fn = recons.ffmpeg_loudness
284
+
285
+ recons.normalize(obj.input_db)
286
+ resample_fn(obj.sample_rate)
287
+ recons = recons[..., : obj.original_length]
288
+ loudness_fn()
289
+ recons.audio_data = recons.audio_data.reshape(
290
+ -1, obj.channels, obj.original_length
291
+ )
292
+
293
+ self.padding = original_padding
294
+ return recons
dac/model/dac.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List
3
+ from typing import Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ from audiotools import AudioSignal
8
+ from audiotools.ml import BaseModel
9
+ from torch import nn
10
+
11
+ from .base import CodecMixin
12
+ from dac.nn.layers import Snake1d
13
+ from dac.nn.layers import WNConv1d
14
+ from dac.nn.layers import WNConvTranspose1d
15
+ from dac.nn.quantize import ResidualVectorQuantize
16
+ from .encodec import SConv1d, SConvTranspose1d, SLSTM
17
+
18
+
19
+ def init_weights(m):
20
+ if isinstance(m, nn.Conv1d):
21
+ nn.init.trunc_normal_(m.weight, std=0.02)
22
+ nn.init.constant_(m.bias, 0)
23
+
24
+
25
+ class ResidualUnit(nn.Module):
26
+ def __init__(self, dim: int = 16, dilation: int = 1, causal: bool = False):
27
+ super().__init__()
28
+ conv1d_type = SConv1d# if causal else WNConv1d
29
+ pad = ((7 - 1) * dilation) // 2
30
+ self.block = nn.Sequential(
31
+ Snake1d(dim),
32
+ conv1d_type(dim, dim, kernel_size=7, dilation=dilation, padding=pad, causal=causal, norm='weight_norm'),
33
+ Snake1d(dim),
34
+ conv1d_type(dim, dim, kernel_size=1, causal=causal, norm='weight_norm'),
35
+ )
36
+
37
+ def forward(self, x):
38
+ y = self.block(x)
39
+ pad = (x.shape[-1] - y.shape[-1]) // 2
40
+ if pad > 0:
41
+ x = x[..., pad:-pad]
42
+ return x + y
43
+
44
+
45
+ class EncoderBlock(nn.Module):
46
+ def __init__(self, dim: int = 16, stride: int = 1, causal: bool = False):
47
+ super().__init__()
48
+ conv1d_type = SConv1d# if causal else WNConv1d
49
+ self.block = nn.Sequential(
50
+ ResidualUnit(dim // 2, dilation=1, causal=causal),
51
+ ResidualUnit(dim // 2, dilation=3, causal=causal),
52
+ ResidualUnit(dim // 2, dilation=9, causal=causal),
53
+ Snake1d(dim // 2),
54
+ conv1d_type(
55
+ dim // 2,
56
+ dim,
57
+ kernel_size=2 * stride,
58
+ stride=stride,
59
+ padding=math.ceil(stride / 2),
60
+ causal=causal,
61
+ norm='weight_norm',
62
+ ),
63
+ )
64
+
65
+ def forward(self, x):
66
+ return self.block(x)
67
+
68
+
69
+ class Encoder(nn.Module):
70
+ def __init__(
71
+ self,
72
+ d_model: int = 64,
73
+ strides: list = [2, 4, 8, 8],
74
+ d_latent: int = 64,
75
+ causal: bool = False,
76
+ lstm: int = 2,
77
+ ):
78
+ super().__init__()
79
+ conv1d_type = SConv1d# if causal else WNConv1d
80
+ # Create first convolution
81
+ self.block = [conv1d_type(1, d_model, kernel_size=7, padding=3, causal=causal, norm='weight_norm')]
82
+
83
+ # Create EncoderBlocks that double channels as they downsample by `stride`
84
+ for stride in strides:
85
+ d_model *= 2
86
+ self.block += [EncoderBlock(d_model, stride=stride, causal=causal)]
87
+
88
+ # Add LSTM if needed
89
+ self.use_lstm = lstm
90
+ if lstm:
91
+ self.block += [SLSTM(d_model, lstm)]
92
+
93
+ # Create last convolution
94
+ self.block += [
95
+ Snake1d(d_model),
96
+ conv1d_type(d_model, d_latent, kernel_size=3, padding=1, causal=causal, norm='weight_norm'),
97
+ ]
98
+
99
+ # Wrap black into nn.Sequential
100
+ self.block = nn.Sequential(*self.block)
101
+ self.enc_dim = d_model
102
+
103
+ def forward(self, x):
104
+ return self.block(x)
105
+
106
+
107
+ class DecoderBlock(nn.Module):
108
+ def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1, causal: bool = False):
109
+ super().__init__()
110
+ conv1d_type = SConvTranspose1d #if causal else WNConvTranspose1d
111
+ self.block = nn.Sequential(
112
+ Snake1d(input_dim),
113
+ conv1d_type(
114
+ input_dim,
115
+ output_dim,
116
+ kernel_size=2 * stride,
117
+ stride=stride,
118
+ padding=math.ceil(stride / 2),
119
+ causal=causal,
120
+ norm='weight_norm'
121
+ ),
122
+ ResidualUnit(output_dim, dilation=1, causal=causal),
123
+ ResidualUnit(output_dim, dilation=3, causal=causal),
124
+ ResidualUnit(output_dim, dilation=9, causal=causal),
125
+ )
126
+
127
+ def forward(self, x):
128
+ return self.block(x)
129
+
130
+
131
+ class Decoder(nn.Module):
132
+ def __init__(
133
+ self,
134
+ input_channel,
135
+ channels,
136
+ rates,
137
+ d_out: int = 1,
138
+ causal: bool = False,
139
+ lstm: int = 2,
140
+ ):
141
+ super().__init__()
142
+ conv1d_type = SConv1d# if causal else WNConv1d
143
+ # Add first conv layer
144
+ layers = [conv1d_type(input_channel, channels, kernel_size=7, padding=3, causal=causal, norm='weight_norm')]
145
+
146
+ if lstm:
147
+ layers += [SLSTM(channels, num_layers=lstm)]
148
+
149
+ # Add upsampling + MRF blocks
150
+ for i, stride in enumerate(rates):
151
+ input_dim = channels // 2**i
152
+ output_dim = channels // 2 ** (i + 1)
153
+ layers += [DecoderBlock(input_dim, output_dim, stride, causal=causal)]
154
+
155
+ # Add final conv layer
156
+ layers += [
157
+ Snake1d(output_dim),
158
+ conv1d_type(output_dim, d_out, kernel_size=7, padding=3, causal=causal, norm='weight_norm'),
159
+ nn.Tanh(),
160
+ ]
161
+
162
+ self.model = nn.Sequential(*layers)
163
+
164
+ def forward(self, x):
165
+ return self.model(x)
166
+
167
+
168
+ class DAC(BaseModel, CodecMixin):
169
+ def __init__(
170
+ self,
171
+ encoder_dim: int = 64,
172
+ encoder_rates: List[int] = [2, 4, 8, 8],
173
+ latent_dim: int = None,
174
+ decoder_dim: int = 1536,
175
+ decoder_rates: List[int] = [8, 8, 4, 2],
176
+ n_codebooks: int = 9,
177
+ codebook_size: int = 1024,
178
+ codebook_dim: Union[int, list] = 8,
179
+ quantizer_dropout: bool = False,
180
+ sample_rate: int = 44100,
181
+ lstm: int = 2,
182
+ causal: bool = False,
183
+ ):
184
+ super().__init__()
185
+
186
+ self.encoder_dim = encoder_dim
187
+ self.encoder_rates = encoder_rates
188
+ self.decoder_dim = decoder_dim
189
+ self.decoder_rates = decoder_rates
190
+ self.sample_rate = sample_rate
191
+
192
+ if latent_dim is None:
193
+ latent_dim = encoder_dim * (2 ** len(encoder_rates))
194
+
195
+ self.latent_dim = latent_dim
196
+
197
+ self.hop_length = np.prod(encoder_rates)
198
+ self.encoder = Encoder(encoder_dim, encoder_rates, latent_dim, causal=causal, lstm=lstm)
199
+
200
+ self.n_codebooks = n_codebooks
201
+ self.codebook_size = codebook_size
202
+ self.codebook_dim = codebook_dim
203
+ self.quantizer = ResidualVectorQuantize(
204
+ input_dim=latent_dim,
205
+ n_codebooks=n_codebooks,
206
+ codebook_size=codebook_size,
207
+ codebook_dim=codebook_dim,
208
+ quantizer_dropout=quantizer_dropout,
209
+ )
210
+
211
+ self.decoder = Decoder(
212
+ latent_dim,
213
+ decoder_dim,
214
+ decoder_rates,
215
+ lstm=lstm,
216
+ causal=causal,
217
+ )
218
+ self.sample_rate = sample_rate
219
+ self.apply(init_weights)
220
+
221
+ self.delay = self.get_delay()
222
+
223
+ def preprocess(self, audio_data, sample_rate):
224
+ if sample_rate is None:
225
+ sample_rate = self.sample_rate
226
+ assert sample_rate == self.sample_rate
227
+
228
+ length = audio_data.shape[-1]
229
+ right_pad = math.ceil(length / self.hop_length) * self.hop_length - length
230
+ audio_data = nn.functional.pad(audio_data, (0, right_pad))
231
+
232
+ return audio_data
233
+
234
+ def encode(
235
+ self,
236
+ audio_data: torch.Tensor,
237
+ n_quantizers: int = None,
238
+ ):
239
+ """Encode given audio data and return quantized latent codes
240
+
241
+ Parameters
242
+ ----------
243
+ audio_data : Tensor[B x 1 x T]
244
+ Audio data to encode
245
+ n_quantizers : int, optional
246
+ Number of quantizers to use, by default None
247
+ If None, all quantizers are used.
248
+
249
+ Returns
250
+ -------
251
+ dict
252
+ A dictionary with the following keys:
253
+ "z" : Tensor[B x D x T]
254
+ Quantized continuous representation of input
255
+ "codes" : Tensor[B x N x T]
256
+ Codebook indices for each codebook
257
+ (quantized discrete representation of input)
258
+ "latents" : Tensor[B x N*D x T]
259
+ Projected latents (continuous representation of input before quantization)
260
+ "vq/commitment_loss" : Tensor[1]
261
+ Commitment loss to train encoder to predict vectors closer to codebook
262
+ entries
263
+ "vq/codebook_loss" : Tensor[1]
264
+ Codebook loss to update the codebook
265
+ "length" : int
266
+ Number of samples in input audio
267
+ """
268
+ z = self.encoder(audio_data)
269
+ z, codes, latents, commitment_loss, codebook_loss = self.quantizer(
270
+ z, n_quantizers
271
+ )
272
+ return z, codes, latents, commitment_loss, codebook_loss
273
+
274
+ def decode(self, z: torch.Tensor):
275
+ """Decode given latent codes and return audio data
276
+
277
+ Parameters
278
+ ----------
279
+ z : Tensor[B x D x T]
280
+ Quantized continuous representation of input
281
+ length : int, optional
282
+ Number of samples in output audio, by default None
283
+
284
+ Returns
285
+ -------
286
+ dict
287
+ A dictionary with the following keys:
288
+ "audio" : Tensor[B x 1 x length]
289
+ Decoded audio data.
290
+ """
291
+ return self.decoder(z)
292
+
293
+ def forward(
294
+ self,
295
+ audio_data: torch.Tensor,
296
+ sample_rate: int = None,
297
+ n_quantizers: int = None,
298
+ ):
299
+ """Model forward pass
300
+
301
+ Parameters
302
+ ----------
303
+ audio_data : Tensor[B x 1 x T]
304
+ Audio data to encode
305
+ sample_rate : int, optional
306
+ Sample rate of audio data in Hz, by default None
307
+ If None, defaults to `self.sample_rate`
308
+ n_quantizers : int, optional
309
+ Number of quantizers to use, by default None.
310
+ If None, all quantizers are used.
311
+
312
+ Returns
313
+ -------
314
+ dict
315
+ A dictionary with the following keys:
316
+ "z" : Tensor[B x D x T]
317
+ Quantized continuous representation of input
318
+ "codes" : Tensor[B x N x T]
319
+ Codebook indices for each codebook
320
+ (quantized discrete representation of input)
321
+ "latents" : Tensor[B x N*D x T]
322
+ Projected latents (continuous representation of input before quantization)
323
+ "vq/commitment_loss" : Tensor[1]
324
+ Commitment loss to train encoder to predict vectors closer to codebook
325
+ entries
326
+ "vq/codebook_loss" : Tensor[1]
327
+ Codebook loss to update the codebook
328
+ "length" : int
329
+ Number of samples in input audio
330
+ "audio" : Tensor[B x 1 x length]
331
+ Decoded audio data.
332
+ """
333
+ length = audio_data.shape[-1]
334
+ audio_data = self.preprocess(audio_data, sample_rate)
335
+ z, codes, latents, commitment_loss, codebook_loss = self.encode(
336
+ audio_data, n_quantizers
337
+ )
338
+
339
+ x = self.decode(z)
340
+ return {
341
+ "audio": x[..., :length],
342
+ "z": z,
343
+ "codes": codes,
344
+ "latents": latents,
345
+ "vq/commitment_loss": commitment_loss,
346
+ "vq/codebook_loss": codebook_loss,
347
+ }
348
+
349
+
350
+ if __name__ == "__main__":
351
+ import numpy as np
352
+ from functools import partial
353
+
354
+ model = DAC().to("cpu")
355
+
356
+ for n, m in model.named_modules():
357
+ o = m.extra_repr()
358
+ p = sum([np.prod(p.size()) for p in m.parameters()])
359
+ fn = lambda o, p: o + f" {p/1e6:<.3f}M params."
360
+ setattr(m, "extra_repr", partial(fn, o=o, p=p))
361
+ print(model)
362
+ print("Total # of params: ", sum([np.prod(p.size()) for p in model.parameters()]))
363
+
364
+ length = 88200 * 2
365
+ x = torch.randn(1, 1, length).to(model.device)
366
+ x.requires_grad_(True)
367
+ x.retain_grad()
368
+
369
+ # Make a forward pass
370
+ out = model(x)["audio"]
371
+ print("Input shape:", x.shape)
372
+ print("Output shape:", out.shape)
373
+
374
+ # Create gradient variable
375
+ grad = torch.zeros_like(out)
376
+ grad[:, :, grad.shape[-1] // 2] = 1
377
+
378
+ # Make a backward pass
379
+ out.backward(grad)
380
+
381
+ # Check non-zero values
382
+ gradmap = x.grad.squeeze(0)
383
+ gradmap = (gradmap != 0).sum(0) # sum across features
384
+ rf = (gradmap != 0).sum()
385
+
386
+ print(f"Receptive field: {rf.item()}")
387
+
388
+ x = AudioSignal(torch.randn(1, 1, 44100 * 60), 44100)
389
+ model.decompress(model.compress(x, verbose=True), verbose=True)
dac/model/discriminator.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from audiotools import AudioSignal
5
+ from audiotools import ml
6
+ from audiotools import STFTParams
7
+ from einops import rearrange
8
+ from torch.nn.utils import weight_norm
9
+
10
+
11
+ def WNConv1d(*args, **kwargs):
12
+ act = kwargs.pop("act", True)
13
+ conv = weight_norm(nn.Conv1d(*args, **kwargs))
14
+ if not act:
15
+ return conv
16
+ return nn.Sequential(conv, nn.LeakyReLU(0.1))
17
+
18
+
19
+ def WNConv2d(*args, **kwargs):
20
+ act = kwargs.pop("act", True)
21
+ conv = weight_norm(nn.Conv2d(*args, **kwargs))
22
+ if not act:
23
+ return conv
24
+ return nn.Sequential(conv, nn.LeakyReLU(0.1))
25
+
26
+
27
+ class MPD(nn.Module):
28
+ def __init__(self, period):
29
+ super().__init__()
30
+ self.period = period
31
+ self.convs = nn.ModuleList(
32
+ [
33
+ WNConv2d(1, 32, (5, 1), (3, 1), padding=(2, 0)),
34
+ WNConv2d(32, 128, (5, 1), (3, 1), padding=(2, 0)),
35
+ WNConv2d(128, 512, (5, 1), (3, 1), padding=(2, 0)),
36
+ WNConv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0)),
37
+ WNConv2d(1024, 1024, (5, 1), 1, padding=(2, 0)),
38
+ ]
39
+ )
40
+ self.conv_post = WNConv2d(
41
+ 1024, 1, kernel_size=(3, 1), padding=(1, 0), act=False
42
+ )
43
+
44
+ def pad_to_period(self, x):
45
+ t = x.shape[-1]
46
+ x = F.pad(x, (0, self.period - t % self.period), mode="reflect")
47
+ return x
48
+
49
+ def forward(self, x):
50
+ fmap = []
51
+
52
+ x = self.pad_to_period(x)
53
+ x = rearrange(x, "b c (l p) -> b c l p", p=self.period)
54
+
55
+ for layer in self.convs:
56
+ x = layer(x)
57
+ fmap.append(x)
58
+
59
+ x = self.conv_post(x)
60
+ fmap.append(x)
61
+
62
+ return fmap
63
+
64
+
65
+ class MSD(nn.Module):
66
+ def __init__(self, rate: int = 1, sample_rate: int = 44100):
67
+ super().__init__()
68
+ self.convs = nn.ModuleList(
69
+ [
70
+ WNConv1d(1, 16, 15, 1, padding=7),
71
+ WNConv1d(16, 64, 41, 4, groups=4, padding=20),
72
+ WNConv1d(64, 256, 41, 4, groups=16, padding=20),
73
+ WNConv1d(256, 1024, 41, 4, groups=64, padding=20),
74
+ WNConv1d(1024, 1024, 41, 4, groups=256, padding=20),
75
+ WNConv1d(1024, 1024, 5, 1, padding=2),
76
+ ]
77
+ )
78
+ self.conv_post = WNConv1d(1024, 1, 3, 1, padding=1, act=False)
79
+ self.sample_rate = sample_rate
80
+ self.rate = rate
81
+
82
+ def forward(self, x):
83
+ x = AudioSignal(x, self.sample_rate)
84
+ x.resample(self.sample_rate // self.rate)
85
+ x = x.audio_data
86
+
87
+ fmap = []
88
+
89
+ for l in self.convs:
90
+ x = l(x)
91
+ fmap.append(x)
92
+ x = self.conv_post(x)
93
+ fmap.append(x)
94
+
95
+ return fmap
96
+
97
+
98
+ BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)]
99
+
100
+
101
+ class MRD(nn.Module):
102
+ def __init__(
103
+ self,
104
+ window_length: int,
105
+ hop_factor: float = 0.25,
106
+ sample_rate: int = 44100,
107
+ bands: list = BANDS,
108
+ ):
109
+ """Complex multi-band spectrogram discriminator.
110
+ Parameters
111
+ ----------
112
+ window_length : int
113
+ Window length of STFT.
114
+ hop_factor : float, optional
115
+ Hop factor of the STFT, defaults to ``0.25 * window_length``.
116
+ sample_rate : int, optional
117
+ Sampling rate of audio in Hz, by default 44100
118
+ bands : list, optional
119
+ Bands to run discriminator over.
120
+ """
121
+ super().__init__()
122
+
123
+ self.window_length = window_length
124
+ self.hop_factor = hop_factor
125
+ self.sample_rate = sample_rate
126
+ self.stft_params = STFTParams(
127
+ window_length=window_length,
128
+ hop_length=int(window_length * hop_factor),
129
+ match_stride=True,
130
+ )
131
+
132
+ n_fft = window_length // 2 + 1
133
+ bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
134
+ self.bands = bands
135
+
136
+ ch = 32
137
+ convs = lambda: nn.ModuleList(
138
+ [
139
+ WNConv2d(2, ch, (3, 9), (1, 1), padding=(1, 4)),
140
+ WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
141
+ WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
142
+ WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
143
+ WNConv2d(ch, ch, (3, 3), (1, 1), padding=(1, 1)),
144
+ ]
145
+ )
146
+ self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
147
+ self.conv_post = WNConv2d(ch, 1, (3, 3), (1, 1), padding=(1, 1), act=False)
148
+
149
+ def spectrogram(self, x):
150
+ x = AudioSignal(x, self.sample_rate, stft_params=self.stft_params)
151
+ x = torch.view_as_real(x.stft())
152
+ x = rearrange(x, "b 1 f t c -> (b 1) c t f")
153
+ # Split into bands
154
+ x_bands = [x[..., b[0] : b[1]] for b in self.bands]
155
+ return x_bands
156
+
157
+ def forward(self, x):
158
+ x_bands = self.spectrogram(x)
159
+ fmap = []
160
+
161
+ x = []
162
+ for band, stack in zip(x_bands, self.band_convs):
163
+ for layer in stack:
164
+ band = layer(band)
165
+ fmap.append(band)
166
+ x.append(band)
167
+
168
+ x = torch.cat(x, dim=-1)
169
+ x = self.conv_post(x)
170
+ fmap.append(x)
171
+
172
+ return fmap
173
+
174
+
175
+ class Discriminator(nn.Module):
176
+ def __init__(
177
+ self,
178
+ rates: list = [],
179
+ periods: list = [2, 3, 5, 7, 11],
180
+ fft_sizes: list = [2048, 1024, 512],
181
+ sample_rate: int = 44100,
182
+ bands: list = BANDS,
183
+ ):
184
+ """Discriminator that combines multiple discriminators.
185
+
186
+ Parameters
187
+ ----------
188
+ rates : list, optional
189
+ sampling rates (in Hz) to run MSD at, by default []
190
+ If empty, MSD is not used.
191
+ periods : list, optional
192
+ periods (of samples) to run MPD at, by default [2, 3, 5, 7, 11]
193
+ fft_sizes : list, optional
194
+ Window sizes of the FFT to run MRD at, by default [2048, 1024, 512]
195
+ sample_rate : int, optional
196
+ Sampling rate of audio in Hz, by default 44100
197
+ bands : list, optional
198
+ Bands to run MRD at, by default `BANDS`
199
+ """
200
+ super().__init__()
201
+ discs = []
202
+ discs += [MPD(p) for p in periods]
203
+ discs += [MSD(r, sample_rate=sample_rate) for r in rates]
204
+ discs += [MRD(f, sample_rate=sample_rate, bands=bands) for f in fft_sizes]
205
+ self.discriminators = nn.ModuleList(discs)
206
+
207
+ def preprocess(self, y):
208
+ # Remove DC offset
209
+ y = y - y.mean(dim=-1, keepdims=True)
210
+ # Peak normalize the volume of input audio
211
+ y = 0.8 * y / (y.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
212
+ return y
213
+
214
+ def forward(self, x):
215
+ x = self.preprocess(x)
216
+ fmaps = [d(x) for d in self.discriminators]
217
+ return fmaps
218
+
219
+
220
+ if __name__ == "__main__":
221
+ disc = Discriminator()
222
+ x = torch.zeros(1, 1, 44100)
223
+ results = disc(x)
224
+ for i, result in enumerate(results):
225
+ print(f"disc{i}")
226
+ for i, r in enumerate(result):
227
+ print(r.shape, r.mean(), r.min(), r.max())
228
+ print()
dac/model/encodec.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Convolutional layers wrappers and utilities."""
8
+
9
+ import math
10
+ import typing as tp
11
+ import warnings
12
+
13
+ import torch
14
+ from torch import nn
15
+ from torch.nn import functional as F
16
+ from torch.nn.utils import spectral_norm, weight_norm
17
+
18
+ import typing as tp
19
+
20
+ import einops
21
+
22
+
23
+ class ConvLayerNorm(nn.LayerNorm):
24
+ """
25
+ Convolution-friendly LayerNorm that moves channels to last dimensions
26
+ before running the normalization and moves them back to original position right after.
27
+ """
28
+ def __init__(self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs):
29
+ super().__init__(normalized_shape, **kwargs)
30
+
31
+ def forward(self, x):
32
+ x = einops.rearrange(x, 'b ... t -> b t ...')
33
+ x = super().forward(x)
34
+ x = einops.rearrange(x, 'b t ... -> b ... t')
35
+ return
36
+
37
+
38
+ CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm',
39
+ 'time_layer_norm', 'layer_norm', 'time_group_norm'])
40
+
41
+
42
+ def apply_parametrization_norm(module: nn.Module, norm: str = 'none') -> nn.Module:
43
+ assert norm in CONV_NORMALIZATIONS
44
+ if norm == 'weight_norm':
45
+ return weight_norm(module)
46
+ elif norm == 'spectral_norm':
47
+ return spectral_norm(module)
48
+ else:
49
+ # We already check was in CONV_NORMALIZATION, so any other choice
50
+ # doesn't need reparametrization.
51
+ return module
52
+
53
+
54
+ def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs) -> nn.Module:
55
+ """Return the proper normalization module. If causal is True, this will ensure the returned
56
+ module is causal, or return an error if the normalization doesn't support causal evaluation.
57
+ """
58
+ assert norm in CONV_NORMALIZATIONS
59
+ if norm == 'layer_norm':
60
+ assert isinstance(module, nn.modules.conv._ConvNd)
61
+ return ConvLayerNorm(module.out_channels, **norm_kwargs)
62
+ elif norm == 'time_group_norm':
63
+ if causal:
64
+ raise ValueError("GroupNorm doesn't support causal evaluation.")
65
+ assert isinstance(module, nn.modules.conv._ConvNd)
66
+ return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
67
+ else:
68
+ return nn.Identity()
69
+
70
+
71
+ def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
72
+ padding_total: int = 0) -> int:
73
+ """See `pad_for_conv1d`.
74
+ """
75
+ length = x.shape[-1]
76
+ n_frames = (length - kernel_size + padding_total) / stride + 1
77
+ ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
78
+ return ideal_length - length
79
+
80
+
81
+ def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0):
82
+ """Pad for a convolution to make sure that the last window is full.
83
+ Extra padding is added at the end. This is required to ensure that we can rebuild
84
+ an output of the same length, as otherwise, even with padding, some time steps
85
+ might get removed.
86
+ For instance, with total padding = 4, kernel size = 4, stride = 2:
87
+ 0 0 1 2 3 4 5 0 0 # (0s are padding)
88
+ 1 2 3 # (output frames of a convolution, last 0 is never used)
89
+ 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
90
+ 1 2 3 4 # once you removed padding, we are missing one time step !
91
+ """
92
+ extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
93
+ return F.pad(x, (0, extra_padding))
94
+
95
+
96
+ def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'zero', value: float = 0.):
97
+ """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
98
+ If this is the case, we insert extra 0 padding to the right before the reflection happen.
99
+ """
100
+ length = x.shape[-1]
101
+ padding_left, padding_right = paddings
102
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
103
+ if mode == 'reflect':
104
+ max_pad = max(padding_left, padding_right)
105
+ extra_pad = 0
106
+ if length <= max_pad:
107
+ extra_pad = max_pad - length + 1
108
+ x = F.pad(x, (0, extra_pad))
109
+ padded = F.pad(x, paddings, mode, value)
110
+ end = padded.shape[-1] - extra_pad
111
+ return padded[..., :end]
112
+ else:
113
+ return F.pad(x, paddings, mode, value)
114
+
115
+
116
+ def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
117
+ """Remove padding from x, handling properly zero padding. Only for 1d!"""
118
+ padding_left, padding_right = paddings
119
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
120
+ assert (padding_left + padding_right) <= x.shape[-1]
121
+ end = x.shape[-1] - padding_right
122
+ return x[..., padding_left: end]
123
+
124
+
125
+ class NormConv1d(nn.Module):
126
+ """Wrapper around Conv1d and normalization applied to this conv
127
+ to provide a uniform interface across normalization approaches.
128
+ """
129
+ def __init__(self, *args, causal: bool = False, norm: str = 'none',
130
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
131
+ super().__init__()
132
+ self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
133
+ self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
134
+ self.norm_type = norm
135
+
136
+ def forward(self, x):
137
+ x = self.conv(x)
138
+ x = self.norm(x)
139
+ return x
140
+
141
+
142
+ class NormConv2d(nn.Module):
143
+ """Wrapper around Conv2d and normalization applied to this conv
144
+ to provide a uniform interface across normalization approaches.
145
+ """
146
+ def __init__(self, *args, norm: str = 'none',
147
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
148
+ super().__init__()
149
+ self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm)
150
+ self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs)
151
+ self.norm_type = norm
152
+
153
+ def forward(self, x):
154
+ x = self.conv(x)
155
+ x = self.norm(x)
156
+ return x
157
+
158
+
159
+ class NormConvTranspose1d(nn.Module):
160
+ """Wrapper around ConvTranspose1d and normalization applied to this conv
161
+ to provide a uniform interface across normalization approaches.
162
+ """
163
+ def __init__(self, *args, causal: bool = False, norm: str = 'none',
164
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
165
+ super().__init__()
166
+ self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm)
167
+ self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
168
+ self.norm_type = norm
169
+
170
+ def forward(self, x):
171
+ x = self.convtr(x)
172
+ x = self.norm(x)
173
+ return x
174
+
175
+
176
+ class NormConvTranspose2d(nn.Module):
177
+ """Wrapper around ConvTranspose2d and normalization applied to this conv
178
+ to provide a uniform interface across normalization approaches.
179
+ """
180
+ def __init__(self, *args, norm: str = 'none',
181
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
182
+ super().__init__()
183
+ self.convtr = apply_parametrization_norm(nn.ConvTranspose2d(*args, **kwargs), norm)
184
+ self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs)
185
+
186
+ def forward(self, x):
187
+ x = self.convtr(x)
188
+ x = self.norm(x)
189
+ return x
190
+
191
+
192
+ class SConv1d(nn.Module):
193
+ """Conv1d with some builtin handling of asymmetric or causal padding
194
+ and normalization.
195
+ """
196
+ def __init__(self, in_channels: int, out_channels: int,
197
+ kernel_size: int, stride: int = 1, dilation: int = 1,
198
+ groups: int = 1, bias: bool = True, causal: bool = False,
199
+ norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {},
200
+ pad_mode: str = 'reflect', **kwargs):
201
+ super().__init__()
202
+ # warn user on unusual setup between dilation and stride
203
+ if stride > 1 and dilation > 1:
204
+ warnings.warn('SConv1d has been initialized with stride > 1 and dilation > 1'
205
+ f' (kernel_size={kernel_size} stride={stride}, dilation={dilation}).')
206
+ self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride,
207
+ dilation=dilation, groups=groups, bias=bias, causal=causal,
208
+ norm=norm, norm_kwargs=norm_kwargs)
209
+ self.causal = causal
210
+ self.pad_mode = pad_mode
211
+
212
+ def forward(self, x):
213
+ B, C, T = x.shape
214
+ kernel_size = self.conv.conv.kernel_size[0]
215
+ stride = self.conv.conv.stride[0]
216
+ dilation = self.conv.conv.dilation[0]
217
+ kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations
218
+ padding_total = kernel_size - stride
219
+ extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
220
+ if self.causal:
221
+ # Left padding for causal
222
+ x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
223
+ else:
224
+ # Asymmetric padding required for odd strides
225
+ padding_right = padding_total // 2
226
+ padding_left = padding_total - padding_right
227
+ x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode)
228
+ return self.conv(x)
229
+
230
+
231
+ class SConvTranspose1d(nn.Module):
232
+ """ConvTranspose1d with some builtin handling of asymmetric or causal padding
233
+ and normalization.
234
+ """
235
+ def __init__(self, in_channels: int, out_channels: int,
236
+ kernel_size: int, stride: int = 1, causal: bool = False,
237
+ norm: str = 'none', trim_right_ratio: float = 1.,
238
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
239
+ super().__init__()
240
+ self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride,
241
+ causal=causal, norm=norm, norm_kwargs=norm_kwargs)
242
+ self.causal = causal
243
+ self.trim_right_ratio = trim_right_ratio
244
+ assert self.causal or self.trim_right_ratio == 1., \
245
+ "`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
246
+ assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1.
247
+
248
+ def forward(self, x):
249
+ kernel_size = self.convtr.convtr.kernel_size[0]
250
+ stride = self.convtr.convtr.stride[0]
251
+ padding_total = kernel_size - stride
252
+
253
+ y = self.convtr(x)
254
+
255
+ # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
256
+ # removed at the very end, when keeping only the right length for the output,
257
+ # as removing it here would require also passing the length at the matching layer
258
+ # in the encoder.
259
+ if self.causal:
260
+ # Trim the padding on the right according to the specified ratio
261
+ # if trim_right_ratio = 1.0, trim everything from right
262
+ padding_right = math.ceil(padding_total * self.trim_right_ratio)
263
+ padding_left = padding_total - padding_right
264
+ y = unpad1d(y, (padding_left, padding_right))
265
+ else:
266
+ # Asymmetric padding required for odd strides
267
+ padding_right = padding_total // 2
268
+ padding_left = padding_total - padding_right
269
+ y = unpad1d(y, (padding_left, padding_right))
270
+ return y
271
+
272
+ class SLSTM(nn.Module):
273
+ """
274
+ LSTM without worrying about the hidden state, nor the layout of the data.
275
+ Expects input as convolutional layout.
276
+ """
277
+ def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True):
278
+ super().__init__()
279
+ self.skip = skip
280
+ self.lstm = nn.LSTM(dimension, dimension, num_layers)
281
+
282
+ def forward(self, x):
283
+ x = x.permute(2, 0, 1)
284
+ y, _ = self.lstm(x)
285
+ if self.skip:
286
+ y = y + x
287
+ y = y.permute(1, 2, 0)
288
+ return y
dac/nn/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from . import layers
2
+ from . import loss
3
+ from . import quantize
dac/nn/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (211 Bytes). View file
 
dac/nn/__pycache__/layers.cpython-310.pyc ADDED
Binary file (1.44 kB). View file
 
dac/nn/__pycache__/loss.cpython-310.pyc ADDED
Binary file (11.6 kB). View file
 
dac/nn/__pycache__/quantize.cpython-310.pyc ADDED
Binary file (8.65 kB). View file
 
dac/nn/layers.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from einops import rearrange
6
+ from torch.nn.utils import weight_norm
7
+
8
+
9
+ def WNConv1d(*args, **kwargs):
10
+ return weight_norm(nn.Conv1d(*args, **kwargs))
11
+
12
+
13
+ def WNConvTranspose1d(*args, **kwargs):
14
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
15
+
16
+
17
+ # Scripting this brings model speed up 1.4x
18
+ @torch.jit.script
19
+ def snake(x, alpha):
20
+ shape = x.shape
21
+ x = x.reshape(shape[0], shape[1], -1)
22
+ x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
23
+ x = x.reshape(shape)
24
+ return x
25
+
26
+
27
+ class Snake1d(nn.Module):
28
+ def __init__(self, channels):
29
+ super().__init__()
30
+ self.alpha = nn.Parameter(torch.ones(1, channels, 1))
31
+
32
+ def forward(self, x):
33
+ return snake(x, self.alpha)
dac/nn/loss.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import typing
2
+ from typing import List
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from audiotools import AudioSignal
7
+ from audiotools import STFTParams
8
+ from torch import nn
9
+
10
+
11
+ class L1Loss(nn.L1Loss):
12
+ """L1 Loss between AudioSignals. Defaults
13
+ to comparing ``audio_data``, but any
14
+ attribute of an AudioSignal can be used.
15
+
16
+ Parameters
17
+ ----------
18
+ attribute : str, optional
19
+ Attribute of signal to compare, defaults to ``audio_data``.
20
+ weight : float, optional
21
+ Weight of this loss, defaults to 1.0.
22
+
23
+ Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py
24
+ """
25
+
26
+ def __init__(self, attribute: str = "audio_data", weight: float = 1.0, **kwargs):
27
+ self.attribute = attribute
28
+ self.weight = weight
29
+ super().__init__(**kwargs)
30
+
31
+ def forward(self, x: AudioSignal, y: AudioSignal):
32
+ """
33
+ Parameters
34
+ ----------
35
+ x : AudioSignal
36
+ Estimate AudioSignal
37
+ y : AudioSignal
38
+ Reference AudioSignal
39
+
40
+ Returns
41
+ -------
42
+ torch.Tensor
43
+ L1 loss between AudioSignal attributes.
44
+ """
45
+ if isinstance(x, AudioSignal):
46
+ x = getattr(x, self.attribute)
47
+ y = getattr(y, self.attribute)
48
+ return super().forward(x, y)
49
+
50
+
51
+ class SISDRLoss(nn.Module):
52
+ """
53
+ Computes the Scale-Invariant Source-to-Distortion Ratio between a batch
54
+ of estimated and reference audio signals or aligned features.
55
+
56
+ Parameters
57
+ ----------
58
+ scaling : int, optional
59
+ Whether to use scale-invariant (True) or
60
+ signal-to-noise ratio (False), by default True
61
+ reduction : str, optional
62
+ How to reduce across the batch (either 'mean',
63
+ 'sum', or none).], by default ' mean'
64
+ zero_mean : int, optional
65
+ Zero mean the references and estimates before
66
+ computing the loss, by default True
67
+ clip_min : int, optional
68
+ The minimum possible loss value. Helps network
69
+ to not focus on making already good examples better, by default None
70
+ weight : float, optional
71
+ Weight of this loss, defaults to 1.0.
72
+
73
+ Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py
74
+ """
75
+
76
+ def __init__(
77
+ self,
78
+ scaling: int = True,
79
+ reduction: str = "mean",
80
+ zero_mean: int = True,
81
+ clip_min: int = None,
82
+ weight: float = 1.0,
83
+ ):
84
+ self.scaling = scaling
85
+ self.reduction = reduction
86
+ self.zero_mean = zero_mean
87
+ self.clip_min = clip_min
88
+ self.weight = weight
89
+ super().__init__()
90
+
91
+ def forward(self, x: AudioSignal, y: AudioSignal):
92
+ eps = 1e-8
93
+ # nb, nc, nt
94
+ if isinstance(x, AudioSignal):
95
+ references = x.audio_data
96
+ estimates = y.audio_data
97
+ else:
98
+ references = x
99
+ estimates = y
100
+
101
+ nb = references.shape[0]
102
+ references = references.reshape(nb, 1, -1).permute(0, 2, 1)
103
+ estimates = estimates.reshape(nb, 1, -1).permute(0, 2, 1)
104
+
105
+ # samples now on axis 1
106
+ if self.zero_mean:
107
+ mean_reference = references.mean(dim=1, keepdim=True)
108
+ mean_estimate = estimates.mean(dim=1, keepdim=True)
109
+ else:
110
+ mean_reference = 0
111
+ mean_estimate = 0
112
+
113
+ _references = references - mean_reference
114
+ _estimates = estimates - mean_estimate
115
+
116
+ references_projection = (_references**2).sum(dim=-2) + eps
117
+ references_on_estimates = (_estimates * _references).sum(dim=-2) + eps
118
+
119
+ scale = (
120
+ (references_on_estimates / references_projection).unsqueeze(1)
121
+ if self.scaling
122
+ else 1
123
+ )
124
+
125
+ e_true = scale * _references
126
+ e_res = _estimates - e_true
127
+
128
+ signal = (e_true**2).sum(dim=1)
129
+ noise = (e_res**2).sum(dim=1)
130
+ sdr = -10 * torch.log10(signal / noise + eps)
131
+
132
+ if self.clip_min is not None:
133
+ sdr = torch.clamp(sdr, min=self.clip_min)
134
+
135
+ if self.reduction == "mean":
136
+ sdr = sdr.mean()
137
+ elif self.reduction == "sum":
138
+ sdr = sdr.sum()
139
+ return sdr
140
+
141
+
142
+ class MultiScaleSTFTLoss(nn.Module):
143
+ """Computes the multi-scale STFT loss from [1].
144
+
145
+ Parameters
146
+ ----------
147
+ window_lengths : List[int], optional
148
+ Length of each window of each STFT, by default [2048, 512]
149
+ loss_fn : typing.Callable, optional
150
+ How to compare each loss, by default nn.L1Loss()
151
+ clamp_eps : float, optional
152
+ Clamp on the log magnitude, below, by default 1e-5
153
+ mag_weight : float, optional
154
+ Weight of raw magnitude portion of loss, by default 1.0
155
+ log_weight : float, optional
156
+ Weight of log magnitude portion of loss, by default 1.0
157
+ pow : float, optional
158
+ Power to raise magnitude to before taking log, by default 2.0
159
+ weight : float, optional
160
+ Weight of this loss, by default 1.0
161
+ match_stride : bool, optional
162
+ Whether to match the stride of convolutional layers, by default False
163
+
164
+ References
165
+ ----------
166
+
167
+ 1. Engel, Jesse, Chenjie Gu, and Adam Roberts.
168
+ "DDSP: Differentiable Digital Signal Processing."
169
+ International Conference on Learning Representations. 2019.
170
+
171
+ Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
172
+ """
173
+
174
+ def __init__(
175
+ self,
176
+ window_lengths: List[int] = [2048, 512],
177
+ loss_fn: typing.Callable = nn.L1Loss(),
178
+ clamp_eps: float = 1e-5,
179
+ mag_weight: float = 1.0,
180
+ log_weight: float = 1.0,
181
+ pow: float = 2.0,
182
+ weight: float = 1.0,
183
+ match_stride: bool = False,
184
+ window_type: str = None,
185
+ ):
186
+ super().__init__()
187
+ self.stft_params = [
188
+ STFTParams(
189
+ window_length=w,
190
+ hop_length=w // 4,
191
+ match_stride=match_stride,
192
+ window_type=window_type,
193
+ )
194
+ for w in window_lengths
195
+ ]
196
+ self.loss_fn = loss_fn
197
+ self.log_weight = log_weight
198
+ self.mag_weight = mag_weight
199
+ self.clamp_eps = clamp_eps
200
+ self.weight = weight
201
+ self.pow = pow
202
+
203
+ def forward(self, x: AudioSignal, y: AudioSignal):
204
+ """Computes multi-scale STFT between an estimate and a reference
205
+ signal.
206
+
207
+ Parameters
208
+ ----------
209
+ x : AudioSignal
210
+ Estimate signal
211
+ y : AudioSignal
212
+ Reference signal
213
+
214
+ Returns
215
+ -------
216
+ torch.Tensor
217
+ Multi-scale STFT loss.
218
+ """
219
+ loss = 0.0
220
+ for s in self.stft_params:
221
+ x.stft(s.window_length, s.hop_length, s.window_type)
222
+ y.stft(s.window_length, s.hop_length, s.window_type)
223
+ loss += self.log_weight * self.loss_fn(
224
+ x.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
225
+ y.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
226
+ )
227
+ loss += self.mag_weight * self.loss_fn(x.magnitude, y.magnitude)
228
+ return loss
229
+
230
+
231
+ class MelSpectrogramLoss(nn.Module):
232
+ """Compute distance between mel spectrograms. Can be used
233
+ in a multi-scale way.
234
+
235
+ Parameters
236
+ ----------
237
+ n_mels : List[int]
238
+ Number of mels per STFT, by default [150, 80],
239
+ window_lengths : List[int], optional
240
+ Length of each window of each STFT, by default [2048, 512]
241
+ loss_fn : typing.Callable, optional
242
+ How to compare each loss, by default nn.L1Loss()
243
+ clamp_eps : float, optional
244
+ Clamp on the log magnitude, below, by default 1e-5
245
+ mag_weight : float, optional
246
+ Weight of raw magnitude portion of loss, by default 1.0
247
+ log_weight : float, optional
248
+ Weight of log magnitude portion of loss, by default 1.0
249
+ pow : float, optional
250
+ Power to raise magnitude to before taking log, by default 2.0
251
+ weight : float, optional
252
+ Weight of this loss, by default 1.0
253
+ match_stride : bool, optional
254
+ Whether to match the stride of convolutional layers, by default False
255
+
256
+ Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
257
+ """
258
+
259
+ def __init__(
260
+ self,
261
+ n_mels: List[int] = [150, 80],
262
+ window_lengths: List[int] = [2048, 512],
263
+ loss_fn: typing.Callable = nn.L1Loss(),
264
+ clamp_eps: float = 1e-5,
265
+ mag_weight: float = 1.0,
266
+ log_weight: float = 1.0,
267
+ pow: float = 2.0,
268
+ weight: float = 1.0,
269
+ match_stride: bool = False,
270
+ mel_fmin: List[float] = [0.0, 0.0],
271
+ mel_fmax: List[float] = [None, None],
272
+ window_type: str = None,
273
+ ):
274
+ super().__init__()
275
+ self.stft_params = [
276
+ STFTParams(
277
+ window_length=w,
278
+ hop_length=w // 4,
279
+ match_stride=match_stride,
280
+ window_type=window_type,
281
+ )
282
+ for w in window_lengths
283
+ ]
284
+ self.n_mels = n_mels
285
+ self.loss_fn = loss_fn
286
+ self.clamp_eps = clamp_eps
287
+ self.log_weight = log_weight
288
+ self.mag_weight = mag_weight
289
+ self.weight = weight
290
+ self.mel_fmin = mel_fmin
291
+ self.mel_fmax = mel_fmax
292
+ self.pow = pow
293
+
294
+ def forward(self, x: AudioSignal, y: AudioSignal):
295
+ """Computes mel loss between an estimate and a reference
296
+ signal.
297
+
298
+ Parameters
299
+ ----------
300
+ x : AudioSignal
301
+ Estimate signal
302
+ y : AudioSignal
303
+ Reference signal
304
+
305
+ Returns
306
+ -------
307
+ torch.Tensor
308
+ Mel loss.
309
+ """
310
+ loss = 0.0
311
+ for n_mels, fmin, fmax, s in zip(
312
+ self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params
313
+ ):
314
+ kwargs = {
315
+ "window_length": s.window_length,
316
+ "hop_length": s.hop_length,
317
+ "window_type": s.window_type,
318
+ }
319
+ x_mels = x.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
320
+ y_mels = y.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
321
+
322
+ loss += self.log_weight * self.loss_fn(
323
+ x_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
324
+ y_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
325
+ )
326
+ loss += self.mag_weight * self.loss_fn(x_mels, y_mels)
327
+ return loss
328
+
329
+
330
+ class GANLoss(nn.Module):
331
+ """
332
+ Computes a discriminator loss, given a discriminator on
333
+ generated waveforms/spectrograms compared to ground truth
334
+ waveforms/spectrograms. Computes the loss for both the
335
+ discriminator and the generator in separate functions.
336
+ """
337
+
338
+ def __init__(self, discriminator):
339
+ super().__init__()
340
+ self.discriminator = discriminator
341
+
342
+ def forward(self, fake, real):
343
+ d_fake = self.discriminator(fake.audio_data)
344
+ d_real = self.discriminator(real.audio_data)
345
+ return d_fake, d_real
346
+
347
+ def discriminator_loss(self, fake, real):
348
+ d_fake, d_real = self.forward(fake.clone().detach(), real)
349
+
350
+ loss_d = 0
351
+ for x_fake, x_real in zip(d_fake, d_real):
352
+ loss_d += torch.mean(x_fake[-1] ** 2)
353
+ loss_d += torch.mean((1 - x_real[-1]) ** 2)
354
+ return loss_d
355
+
356
+ def generator_loss(self, fake, real):
357
+ d_fake, d_real = self.forward(fake, real)
358
+
359
+ loss_g = 0
360
+ for x_fake in d_fake:
361
+ loss_g += torch.mean((1 - x_fake[-1]) ** 2)
362
+
363
+ loss_feature = 0
364
+
365
+ for i in range(len(d_fake)):
366
+ for j in range(len(d_fake[i]) - 1):
367
+ loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach())
368
+ return loss_g, loss_feature
dac/nn/quantize.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from einops import rearrange
8
+ from torch.nn.utils import weight_norm
9
+
10
+ from dac.nn.layers import WNConv1d
11
+
12
+
13
+ class VectorQuantize(nn.Module):
14
+ """
15
+ Implementation of VQ similar to Karpathy's repo:
16
+ https://github.com/karpathy/deep-vector-quantization
17
+ Additionally uses following tricks from Improved VQGAN
18
+ (https://arxiv.org/pdf/2110.04627.pdf):
19
+ 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
20
+ for improved codebook usage
21
+ 2. l2-normalized codes: Converts euclidean distance to cosine similarity which
22
+ improves training stability
23
+ """
24
+
25
+ def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int):
26
+ super().__init__()
27
+ self.codebook_size = codebook_size
28
+ self.codebook_dim = codebook_dim
29
+
30
+ self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
31
+ self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
32
+ self.codebook = nn.Embedding(codebook_size, codebook_dim)
33
+
34
+ def forward(self, z):
35
+ """Quantized the input tensor using a fixed codebook and returns
36
+ the corresponding codebook vectors
37
+
38
+ Parameters
39
+ ----------
40
+ z : Tensor[B x D x T]
41
+
42
+ Returns
43
+ -------
44
+ Tensor[B x D x T]
45
+ Quantized continuous representation of input
46
+ Tensor[1]
47
+ Commitment loss to train encoder to predict vectors closer to codebook
48
+ entries
49
+ Tensor[1]
50
+ Codebook loss to update the codebook
51
+ Tensor[B x T]
52
+ Codebook indices (quantized discrete representation of input)
53
+ Tensor[B x D x T]
54
+ Projected latents (continuous representation of input before quantization)
55
+ """
56
+
57
+ # Factorized codes (ViT-VQGAN) Project input into low-dimensional space
58
+ z_e = self.in_proj(z) # z_e : (B x D x T)
59
+ z_q, indices = self.decode_latents(z_e)
60
+
61
+ commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
62
+ codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
63
+
64
+ z_q = (
65
+ z_e + (z_q - z_e).detach()
66
+ ) # noop in forward pass, straight-through gradient estimator in backward pass
67
+
68
+ z_q = self.out_proj(z_q)
69
+
70
+ return z_q, commitment_loss, codebook_loss, indices, z_e
71
+
72
+ def embed_code(self, embed_id):
73
+ return F.embedding(embed_id, self.codebook.weight)
74
+
75
+ def decode_code(self, embed_id):
76
+ return self.embed_code(embed_id).transpose(1, 2)
77
+
78
+ def decode_latents(self, latents):
79
+ encodings = rearrange(latents, "b d t -> (b t) d")
80
+ codebook = self.codebook.weight # codebook: (N x D)
81
+
82
+ # L2 normalize encodings and codebook (ViT-VQGAN)
83
+ encodings = F.normalize(encodings)
84
+ codebook = F.normalize(codebook)
85
+
86
+ # Compute euclidean distance with codebook
87
+ dist = (
88
+ encodings.pow(2).sum(1, keepdim=True)
89
+ - 2 * encodings @ codebook.t()
90
+ + codebook.pow(2).sum(1, keepdim=True).t()
91
+ )
92
+ indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
93
+ z_q = self.decode_code(indices)
94
+ return z_q, indices
95
+
96
+
97
+ class ResidualVectorQuantize(nn.Module):
98
+ """
99
+ Introduced in SoundStream: An end2end neural audio codec
100
+ https://arxiv.org/abs/2107.03312
101
+ """
102
+
103
+ def __init__(
104
+ self,
105
+ input_dim: int = 512,
106
+ n_codebooks: int = 9,
107
+ codebook_size: int = 1024,
108
+ codebook_dim: Union[int, list] = 8,
109
+ quantizer_dropout: float = 0.0,
110
+ ):
111
+ super().__init__()
112
+ if isinstance(codebook_dim, int):
113
+ codebook_dim = [codebook_dim for _ in range(n_codebooks)]
114
+
115
+ self.n_codebooks = n_codebooks
116
+ self.codebook_dim = codebook_dim
117
+ self.codebook_size = codebook_size
118
+
119
+ self.quantizers = nn.ModuleList(
120
+ [
121
+ VectorQuantize(input_dim, codebook_size, codebook_dim[i])
122
+ for i in range(n_codebooks)
123
+ ]
124
+ )
125
+ self.quantizer_dropout = quantizer_dropout
126
+
127
+ def forward(self, z, n_quantizers: int = None):
128
+ """Quantized the input tensor using a fixed set of `n` codebooks and returns
129
+ the corresponding codebook vectors
130
+ Parameters
131
+ ----------
132
+ z : Tensor[B x D x T]
133
+ n_quantizers : int, optional
134
+ No. of quantizers to use
135
+ (n_quantizers < self.n_codebooks ex: for quantizer dropout)
136
+ Note: if `self.quantizer_dropout` is True, this argument is ignored
137
+ when in training mode, and a random number of quantizers is used.
138
+ Returns
139
+ -------
140
+ dict
141
+ A dictionary with the following keys:
142
+
143
+ "z" : Tensor[B x D x T]
144
+ Quantized continuous representation of input
145
+ "codes" : Tensor[B x N x T]
146
+ Codebook indices for each codebook
147
+ (quantized discrete representation of input)
148
+ "latents" : Tensor[B x N*D x T]
149
+ Projected latents (continuous representation of input before quantization)
150
+ "vq/commitment_loss" : Tensor[1]
151
+ Commitment loss to train encoder to predict vectors closer to codebook
152
+ entries
153
+ "vq/codebook_loss" : Tensor[1]
154
+ Codebook loss to update the codebook
155
+ """
156
+ z_q = 0
157
+ residual = z
158
+ commitment_loss = 0
159
+ codebook_loss = 0
160
+
161
+ codebook_indices = []
162
+ latents = []
163
+
164
+ if n_quantizers is None:
165
+ n_quantizers = self.n_codebooks
166
+ if self.training:
167
+ n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
168
+ dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
169
+ n_dropout = int(z.shape[0] * self.quantizer_dropout)
170
+ n_quantizers[:n_dropout] = dropout[:n_dropout]
171
+ n_quantizers = n_quantizers.to(z.device)
172
+
173
+ for i, quantizer in enumerate(self.quantizers):
174
+ if self.training is False and i >= n_quantizers:
175
+ break
176
+
177
+ z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
178
+ residual
179
+ )
180
+
181
+ # Create mask to apply quantizer dropout
182
+ mask = (
183
+ torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
184
+ )
185
+ z_q = z_q + z_q_i * mask[:, None, None]
186
+ residual = residual - z_q_i
187
+
188
+ # Sum losses
189
+ commitment_loss += (commitment_loss_i * mask).mean()
190
+ codebook_loss += (codebook_loss_i * mask).mean()
191
+
192
+ codebook_indices.append(indices_i)
193
+ latents.append(z_e_i)
194
+
195
+ codes = torch.stack(codebook_indices, dim=1)
196
+ latents = torch.cat(latents, dim=1)
197
+
198
+ return z_q, codes, latents, commitment_loss, codebook_loss
199
+
200
+ def from_codes(self, codes: torch.Tensor):
201
+ """Given the quantized codes, reconstruct the continuous representation
202
+ Parameters
203
+ ----------
204
+ codes : Tensor[B x N x T]
205
+ Quantized discrete representation of input
206
+ Returns
207
+ -------
208
+ Tensor[B x D x T]
209
+ Quantized continuous representation of input
210
+ """
211
+ z_q = 0.0
212
+ z_p = []
213
+ n_codebooks = codes.shape[1]
214
+ for i in range(n_codebooks):
215
+ z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
216
+ z_p.append(z_p_i)
217
+
218
+ z_q_i = self.quantizers[i].out_proj(z_p_i)
219
+ z_q = z_q + z_q_i
220
+ return z_q, torch.cat(z_p, dim=1), codes
221
+
222
+ def from_latents(self, latents: torch.Tensor):
223
+ """Given the unquantized latents, reconstruct the
224
+ continuous representation after quantization.
225
+
226
+ Parameters
227
+ ----------
228
+ latents : Tensor[B x N x T]
229
+ Continuous representation of input after projection
230
+
231
+ Returns
232
+ -------
233
+ Tensor[B x D x T]
234
+ Quantized representation of full-projected space
235
+ Tensor[B x D x T]
236
+ Quantized representation of latent space
237
+ """
238
+ z_q = 0
239
+ z_p = []
240
+ codes = []
241
+ dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
242
+
243
+ n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[
244
+ 0
245
+ ]
246
+ for i in range(n_codebooks):
247
+ j, k = dims[i], dims[i + 1]
248
+ z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
249
+ z_p.append(z_p_i)
250
+ codes.append(codes_i)
251
+
252
+ z_q_i = self.quantizers[i].out_proj(z_p_i)
253
+ z_q = z_q + z_q_i
254
+
255
+ return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
256
+
257
+
258
+ if __name__ == "__main__":
259
+ rvq = ResidualVectorQuantize(quantizer_dropout=True)
260
+ x = torch.randn(16, 512, 80)
261
+ y = rvq(x)
262
+ print(y["latents"].shape)
dac/utils/__init__.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import argbind
4
+ from audiotools import ml
5
+
6
+ import dac
7
+
8
+ DAC = dac.model.DAC
9
+ Accelerator = ml.Accelerator
10
+
11
+ __MODEL_LATEST_TAGS__ = {
12
+ ("44khz", "8kbps"): "0.0.1",
13
+ ("24khz", "8kbps"): "0.0.4",
14
+ ("16khz", "8kbps"): "0.0.5",
15
+ ("44khz", "16kbps"): "1.0.0",
16
+ }
17
+
18
+ __MODEL_URLS__ = {
19
+ (
20
+ "44khz",
21
+ "0.0.1",
22
+ "8kbps",
23
+ ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.1/weights.pth",
24
+ (
25
+ "24khz",
26
+ "0.0.4",
27
+ "8kbps",
28
+ ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.4/weights_24khz.pth",
29
+ (
30
+ "16khz",
31
+ "0.0.5",
32
+ "8kbps",
33
+ ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.5/weights_16khz.pth",
34
+ (
35
+ "44khz",
36
+ "1.0.0",
37
+ "16kbps",
38
+ ): "https://github.com/descriptinc/descript-audio-codec/releases/download/1.0.0/weights_44khz_16kbps.pth",
39
+ }
40
+
41
+
42
+ @argbind.bind(group="download", positional=True, without_prefix=True)
43
+ def download(
44
+ model_type: str = "44khz", model_bitrate: str = "8kbps", tag: str = "latest"
45
+ ):
46
+ """
47
+ Function that downloads the weights file from URL if a local cache is not found.
48
+
49
+ Parameters
50
+ ----------
51
+ model_type : str
52
+ The type of model to download. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz".
53
+ model_bitrate: str
54
+ Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps".
55
+ Only 44khz model supports 16kbps.
56
+ tag : str
57
+ The tag of the model to download. Defaults to "latest".
58
+
59
+ Returns
60
+ -------
61
+ Path
62
+ Directory path required to load model via audiotools.
63
+ """
64
+ model_type = model_type.lower()
65
+ tag = tag.lower()
66
+
67
+ assert model_type in [
68
+ "44khz",
69
+ "24khz",
70
+ "16khz",
71
+ ], "model_type must be one of '44khz', '24khz', or '16khz'"
72
+
73
+ assert model_bitrate in [
74
+ "8kbps",
75
+ "16kbps",
76
+ ], "model_bitrate must be one of '8kbps', or '16kbps'"
77
+
78
+ if tag == "latest":
79
+ tag = __MODEL_LATEST_TAGS__[(model_type, model_bitrate)]
80
+
81
+ download_link = __MODEL_URLS__.get((model_type, tag, model_bitrate), None)
82
+
83
+ if download_link is None:
84
+ raise ValueError(
85
+ f"Could not find model with tag {tag} and model type {model_type}"
86
+ )
87
+
88
+ local_path = (
89
+ Path.home()
90
+ / ".cache"
91
+ / "descript"
92
+ / "dac"
93
+ / f"weights_{model_type}_{model_bitrate}_{tag}.pth"
94
+ )
95
+ if not local_path.exists():
96
+ local_path.parent.mkdir(parents=True, exist_ok=True)
97
+
98
+ # Download the model
99
+ import requests
100
+
101
+ response = requests.get(download_link)
102
+
103
+ if response.status_code != 200:
104
+ raise ValueError(
105
+ f"Could not download model. Received response code {response.status_code}"
106
+ )
107
+ local_path.write_bytes(response.content)
108
+
109
+ return local_path
110
+
111
+
112
+ def load_model(
113
+ model_type: str = "44khz",
114
+ model_bitrate: str = "8kbps",
115
+ tag: str = "latest",
116
+ load_path: str = None,
117
+ ):
118
+ if not load_path:
119
+ load_path = download(
120
+ model_type=model_type, model_bitrate=model_bitrate, tag=tag
121
+ )
122
+ generator = DAC.load(load_path)
123
+ return generator
dac/utils/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (2.82 kB). View file
 
dac/utils/decode.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from pathlib import Path
3
+
4
+ import argbind
5
+ import numpy as np
6
+ import torch
7
+ from audiotools import AudioSignal
8
+ from tqdm import tqdm
9
+
10
+ from dac import DACFile
11
+ from dac.utils import load_model
12
+
13
+ warnings.filterwarnings("ignore", category=UserWarning)
14
+
15
+
16
+ @argbind.bind(group="decode", positional=True, without_prefix=True)
17
+ @torch.inference_mode()
18
+ @torch.no_grad()
19
+ def decode(
20
+ input: str,
21
+ output: str = "",
22
+ weights_path: str = "",
23
+ model_tag: str = "latest",
24
+ model_bitrate: str = "8kbps",
25
+ device: str = "cuda",
26
+ model_type: str = "44khz",
27
+ verbose: bool = False,
28
+ ):
29
+ """Decode audio from codes.
30
+
31
+ Parameters
32
+ ----------
33
+ input : str
34
+ Path to input directory or file
35
+ output : str, optional
36
+ Path to output directory, by default "".
37
+ If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`.
38
+ weights_path : str, optional
39
+ Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the
40
+ model_tag and model_type.
41
+ model_tag : str, optional
42
+ Tag of the model to use, by default "latest". Ignored if `weights_path` is specified.
43
+ model_bitrate: str
44
+ Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps".
45
+ device : str, optional
46
+ Device to use, by default "cuda". If "cpu", the model will be loaded on the CPU.
47
+ model_type : str, optional
48
+ The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified.
49
+ """
50
+ generator = load_model(
51
+ model_type=model_type,
52
+ model_bitrate=model_bitrate,
53
+ tag=model_tag,
54
+ load_path=weights_path,
55
+ )
56
+ generator.to(device)
57
+ generator.eval()
58
+
59
+ # Find all .dac files in input directory
60
+ _input = Path(input)
61
+ input_files = list(_input.glob("**/*.dac"))
62
+
63
+ # If input is a .dac file, add it to the list
64
+ if _input.suffix == ".dac":
65
+ input_files.append(_input)
66
+
67
+ # Create output directory
68
+ output = Path(output)
69
+ output.mkdir(parents=True, exist_ok=True)
70
+
71
+ for i in tqdm(range(len(input_files)), desc=f"Decoding files"):
72
+ # Load file
73
+ artifact = DACFile.load(input_files[i])
74
+
75
+ # Reconstruct audio from codes
76
+ recons = generator.decompress(artifact, verbose=verbose)
77
+
78
+ # Compute output path
79
+ relative_path = input_files[i].relative_to(input)
80
+ output_dir = output / relative_path.parent
81
+ if not relative_path.name:
82
+ output_dir = output
83
+ relative_path = input_files[i]
84
+ output_name = relative_path.with_suffix(".wav").name
85
+ output_path = output_dir / output_name
86
+ output_path.parent.mkdir(parents=True, exist_ok=True)
87
+
88
+ # Write to file
89
+ recons.write(output_path)
90
+
91
+
92
+ if __name__ == "__main__":
93
+ args = argbind.parse_args()
94
+ with argbind.scope(args):
95
+ decode()
dac/utils/encode.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import warnings
3
+ from pathlib import Path
4
+
5
+ import argbind
6
+ import numpy as np
7
+ import torch
8
+ from audiotools import AudioSignal
9
+ from audiotools.core import util
10
+ from tqdm import tqdm
11
+
12
+ from dac.utils import load_model
13
+
14
+ warnings.filterwarnings("ignore", category=UserWarning)
15
+
16
+
17
+ @argbind.bind(group="encode", positional=True, without_prefix=True)
18
+ @torch.inference_mode()
19
+ @torch.no_grad()
20
+ def encode(
21
+ input: str,
22
+ output: str = "",
23
+ weights_path: str = "",
24
+ model_tag: str = "latest",
25
+ model_bitrate: str = "8kbps",
26
+ n_quantizers: int = None,
27
+ device: str = "cuda",
28
+ model_type: str = "44khz",
29
+ win_duration: float = 5.0,
30
+ verbose: bool = False,
31
+ ):
32
+ """Encode audio files in input path to .dac format.
33
+
34
+ Parameters
35
+ ----------
36
+ input : str
37
+ Path to input audio file or directory
38
+ output : str, optional
39
+ Path to output directory, by default "". If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`.
40
+ weights_path : str, optional
41
+ Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the
42
+ model_tag and model_type.
43
+ model_tag : str, optional
44
+ Tag of the model to use, by default "latest". Ignored if `weights_path` is specified.
45
+ model_bitrate: str
46
+ Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps".
47
+ n_quantizers : int, optional
48
+ Number of quantizers to use, by default None. If not specified, all the quantizers will be used and the model will compress at maximum bitrate.
49
+ device : str, optional
50
+ Device to use, by default "cuda"
51
+ model_type : str, optional
52
+ The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified.
53
+ """
54
+ generator = load_model(
55
+ model_type=model_type,
56
+ model_bitrate=model_bitrate,
57
+ tag=model_tag,
58
+ load_path=weights_path,
59
+ )
60
+ generator.to(device)
61
+ generator.eval()
62
+ kwargs = {"n_quantizers": n_quantizers}
63
+
64
+ # Find all audio files in input path
65
+ input = Path(input)
66
+ audio_files = util.find_audio(input)
67
+
68
+ output = Path(output)
69
+ output.mkdir(parents=True, exist_ok=True)
70
+
71
+ for i in tqdm(range(len(audio_files)), desc="Encoding files"):
72
+ # Load file
73
+ signal = AudioSignal(audio_files[i])
74
+
75
+ # Encode audio to .dac format
76
+ artifact = generator.compress(signal, win_duration, verbose=verbose, **kwargs)
77
+
78
+ # Compute output path
79
+ relative_path = audio_files[i].relative_to(input)
80
+ output_dir = output / relative_path.parent
81
+ if not relative_path.name:
82
+ output_dir = output
83
+ relative_path = audio_files[i]
84
+ output_name = relative_path.with_suffix(".dac").name
85
+ output_path = output_dir / output_name
86
+ output_path.parent.mkdir(parents=True, exist_ok=True)
87
+
88
+ artifact.save(output_path)
89
+
90
+
91
+ if __name__ == "__main__":
92
+ args = argbind.parse_args()
93
+ with argbind.scope(args):
94
+ encode()
hf_utils.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ from huggingface_hub import hf_hub_download
4
+
5
+
6
+ def load_custom_model_from_hf(repo_id, model_filename="pytorch_model.bin", config_filename="config.yml"):
7
+ os.makedirs("./checkpoints", exist_ok=True)
8
+ model_path = hf_hub_download(repo_id=repo_id, filename=model_filename, cache_dir="./checkpoints")
9
+ config_path = hf_hub_download(repo_id=repo_id, filename=config_filename, cache_dir="./checkpoints")
10
+
11
+ return model_path, config_path
modules/__pycache__/attentions.cpython-310.pyc ADDED
Binary file (10.3 kB). View file
 
modules/__pycache__/commons.cpython-310.pyc ADDED
Binary file (12.6 kB). View file
 
modules/__pycache__/layers.cpython-310.pyc ADDED
Binary file (11 kB). View file
 
modules/__pycache__/mamba.cpython-310.pyc ADDED
Binary file (23.7 kB). View file
 
modules/__pycache__/quantize.cpython-310.pyc ADDED
Binary file (15.4 kB). View file
 
modules/__pycache__/redecoder.cpython-310.pyc ADDED
Binary file (3.2 kB). View file
 
modules/__pycache__/style_encoder.cpython-310.pyc ADDED
Binary file (3.05 kB). View file
 
modules/__pycache__/wavenet.cpython-310.pyc ADDED
Binary file (5.16 kB). View file
 
modules/attentions.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import math
3
+ import numpy as np
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+
8
+ from . import commons
9
+ class LayerNorm(nn.Module):
10
+ def __init__(self, channels, eps=1e-5):
11
+ super().__init__()
12
+ self.channels = channels
13
+ self.eps = eps
14
+
15
+ self.gamma = nn.Parameter(torch.ones(channels))
16
+ self.beta = nn.Parameter(torch.zeros(channels))
17
+
18
+ def forward(self, x):
19
+ x = x.transpose(1, -1)
20
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
21
+ return x.transpose(1, -1)
22
+
23
+
24
+ class Encoder(nn.Module):
25
+ def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., window_size=4,
26
+ **kwargs):
27
+ super().__init__()
28
+ self.hidden_channels = hidden_channels
29
+ self.filter_channels = filter_channels
30
+ self.n_heads = n_heads
31
+ self.n_layers = n_layers
32
+ self.kernel_size = kernel_size
33
+ self.p_dropout = p_dropout
34
+ self.window_size = window_size
35
+
36
+ self.drop = nn.Dropout(p_dropout)
37
+ self.attn_layers = nn.ModuleList()
38
+ self.norm_layers_1 = nn.ModuleList()
39
+ self.ffn_layers = nn.ModuleList()
40
+ self.norm_layers_2 = nn.ModuleList()
41
+ for i in range(self.n_layers):
42
+ self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout,
43
+ window_size=window_size))
44
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
45
+ self.ffn_layers.append(
46
+ FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout))
47
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
48
+
49
+ def forward(self, x, x_mask):
50
+ attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
51
+ x = x * x_mask
52
+ for i in range(self.n_layers):
53
+ y = self.attn_layers[i](x, x, attn_mask)
54
+ y = self.drop(y)
55
+ x = self.norm_layers_1[i](x + y)
56
+
57
+ y = self.ffn_layers[i](x, x_mask)
58
+ y = self.drop(y)
59
+ x = self.norm_layers_2[i](x + y)
60
+ x = x * x_mask
61
+ return x
62
+
63
+
64
+ class Decoder(nn.Module):
65
+ def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0.,
66
+ proximal_bias=False, proximal_init=True, **kwargs):
67
+ super().__init__()
68
+ self.hidden_channels = hidden_channels
69
+ self.filter_channels = filter_channels
70
+ self.n_heads = n_heads
71
+ self.n_layers = n_layers
72
+ self.kernel_size = kernel_size
73
+ self.p_dropout = p_dropout
74
+ self.proximal_bias = proximal_bias
75
+ self.proximal_init = proximal_init
76
+
77
+ self.drop = nn.Dropout(p_dropout)
78
+ self.self_attn_layers = nn.ModuleList()
79
+ self.norm_layers_0 = nn.ModuleList()
80
+ self.encdec_attn_layers = nn.ModuleList()
81
+ self.norm_layers_1 = nn.ModuleList()
82
+ self.ffn_layers = nn.ModuleList()
83
+ self.norm_layers_2 = nn.ModuleList()
84
+ for i in range(self.n_layers):
85
+ self.self_attn_layers.append(
86
+ MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout,
87
+ proximal_bias=proximal_bias, proximal_init=proximal_init))
88
+ self.norm_layers_0.append(LayerNorm(hidden_channels))
89
+ self.encdec_attn_layers.append(
90
+ MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout))
91
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
92
+ self.ffn_layers.append(
93
+ FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout, causal=True))
94
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
95
+
96
+ def forward(self, x, x_mask, h, h_mask):
97
+ """
98
+ x: decoder input
99
+ h: encoder output
100
+ """
101
+ self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype)
102
+ encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
103
+ x = x * x_mask
104
+ for i in range(self.n_layers):
105
+ y = self.self_attn_layers[i](x, x, self_attn_mask)
106
+ y = self.drop(y)
107
+ x = self.norm_layers_0[i](x + y)
108
+
109
+ y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
110
+ y = self.drop(y)
111
+ x = self.norm_layers_1[i](x + y)
112
+
113
+ y = self.ffn_layers[i](x, x_mask)
114
+ y = self.drop(y)
115
+ x = self.norm_layers_2[i](x + y)
116
+ x = x * x_mask
117
+ return x
118
+
119
+
120
+ class MultiHeadAttention(nn.Module):
121
+ def __init__(self, channels, out_channels, n_heads, p_dropout=0., window_size=None, heads_share=True,
122
+ block_length=None, proximal_bias=False, proximal_init=False):
123
+ super().__init__()
124
+ assert channels % n_heads == 0
125
+
126
+ self.channels = channels
127
+ self.out_channels = out_channels
128
+ self.n_heads = n_heads
129
+ self.p_dropout = p_dropout
130
+ self.window_size = window_size
131
+ self.heads_share = heads_share
132
+ self.block_length = block_length
133
+ self.proximal_bias = proximal_bias
134
+ self.proximal_init = proximal_init
135
+ self.attn = None
136
+
137
+ self.k_channels = channels // n_heads
138
+ self.conv_q = nn.Conv1d(channels, channels, 1)
139
+ self.conv_k = nn.Conv1d(channels, channels, 1)
140
+ self.conv_v = nn.Conv1d(channels, channels, 1)
141
+ self.conv_o = nn.Conv1d(channels, out_channels, 1)
142
+ self.drop = nn.Dropout(p_dropout)
143
+
144
+ if window_size is not None:
145
+ n_heads_rel = 1 if heads_share else n_heads
146
+ rel_stddev = self.k_channels ** -0.5
147
+ self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
148
+ self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
149
+
150
+ nn.init.xavier_uniform_(self.conv_q.weight)
151
+ nn.init.xavier_uniform_(self.conv_k.weight)
152
+ nn.init.xavier_uniform_(self.conv_v.weight)
153
+ if proximal_init:
154
+ with torch.no_grad():
155
+ self.conv_k.weight.copy_(self.conv_q.weight)
156
+ self.conv_k.bias.copy_(self.conv_q.bias)
157
+
158
+ def forward(self, x, c, attn_mask=None):
159
+ q = self.conv_q(x)
160
+ k = self.conv_k(c)
161
+ v = self.conv_v(c)
162
+
163
+ x, self.attn = self.attention(q, k, v, mask=attn_mask)
164
+
165
+ x = self.conv_o(x)
166
+ return x
167
+
168
+ def attention(self, query, key, value, mask=None):
169
+ # reshape [b, d, t] -> [b, n_h, t, d_k]
170
+ b, d, t_s, t_t = (*key.size(), query.size(2))
171
+ query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
172
+ key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
173
+ value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
174
+
175
+ scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
176
+ if self.window_size is not None:
177
+ assert t_s == t_t, "Relative attention is only available for self-attention."
178
+ key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
179
+ rel_logits = self._matmul_with_relative_keys(query / math.sqrt(self.k_channels), key_relative_embeddings)
180
+ scores_local = self._relative_position_to_absolute_position(rel_logits)
181
+ scores = scores + scores_local
182
+ if self.proximal_bias:
183
+ assert t_s == t_t, "Proximal bias is only available for self-attention."
184
+ scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
185
+ if mask is not None:
186
+ scores = scores.masked_fill(mask == 0, -1e4)
187
+ if self.block_length is not None:
188
+ assert t_s == t_t, "Local attention is only available for self-attention."
189
+ block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length)
190
+ scores = scores.masked_fill(block_mask == 0, -1e4)
191
+ p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
192
+ p_attn = self.drop(p_attn)
193
+ output = torch.matmul(p_attn, value)
194
+ if self.window_size is not None:
195
+ relative_weights = self._absolute_position_to_relative_position(p_attn)
196
+ value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
197
+ output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
198
+ output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t]
199
+ return output, p_attn
200
+
201
+ def _matmul_with_relative_values(self, x, y):
202
+ """
203
+ x: [b, h, l, m]
204
+ y: [h or 1, m, d]
205
+ ret: [b, h, l, d]
206
+ """
207
+ ret = torch.matmul(x, y.unsqueeze(0))
208
+ return ret
209
+
210
+ def _matmul_with_relative_keys(self, x, y):
211
+ """
212
+ x: [b, h, l, d]
213
+ y: [h or 1, m, d]
214
+ ret: [b, h, l, m]
215
+ """
216
+ ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
217
+ return ret
218
+
219
+ def _get_relative_embeddings(self, relative_embeddings, length):
220
+ max_relative_position = 2 * self.window_size + 1
221
+ # Pad first before slice to avoid using cond ops.
222
+ pad_length = max(length - (self.window_size + 1), 0)
223
+ slice_start_position = max((self.window_size + 1) - length, 0)
224
+ slice_end_position = slice_start_position + 2 * length - 1
225
+ if pad_length > 0:
226
+ padded_relative_embeddings = F.pad(
227
+ relative_embeddings,
228
+ commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]))
229
+ else:
230
+ padded_relative_embeddings = relative_embeddings
231
+ used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position]
232
+ return used_relative_embeddings
233
+
234
+ def _relative_position_to_absolute_position(self, x):
235
+ """
236
+ x: [b, h, l, 2*l-1]
237
+ ret: [b, h, l, l]
238
+ """
239
+ batch, heads, length, _ = x.size()
240
+ # Concat columns of pad to shift from relative to absolute indexing.
241
+ x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
242
+
243
+ # Concat extra elements so to add up to shape (len+1, 2*len-1).
244
+ x_flat = x.view([batch, heads, length * 2 * length])
245
+ x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]))
246
+
247
+ # Reshape and slice out the padded elements.
248
+ x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1:]
249
+ return x_final
250
+
251
+ def _absolute_position_to_relative_position(self, x):
252
+ """
253
+ x: [b, h, l, l]
254
+ ret: [b, h, l, 2*l-1]
255
+ """
256
+ batch, heads, length, _ = x.size()
257
+ # padd along column
258
+ x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]))
259
+ x_flat = x.view([batch, heads, length ** 2 + length * (length - 1)])
260
+ # add 0's in the beginning that will skew the elements after reshape
261
+ x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
262
+ x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
263
+ return x_final
264
+
265
+ def _attention_bias_proximal(self, length):
266
+ """Bias for self-attention to encourage attention to close positions.
267
+ Args:
268
+ length: an integer scalar.
269
+ Returns:
270
+ a Tensor with shape [1, 1, length, length]
271
+ """
272
+ r = torch.arange(length, dtype=torch.float32)
273
+ diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
274
+ return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
275
+
276
+
277
+ class FFN(nn.Module):
278
+ def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., activation=None,
279
+ causal=False):
280
+ super().__init__()
281
+ self.in_channels = in_channels
282
+ self.out_channels = out_channels
283
+ self.filter_channels = filter_channels
284
+ self.kernel_size = kernel_size
285
+ self.p_dropout = p_dropout
286
+ self.activation = activation
287
+ self.causal = causal
288
+
289
+ if causal:
290
+ self.padding = self._causal_padding
291
+ else:
292
+ self.padding = self._same_padding
293
+
294
+ self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
295
+ self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
296
+ self.drop = nn.Dropout(p_dropout)
297
+
298
+ def forward(self, x, x_mask):
299
+ x = self.conv_1(self.padding(x * x_mask))
300
+ if self.activation == "gelu":
301
+ x = x * torch.sigmoid(1.702 * x)
302
+ else:
303
+ x = torch.relu(x)
304
+ x = self.drop(x)
305
+ x = self.conv_2(self.padding(x * x_mask))
306
+ return x * x_mask
307
+
308
+ def _causal_padding(self, x):
309
+ if self.kernel_size == 1:
310
+ return x
311
+ pad_l = self.kernel_size - 1
312
+ pad_r = 0
313
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
314
+ x = F.pad(x, commons.convert_pad_shape(padding))
315
+ return x
316
+
317
+ def _same_padding(self, x):
318
+ if self.kernel_size == 1:
319
+ return x
320
+ pad_l = (self.kernel_size - 1) // 2
321
+ pad_r = self.kernel_size // 2
322
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
323
+ x = F.pad(x, commons.convert_pad_shape(padding))
324
+ return x
modules/beta_vae.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.distributions as td
4
+ import numpy as np
5
+
6
+ class Swish(nn.Module):
7
+ def __init__(self):
8
+ super(Swish, self).__init__()
9
+ def forward(self, x):
10
+ return x * torch.sigmoid(x)
11
+
12
+ def cycle_interval(starting_value, num_frames, min_val, max_val):
13
+ """Cycles through the state space in a single cycle."""
14
+ starting_in_01 = ((starting_value - min_val) / (max_val - min_val)).cpu()
15
+ grid = torch.linspace(starting_in_01.item(), starting_in_01.item() + 2., steps=num_frames + 1)[:-1]
16
+ grid -= np.maximum(0, 2 * grid - 2)
17
+ grid += np.maximum(0, -2 * grid)
18
+ return grid * (max_val - min_val) + min_val
19
+ class BetaVAE_Linear(nn.Module):
20
+ def __init__(self, in_dim=1024, n_hidden=64, latent=8):
21
+ super(BetaVAE_Linear, self).__init__()
22
+
23
+ self.n_hidden = n_hidden
24
+ self.latent = latent
25
+
26
+ # Encoder
27
+ self.encoder = nn.Sequential(
28
+ nn.Linear(in_dim, n_hidden), Swish(),
29
+ )
30
+
31
+ # Latent
32
+ self.mu = nn.Linear(n_hidden, latent)
33
+ self.lv = nn.Linear(n_hidden, latent)
34
+
35
+ # Decoder
36
+ self.decoder = nn.Sequential(
37
+ nn.Linear(latent, n_hidden), Swish(),
38
+ nn.Linear(n_hidden, in_dim), Swish()
39
+ )
40
+
41
+ def BottomUp(self, x):
42
+ out = self.encoder(x)
43
+ mu, lv = self.mu(out), self.lv(out)
44
+ return mu, lv
45
+
46
+ def reparameterize(self, mu, lv):
47
+ std = torch.exp(0.5 * lv)
48
+ eps = torch.randn_like(std)
49
+ return mu + std * eps
50
+
51
+ def TopDown(self, z):
52
+ out = self.decoder(z)
53
+ return out
54
+
55
+ def forward(self, x):
56
+ # x = x.view(x.shape[0], -1)
57
+ mu, lv = self.BottomUp(x)
58
+ z = self.reparameterize(mu, lv)
59
+ out = self.TopDown(z)
60
+ return out, mu, lv
61
+
62
+ def calc_loss(self, x, beta):
63
+ mu, lv = self.BottomUp(x)
64
+ z = self.reparameterize(mu, lv)
65
+ out = torch.sigmoid(self.TopDown(z))
66
+
67
+ nll = -nn.functional.binary_cross_entropy(out, x, reduction='sum') / x.shape[0]
68
+ kl = (-0.5 * torch.sum(1 + lv - mu.pow(2) - lv.exp()) + 1e-5) / x.shape[0]
69
+ # print(kl, nll)
70
+
71
+ return -nll + kl * beta, kl, nll
72
+
73
+ def LT_fitted_gauss_2std(self, x,num_var=6, num_traversal=5):
74
+ # Cycle linearly through +-2 std dev of a fitted Gaussian.
75
+ x = x.view(x.shape[0], -1)
76
+ mu, lv = self.BottomUp(x)
77
+
78
+ images = []
79
+ for i, batch_mu in enumerate(mu[:num_var]):
80
+ images.append(torch.sigmoid(self.TopDown(batch_mu)).unsqueeze(0))
81
+ for latent_var in range(batch_mu.shape[0]):
82
+ new_mu = batch_mu.unsqueeze(0).repeat([num_traversal, 1])
83
+ loc = mu[:, latent_var].mean()
84
+ total_var = lv[:, latent_var].exp().mean() + mu[:, latent_var].var()
85
+ scale = total_var.sqrt()
86
+ new_mu[:, latent_var] = cycle_interval(batch_mu[latent_var], num_traversal,
87
+ loc - 2 * scale, loc + 2 * scale)
88
+ images.append(torch.sigmoid(self.TopDown(new_mu)))
89
+ return images
90
+
91
+
92
+ if __name__ == "__main__":
93
+ model = BetaVAE_Linear()
94
+ x = torch.rand(10, 784)
95
+ out = model(x)
96
+ print(out.shape)
97
+ loss, kl, nll = model.calc_loss(x, 0.05)
98
+ print(loss, kl, nll)
99
+ images = model.LT_fitted_gauss_2std(x)
100
+ print(len(images), images[0].shape)
101
+ print(images[0].shape)
modules/commons.py ADDED
@@ -0,0 +1,479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+ from munch import Munch
7
+ import json
8
+
9
+ class AttrDict(dict):
10
+ def __init__(self, *args, **kwargs):
11
+ super(AttrDict, self).__init__(*args, **kwargs)
12
+ self.__dict__ = self
13
+
14
+ def init_weights(m, mean=0.0, std=0.01):
15
+ classname = m.__class__.__name__
16
+ if classname.find("Conv") != -1:
17
+ m.weight.data.normal_(mean, std)
18
+
19
+
20
+ def get_padding(kernel_size, dilation=1):
21
+ return int((kernel_size*dilation - dilation)/2)
22
+
23
+
24
+ def convert_pad_shape(pad_shape):
25
+ l = pad_shape[::-1]
26
+ pad_shape = [item for sublist in l for item in sublist]
27
+ return pad_shape
28
+
29
+
30
+ def intersperse(lst, item):
31
+ result = [item] * (len(lst) * 2 + 1)
32
+ result[1::2] = lst
33
+ return result
34
+
35
+
36
+ def kl_divergence(m_p, logs_p, m_q, logs_q):
37
+ """KL(P||Q)"""
38
+ kl = (logs_q - logs_p) - 0.5
39
+ kl += 0.5 * (torch.exp(2. * logs_p) + ((m_p - m_q)**2)) * torch.exp(-2. * logs_q)
40
+ return kl
41
+
42
+
43
+ def rand_gumbel(shape):
44
+ """Sample from the Gumbel distribution, protect from overflows."""
45
+ uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
46
+ return -torch.log(-torch.log(uniform_samples))
47
+
48
+
49
+ def rand_gumbel_like(x):
50
+ g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
51
+ return g
52
+
53
+
54
+ def slice_segments(x, ids_str, segment_size=4):
55
+ ret = torch.zeros_like(x[:, :, :segment_size])
56
+ for i in range(x.size(0)):
57
+ idx_str = ids_str[i]
58
+ idx_end = idx_str + segment_size
59
+ ret[i] = x[i, :, idx_str:idx_end]
60
+ return ret
61
+
62
+ def slice_segments_audio(x, ids_str, segment_size=4):
63
+ ret = torch.zeros_like(x[:, :segment_size])
64
+ for i in range(x.size(0)):
65
+ idx_str = ids_str[i]
66
+ idx_end = idx_str + segment_size
67
+ ret[i] = x[i, idx_str:idx_end]
68
+ return ret
69
+
70
+ def rand_slice_segments(x, x_lengths=None, segment_size=4):
71
+ b, d, t = x.size()
72
+ if x_lengths is None:
73
+ x_lengths = t
74
+ ids_str_max = x_lengths - segment_size + 1
75
+ ids_str = ((torch.rand([b]).to(device=x.device) * ids_str_max).clip(0)).to(dtype=torch.long)
76
+ ret = slice_segments(x, ids_str, segment_size)
77
+ return ret, ids_str
78
+
79
+
80
+ def get_timing_signal_1d(
81
+ length, channels, min_timescale=1.0, max_timescale=1.0e4):
82
+ position = torch.arange(length, dtype=torch.float)
83
+ num_timescales = channels // 2
84
+ log_timescale_increment = (
85
+ math.log(float(max_timescale) / float(min_timescale)) /
86
+ (num_timescales - 1))
87
+ inv_timescales = min_timescale * torch.exp(
88
+ torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment)
89
+ scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
90
+ signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
91
+ signal = F.pad(signal, [0, 0, 0, channels % 2])
92
+ signal = signal.view(1, channels, length)
93
+ return signal
94
+
95
+
96
+ def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
97
+ b, channels, length = x.size()
98
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
99
+ return x + signal.to(dtype=x.dtype, device=x.device)
100
+
101
+
102
+ def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
103
+ b, channels, length = x.size()
104
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
105
+ return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
106
+
107
+
108
+ def subsequent_mask(length):
109
+ mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
110
+ return mask
111
+
112
+
113
+ @torch.jit.script
114
+ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
115
+ n_channels_int = n_channels[0]
116
+ in_act = input_a + input_b
117
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
118
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
119
+ acts = t_act * s_act
120
+ return acts
121
+
122
+
123
+ def convert_pad_shape(pad_shape):
124
+ l = pad_shape[::-1]
125
+ pad_shape = [item for sublist in l for item in sublist]
126
+ return pad_shape
127
+
128
+
129
+ def shift_1d(x):
130
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
131
+ return x
132
+
133
+
134
+ def sequence_mask(length, max_length=None):
135
+ if max_length is None:
136
+ max_length = length.max()
137
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
138
+ return x.unsqueeze(0) < length.unsqueeze(1)
139
+
140
+
141
+ def generate_path(duration, mask):
142
+ """
143
+ duration: [b, 1, t_x]
144
+ mask: [b, 1, t_y, t_x]
145
+ """
146
+ device = duration.device
147
+
148
+ b, _, t_y, t_x = mask.shape
149
+ cum_duration = torch.cumsum(duration, -1)
150
+
151
+ cum_duration_flat = cum_duration.view(b * t_x)
152
+ path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
153
+ path = path.view(b, t_x, t_y)
154
+ path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
155
+ path = path.unsqueeze(1).transpose(2,3) * mask
156
+ return path
157
+
158
+
159
+ def clip_grad_value_(parameters, clip_value, norm_type=2):
160
+ if isinstance(parameters, torch.Tensor):
161
+ parameters = [parameters]
162
+ parameters = list(filter(lambda p: p.grad is not None, parameters))
163
+ norm_type = float(norm_type)
164
+ if clip_value is not None:
165
+ clip_value = float(clip_value)
166
+
167
+ total_norm = 0
168
+ for p in parameters:
169
+ param_norm = p.grad.data.norm(norm_type)
170
+ total_norm += param_norm.item() ** norm_type
171
+ if clip_value is not None:
172
+ p.grad.data.clamp_(min=-clip_value, max=clip_value)
173
+ total_norm = total_norm ** (1. / norm_type)
174
+ return total_norm
175
+
176
+ def log_norm(x, mean=-4, std=4, dim=2):
177
+ """
178
+ normalized log mel -> mel -> norm -> log(norm)
179
+ """
180
+ x = torch.log(torch.exp(x * std + mean).norm(dim=dim))
181
+ return x
182
+
183
+ def load_F0_models(path):
184
+ # load F0 model
185
+ from .JDC.model import JDCNet
186
+ F0_model = JDCNet(num_class=1, seq_len=192)
187
+ params = torch.load(path, map_location='cpu')['net']
188
+ F0_model.load_state_dict(params)
189
+ _ = F0_model.train()
190
+
191
+ return F0_model
192
+
193
+ def modify_w2v_forward(self, output_layer=15):
194
+ '''
195
+ change forward method of w2v encoder to get its intermediate layer output
196
+ :param self:
197
+ :param layer:
198
+ :return:
199
+ '''
200
+ from transformers.modeling_outputs import BaseModelOutput
201
+ def forward(
202
+ hidden_states,
203
+ attention_mask=None,
204
+ output_attentions=False,
205
+ output_hidden_states=False,
206
+ return_dict=True,
207
+ ):
208
+ all_hidden_states = () if output_hidden_states else None
209
+ all_self_attentions = () if output_attentions else None
210
+
211
+ conv_attention_mask = attention_mask
212
+ if attention_mask is not None:
213
+ # make sure padded tokens output 0
214
+ hidden_states = hidden_states.masked_fill(~attention_mask.bool().unsqueeze(-1), 0.0)
215
+
216
+ # extend attention_mask
217
+ attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
218
+ attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
219
+ attention_mask = attention_mask.expand(
220
+ attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
221
+ )
222
+
223
+ hidden_states = self.dropout(hidden_states)
224
+
225
+ if self.embed_positions is not None:
226
+ relative_position_embeddings = self.embed_positions(hidden_states)
227
+ else:
228
+ relative_position_embeddings = None
229
+
230
+ deepspeed_zero3_is_enabled = False
231
+
232
+ for i, layer in enumerate(self.layers):
233
+ if output_hidden_states:
234
+ all_hidden_states = all_hidden_states + (hidden_states,)
235
+
236
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
237
+ dropout_probability = torch.rand([])
238
+
239
+ skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
240
+ if not skip_the_layer or deepspeed_zero3_is_enabled:
241
+ # under deepspeed zero3 all gpus must run in sync
242
+ if self.gradient_checkpointing and self.training:
243
+ layer_outputs = self._gradient_checkpointing_func(
244
+ layer.__call__,
245
+ hidden_states,
246
+ attention_mask,
247
+ relative_position_embeddings,
248
+ output_attentions,
249
+ conv_attention_mask,
250
+ )
251
+ else:
252
+ layer_outputs = layer(
253
+ hidden_states,
254
+ attention_mask=attention_mask,
255
+ relative_position_embeddings=relative_position_embeddings,
256
+ output_attentions=output_attentions,
257
+ conv_attention_mask=conv_attention_mask,
258
+ )
259
+ hidden_states = layer_outputs[0]
260
+
261
+ if skip_the_layer:
262
+ layer_outputs = (None, None)
263
+
264
+ if output_attentions:
265
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
266
+
267
+ if i == output_layer - 1:
268
+ break
269
+
270
+ if output_hidden_states:
271
+ all_hidden_states = all_hidden_states + (hidden_states,)
272
+
273
+ if not return_dict:
274
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
275
+ return BaseModelOutput(
276
+ last_hidden_state=hidden_states,
277
+ hidden_states=all_hidden_states,
278
+ attentions=all_self_attentions,
279
+ )
280
+ return forward
281
+
282
+
283
+ def build_model(args, stage='codec'):
284
+ if stage == 'codec':
285
+ # Generators
286
+ from dac.model.dac import Encoder, Decoder
287
+ from modules.quantize import FAquantizer, FApredictors, CNNLSTM, GradientReversal
288
+
289
+ # Discriminators
290
+ from dac.model.discriminator import Discriminator
291
+
292
+ encoder = Encoder(d_model=args.DAC.encoder_dim,
293
+ strides=args.DAC.encoder_rates,
294
+ d_latent=1024,
295
+ causal=args.causal,
296
+ lstm=args.lstm,)
297
+
298
+ quantizer = FAquantizer(in_dim=1024,
299
+ n_p_codebooks=1,
300
+ n_c_codebooks=args.n_c_codebooks,
301
+ n_t_codebooks=2,
302
+ n_r_codebooks=3,
303
+ codebook_size=1024,
304
+ codebook_dim=8,
305
+ quantizer_dropout=0.5,
306
+ causal=args.causal,
307
+ separate_prosody_encoder=args.separate_prosody_encoder,
308
+ timbre_norm=args.timbre_norm,
309
+ )
310
+
311
+ fa_predictors = FApredictors(in_dim=1024,
312
+ use_gr_content_f0=args.use_gr_content_f0,
313
+ use_gr_prosody_phone=args.use_gr_prosody_phone,
314
+ use_gr_residual_f0=True,
315
+ use_gr_residual_phone=True,
316
+ use_gr_timbre_content=True,
317
+ use_gr_timbre_prosody=args.use_gr_timbre_prosody,
318
+ use_gr_x_timbre=True,
319
+ norm_f0=args.norm_f0,
320
+ timbre_norm=args.timbre_norm,
321
+ use_gr_content_global_f0=args.use_gr_content_global_f0,
322
+ )
323
+
324
+
325
+
326
+ decoder = Decoder(
327
+ input_channel=1024,
328
+ channels=args.DAC.decoder_dim,
329
+ rates=args.DAC.decoder_rates,
330
+ causal=args.causal,
331
+ lstm=args.lstm,
332
+ )
333
+
334
+ discriminator = Discriminator(
335
+ rates=[],
336
+ periods=[2, 3, 5, 7, 11],
337
+ fft_sizes=[2048, 1024, 512],
338
+ sample_rate=args.DAC.sr,
339
+ bands=[(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)],
340
+ )
341
+
342
+ nets = Munch(
343
+ encoder=encoder,
344
+ quantizer=quantizer,
345
+ decoder=decoder,
346
+ discriminator=discriminator,
347
+ fa_predictors=fa_predictors,
348
+ )
349
+ elif stage == 'beta_vae':
350
+ from dac.model.dac import Encoder, Decoder
351
+ from modules.beta_vae import BetaVAE_Linear
352
+ # Discriminators
353
+ from dac.model.discriminator import Discriminator
354
+
355
+ encoder = Encoder(d_model=args.DAC.encoder_dim,
356
+ strides=args.DAC.encoder_rates,
357
+ d_latent=1024,
358
+ causal=args.causal,
359
+ lstm=args.lstm, )
360
+
361
+ decoder = Decoder(
362
+ input_channel=1024,
363
+ channels=args.DAC.decoder_dim,
364
+ rates=args.DAC.decoder_rates,
365
+ causal=args.causal,
366
+ lstm=args.lstm,
367
+ )
368
+
369
+ discriminator = Discriminator(
370
+ rates=[],
371
+ periods=[2, 3, 5, 7, 11],
372
+ fft_sizes=[2048, 1024, 512],
373
+ sample_rate=args.DAC.sr,
374
+ bands=[(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)],
375
+ )
376
+
377
+ beta_vae = BetaVAE_Linear(in_dim=1024, n_hidden=64, latent=8)
378
+
379
+ nets = Munch(
380
+ encoder=encoder,
381
+ decoder=decoder,
382
+ discriminator=discriminator,
383
+ beta_vae=beta_vae,
384
+ )
385
+ elif stage == 'redecoder':
386
+ # from vc.models import FastTransformer, SlowTransformer, Mambo
387
+ from dac.model.dac import Encoder, Decoder
388
+ from dac.model.discriminator import Discriminator
389
+ from modules.redecoder import Redecoder
390
+
391
+ encoder = Redecoder(args)
392
+
393
+ decoder = Decoder(
394
+ input_channel=1024,
395
+ channels=args.DAC.decoder_dim,
396
+ rates=args.DAC.decoder_rates,
397
+ causal=args.decoder_causal,
398
+ lstm=args.decoder_lstm,
399
+ )
400
+
401
+ discriminator = Discriminator(
402
+ rates=[],
403
+ periods=[2, 3, 5, 7, 11],
404
+ fft_sizes=[2048, 1024, 512],
405
+ sample_rate=args.DAC.sr,
406
+ bands=[(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)],
407
+ )
408
+
409
+ nets = Munch(
410
+ encoder=encoder,
411
+ decoder=decoder,
412
+ discriminator=discriminator,
413
+ )
414
+ elif stage == 'encoder':
415
+ from dac.model.dac import Encoder, Decoder
416
+ from modules.quantize import FAquantizer
417
+
418
+ encoder = Encoder(d_model=args.DAC.encoder_dim,
419
+ strides=args.DAC.encoder_rates,
420
+ d_latent=1024,
421
+ causal=args.encoder_causal,
422
+ lstm=args.encoder_lstm,)
423
+
424
+ quantizer = FAquantizer(in_dim=1024,
425
+ n_p_codebooks=1,
426
+ n_c_codebooks=args.n_c_codebooks,
427
+ n_t_codebooks=2,
428
+ n_r_codebooks=3,
429
+ codebook_size=1024,
430
+ codebook_dim=8,
431
+ quantizer_dropout=0.5,
432
+ causal=args.encoder_causal,
433
+ separate_prosody_encoder=args.separate_prosody_encoder,
434
+ timbre_norm=args.timbre_norm,
435
+ )
436
+ nets = Munch(
437
+ encoder=encoder,
438
+ quantizer=quantizer,
439
+ )
440
+ else:
441
+ raise ValueError(f"Unknown stage: {stage}")
442
+
443
+ return nets
444
+
445
+
446
+ def load_checkpoint(model, optimizer, path, load_only_params=True, ignore_modules=[], is_distributed=False):
447
+ state = torch.load(path, map_location='cpu')
448
+ params = state['net']
449
+ for key in model:
450
+ if key in params and key not in ignore_modules:
451
+ if not is_distributed:
452
+ # strip prefix of DDP (module.), create a new OrderedDict that does not contain the prefix
453
+ for k in list(params[key].keys()):
454
+ if k.startswith('module.'):
455
+ params[key][k[len("module."):]] = params[key][k]
456
+ del params[key][k]
457
+ print('%s loaded' % key)
458
+ model[key].load_state_dict(params[key], strict=True)
459
+ _ = [model[key].eval() for key in model]
460
+
461
+ if not load_only_params:
462
+ epoch = state["epoch"] + 1
463
+ iters = state["iters"]
464
+ optimizer.load_state_dict(state["optimizer"])
465
+ optimizer.load_scheduler_state_dict(state["scheduler"])
466
+
467
+ else:
468
+ epoch = state["epoch"] + 1
469
+ iters = state["iters"]
470
+
471
+ return model, optimizer, epoch, iters
472
+
473
+ def recursive_munch(d):
474
+ if isinstance(d, dict):
475
+ return Munch((k, recursive_munch(v)) for k, v in d.items())
476
+ elif isinstance(d, list):
477
+ return [recursive_munch(v) for v in d]
478
+ else:
479
+ return d
modules/layers.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from typing import Optional, Any
5
+ from torch import Tensor
6
+ import torch.nn.functional as F
7
+ import torchaudio
8
+ import torchaudio.functional as audio_F
9
+
10
+ import random
11
+ random.seed(0)
12
+
13
+
14
+ def _get_activation_fn(activ):
15
+ if activ == 'relu':
16
+ return nn.ReLU()
17
+ elif activ == 'lrelu':
18
+ return nn.LeakyReLU(0.2)
19
+ elif activ == 'swish':
20
+ return lambda x: x*torch.sigmoid(x)
21
+ else:
22
+ raise RuntimeError('Unexpected activ type %s, expected [relu, lrelu, swish]' % activ)
23
+
24
+ class LinearNorm(torch.nn.Module):
25
+ def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
26
+ super(LinearNorm, self).__init__()
27
+ self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
28
+
29
+ torch.nn.init.xavier_uniform_(
30
+ self.linear_layer.weight,
31
+ gain=torch.nn.init.calculate_gain(w_init_gain))
32
+
33
+ def forward(self, x):
34
+ return self.linear_layer(x)
35
+
36
+
37
+ class ConvNorm(torch.nn.Module):
38
+ def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
39
+ padding=None, dilation=1, bias=True, w_init_gain='linear', param=None):
40
+ super(ConvNorm, self).__init__()
41
+ if padding is None:
42
+ assert(kernel_size % 2 == 1)
43
+ padding = int(dilation * (kernel_size - 1) / 2)
44
+
45
+ self.conv = torch.nn.Conv1d(in_channels, out_channels,
46
+ kernel_size=kernel_size, stride=stride,
47
+ padding=padding, dilation=dilation,
48
+ bias=bias)
49
+
50
+ torch.nn.init.xavier_uniform_(
51
+ self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain, param=param))
52
+
53
+ def forward(self, signal):
54
+ conv_signal = self.conv(signal)
55
+ return conv_signal
56
+
57
+ class CausualConv(nn.Module):
58
+ def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=1, dilation=1, bias=True, w_init_gain='linear', param=None):
59
+ super(CausualConv, self).__init__()
60
+ if padding is None:
61
+ assert(kernel_size % 2 == 1)
62
+ padding = int(dilation * (kernel_size - 1) / 2) * 2
63
+ else:
64
+ self.padding = padding * 2
65
+ self.conv = nn.Conv1d(in_channels, out_channels,
66
+ kernel_size=kernel_size, stride=stride,
67
+ padding=self.padding,
68
+ dilation=dilation,
69
+ bias=bias)
70
+
71
+ torch.nn.init.xavier_uniform_(
72
+ self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain, param=param))
73
+
74
+ def forward(self, x):
75
+ x = self.conv(x)
76
+ x = x[:, :, :-self.padding]
77
+ return x
78
+
79
+ class CausualBlock(nn.Module):
80
+ def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ='lrelu'):
81
+ super(CausualBlock, self).__init__()
82
+ self.blocks = nn.ModuleList([
83
+ self._get_conv(hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p)
84
+ for i in range(n_conv)])
85
+
86
+ def forward(self, x):
87
+ for block in self.blocks:
88
+ res = x
89
+ x = block(x)
90
+ x += res
91
+ return x
92
+
93
+ def _get_conv(self, hidden_dim, dilation, activ='lrelu', dropout_p=0.2):
94
+ layers = [
95
+ CausualConv(hidden_dim, hidden_dim, kernel_size=3, padding=dilation, dilation=dilation),
96
+ _get_activation_fn(activ),
97
+ nn.BatchNorm1d(hidden_dim),
98
+ nn.Dropout(p=dropout_p),
99
+ CausualConv(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1),
100
+ _get_activation_fn(activ),
101
+ nn.Dropout(p=dropout_p)
102
+ ]
103
+ return nn.Sequential(*layers)
104
+
105
+ class ConvBlock(nn.Module):
106
+ def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ='relu'):
107
+ super().__init__()
108
+ self._n_groups = 8
109
+ self.blocks = nn.ModuleList([
110
+ self._get_conv(hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p)
111
+ for i in range(n_conv)])
112
+
113
+
114
+ def forward(self, x):
115
+ for block in self.blocks:
116
+ res = x
117
+ x = block(x)
118
+ x += res
119
+ return x
120
+
121
+ def _get_conv(self, hidden_dim, dilation, activ='relu', dropout_p=0.2):
122
+ layers = [
123
+ ConvNorm(hidden_dim, hidden_dim, kernel_size=3, padding=dilation, dilation=dilation),
124
+ _get_activation_fn(activ),
125
+ nn.GroupNorm(num_groups=self._n_groups, num_channels=hidden_dim),
126
+ nn.Dropout(p=dropout_p),
127
+ ConvNorm(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1),
128
+ _get_activation_fn(activ),
129
+ nn.Dropout(p=dropout_p)
130
+ ]
131
+ return nn.Sequential(*layers)
132
+
133
+ class LocationLayer(nn.Module):
134
+ def __init__(self, attention_n_filters, attention_kernel_size,
135
+ attention_dim):
136
+ super(LocationLayer, self).__init__()
137
+ padding = int((attention_kernel_size - 1) / 2)
138
+ self.location_conv = ConvNorm(2, attention_n_filters,
139
+ kernel_size=attention_kernel_size,
140
+ padding=padding, bias=False, stride=1,
141
+ dilation=1)
142
+ self.location_dense = LinearNorm(attention_n_filters, attention_dim,
143
+ bias=False, w_init_gain='tanh')
144
+
145
+ def forward(self, attention_weights_cat):
146
+ processed_attention = self.location_conv(attention_weights_cat)
147
+ processed_attention = processed_attention.transpose(1, 2)
148
+ processed_attention = self.location_dense(processed_attention)
149
+ return processed_attention
150
+
151
+
152
+ class Attention(nn.Module):
153
+ def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
154
+ attention_location_n_filters, attention_location_kernel_size):
155
+ super(Attention, self).__init__()
156
+ self.query_layer = LinearNorm(attention_rnn_dim, attention_dim,
157
+ bias=False, w_init_gain='tanh')
158
+ self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False,
159
+ w_init_gain='tanh')
160
+ self.v = LinearNorm(attention_dim, 1, bias=False)
161
+ self.location_layer = LocationLayer(attention_location_n_filters,
162
+ attention_location_kernel_size,
163
+ attention_dim)
164
+ self.score_mask_value = -float("inf")
165
+
166
+ def get_alignment_energies(self, query, processed_memory,
167
+ attention_weights_cat):
168
+ """
169
+ PARAMS
170
+ ------
171
+ query: decoder output (batch, n_mel_channels * n_frames_per_step)
172
+ processed_memory: processed encoder outputs (B, T_in, attention_dim)
173
+ attention_weights_cat: cumulative and prev. att weights (B, 2, max_time)
174
+ RETURNS
175
+ -------
176
+ alignment (batch, max_time)
177
+ """
178
+
179
+ processed_query = self.query_layer(query.unsqueeze(1))
180
+ processed_attention_weights = self.location_layer(attention_weights_cat)
181
+ energies = self.v(torch.tanh(
182
+ processed_query + processed_attention_weights + processed_memory))
183
+
184
+ energies = energies.squeeze(-1)
185
+ return energies
186
+
187
+ def forward(self, attention_hidden_state, memory, processed_memory,
188
+ attention_weights_cat, mask):
189
+ """
190
+ PARAMS
191
+ ------
192
+ attention_hidden_state: attention rnn last output
193
+ memory: encoder outputs
194
+ processed_memory: processed encoder outputs
195
+ attention_weights_cat: previous and cummulative attention weights
196
+ mask: binary mask for padded data
197
+ """
198
+ alignment = self.get_alignment_energies(
199
+ attention_hidden_state, processed_memory, attention_weights_cat)
200
+
201
+ if mask is not None:
202
+ alignment.data.masked_fill_(mask, self.score_mask_value)
203
+
204
+ attention_weights = F.softmax(alignment, dim=1)
205
+ attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
206
+ attention_context = attention_context.squeeze(1)
207
+
208
+ return attention_context, attention_weights
209
+
210
+
211
+ class ForwardAttentionV2(nn.Module):
212
+ def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
213
+ attention_location_n_filters, attention_location_kernel_size):
214
+ super(ForwardAttentionV2, self).__init__()
215
+ self.query_layer = LinearNorm(attention_rnn_dim, attention_dim,
216
+ bias=False, w_init_gain='tanh')
217
+ self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False,
218
+ w_init_gain='tanh')
219
+ self.v = LinearNorm(attention_dim, 1, bias=False)
220
+ self.location_layer = LocationLayer(attention_location_n_filters,
221
+ attention_location_kernel_size,
222
+ attention_dim)
223
+ self.score_mask_value = -float(1e20)
224
+
225
+ def get_alignment_energies(self, query, processed_memory,
226
+ attention_weights_cat):
227
+ """
228
+ PARAMS
229
+ ------
230
+ query: decoder output (batch, n_mel_channels * n_frames_per_step)
231
+ processed_memory: processed encoder outputs (B, T_in, attention_dim)
232
+ attention_weights_cat: prev. and cumulative att weights (B, 2, max_time)
233
+ RETURNS
234
+ -------
235
+ alignment (batch, max_time)
236
+ """
237
+
238
+ processed_query = self.query_layer(query.unsqueeze(1))
239
+ processed_attention_weights = self.location_layer(attention_weights_cat)
240
+ energies = self.v(torch.tanh(
241
+ processed_query + processed_attention_weights + processed_memory))
242
+
243
+ energies = energies.squeeze(-1)
244
+ return energies
245
+
246
+ def forward(self, attention_hidden_state, memory, processed_memory,
247
+ attention_weights_cat, mask, log_alpha):
248
+ """
249
+ PARAMS
250
+ ------
251
+ attention_hidden_state: attention rnn last output
252
+ memory: encoder outputs
253
+ processed_memory: processed encoder outputs
254
+ attention_weights_cat: previous and cummulative attention weights
255
+ mask: binary mask for padded data
256
+ """
257
+ log_energy = self.get_alignment_energies(
258
+ attention_hidden_state, processed_memory, attention_weights_cat)
259
+
260
+ #log_energy =
261
+
262
+ if mask is not None:
263
+ log_energy.data.masked_fill_(mask, self.score_mask_value)
264
+
265
+ #attention_weights = F.softmax(alignment, dim=1)
266
+
267
+ #content_score = log_energy.unsqueeze(1) #[B, MAX_TIME] -> [B, 1, MAX_TIME]
268
+ #log_alpha = log_alpha.unsqueeze(2) #[B, MAX_TIME] -> [B, MAX_TIME, 1]
269
+
270
+ #log_total_score = log_alpha + content_score
271
+
272
+ #previous_attention_weights = attention_weights_cat[:,0,:]
273
+
274
+ log_alpha_shift_padded = []
275
+ max_time = log_energy.size(1)
276
+ for sft in range(2):
277
+ shifted = log_alpha[:,:max_time-sft]
278
+ shift_padded = F.pad(shifted, (sft,0), 'constant', self.score_mask_value)
279
+ log_alpha_shift_padded.append(shift_padded.unsqueeze(2))
280
+
281
+ biased = torch.logsumexp(torch.cat(log_alpha_shift_padded,2), 2)
282
+
283
+ log_alpha_new = biased + log_energy
284
+
285
+ attention_weights = F.softmax(log_alpha_new, dim=1)
286
+
287
+ attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
288
+ attention_context = attention_context.squeeze(1)
289
+
290
+ return attention_context, attention_weights, log_alpha_new
291
+
292
+
293
+ class PhaseShuffle2d(nn.Module):
294
+ def __init__(self, n=2):
295
+ super(PhaseShuffle2d, self).__init__()
296
+ self.n = n
297
+ self.random = random.Random(1)
298
+
299
+ def forward(self, x, move=None):
300
+ # x.size = (B, C, M, L)
301
+ if move is None:
302
+ move = self.random.randint(-self.n, self.n)
303
+
304
+ if move == 0:
305
+ return x
306
+ else:
307
+ left = x[:, :, :, :move]
308
+ right = x[:, :, :, move:]
309
+ shuffled = torch.cat([right, left], dim=3)
310
+ return shuffled
311
+
312
+ class PhaseShuffle1d(nn.Module):
313
+ def __init__(self, n=2):
314
+ super(PhaseShuffle1d, self).__init__()
315
+ self.n = n
316
+ self.random = random.Random(1)
317
+
318
+ def forward(self, x, move=None):
319
+ # x.size = (B, C, M, L)
320
+ if move is None:
321
+ move = self.random.randint(-self.n, self.n)
322
+
323
+ if move == 0:
324
+ return x
325
+ else:
326
+ left = x[:, :, :move]
327
+ right = x[:, :, move:]
328
+ shuffled = torch.cat([right, left], dim=2)
329
+
330
+ return shuffled
331
+
332
+ class MFCC(nn.Module):
333
+ def __init__(self, n_mfcc=40, n_mels=80):
334
+ super(MFCC, self).__init__()
335
+ self.n_mfcc = n_mfcc
336
+ self.n_mels = n_mels
337
+ self.norm = 'ortho'
338
+ dct_mat = audio_F.create_dct(self.n_mfcc, self.n_mels, self.norm)
339
+ self.register_buffer('dct_mat', dct_mat)
340
+
341
+ def forward(self, mel_specgram):
342
+ if len(mel_specgram.shape) == 2:
343
+ mel_specgram = mel_specgram.unsqueeze(0)
344
+ unsqueezed = True
345
+ else:
346
+ unsqueezed = False
347
+ # (channel, n_mels, time).tranpose(...) dot (n_mels, n_mfcc)
348
+ # -> (channel, time, n_mfcc).tranpose(...)
349
+ mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2)
350
+
351
+ # unpack batch
352
+ if unsqueezed:
353
+ mfcc = mfcc.squeeze(0)
354
+ return mfcc
modules/quantize.py ADDED
@@ -0,0 +1,613 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dac.nn.quantize import ResidualVectorQuantize
2
+ from torch import nn
3
+ from modules.wavenet import WN
4
+ from modules.style_encoder import StyleEncoder
5
+ from gradient_reversal import GradientReversal
6
+ import torch
7
+ import torchaudio
8
+ import torchaudio.functional as audio_F
9
+ import numpy as np
10
+ from alias_free_torch import *
11
+ from torch.nn.utils import weight_norm
12
+ from torch import nn, sin, pow
13
+ from einops.layers.torch import Rearrange
14
+ from dac.model.encodec import SConv1d
15
+
16
+ def init_weights(m):
17
+ if isinstance(m, nn.Conv1d):
18
+ nn.init.trunc_normal_(m.weight, std=0.02)
19
+ nn.init.constant_(m.bias, 0)
20
+
21
+
22
+ def WNConv1d(*args, **kwargs):
23
+ return weight_norm(nn.Conv1d(*args, **kwargs))
24
+
25
+
26
+ def WNConvTranspose1d(*args, **kwargs):
27
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
28
+
29
+ class SnakeBeta(nn.Module):
30
+ """
31
+ A modified Snake function which uses separate parameters for the magnitude of the periodic components
32
+ Shape:
33
+ - Input: (B, C, T)
34
+ - Output: (B, C, T), same shape as the input
35
+ Parameters:
36
+ - alpha - trainable parameter that controls frequency
37
+ - beta - trainable parameter that controls magnitude
38
+ References:
39
+ - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
40
+ https://arxiv.org/abs/2006.08195
41
+ Examples:
42
+ >>> a1 = snakebeta(256)
43
+ >>> x = torch.randn(256)
44
+ >>> x = a1(x)
45
+ """
46
+
47
+ def __init__(
48
+ self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
49
+ ):
50
+ """
51
+ Initialization.
52
+ INPUT:
53
+ - in_features: shape of the input
54
+ - alpha - trainable parameter that controls frequency
55
+ - beta - trainable parameter that controls magnitude
56
+ alpha is initialized to 1 by default, higher values = higher-frequency.
57
+ beta is initialized to 1 by default, higher values = higher-magnitude.
58
+ alpha will be trained along with the rest of your model.
59
+ """
60
+ super(SnakeBeta, self).__init__()
61
+ self.in_features = in_features
62
+
63
+ # initialize alpha
64
+ self.alpha_logscale = alpha_logscale
65
+ if self.alpha_logscale: # log scale alphas initialized to zeros
66
+ self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
67
+ self.beta = nn.Parameter(torch.zeros(in_features) * alpha)
68
+ else: # linear scale alphas initialized to ones
69
+ self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
70
+ self.beta = nn.Parameter(torch.ones(in_features) * alpha)
71
+
72
+ self.alpha.requires_grad = alpha_trainable
73
+ self.beta.requires_grad = alpha_trainable
74
+
75
+ self.no_div_by_zero = 0.000000001
76
+
77
+ def forward(self, x):
78
+ """
79
+ Forward pass of the function.
80
+ Applies the function to the input elementwise.
81
+ SnakeBeta := x + 1/b * sin^2 (xa)
82
+ """
83
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
84
+ beta = self.beta.unsqueeze(0).unsqueeze(-1)
85
+ if self.alpha_logscale:
86
+ alpha = torch.exp(alpha)
87
+ beta = torch.exp(beta)
88
+ x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
89
+
90
+ return x
91
+
92
+ class ResidualUnit(nn.Module):
93
+ def __init__(self, dim: int = 16, dilation: int = 1):
94
+ super().__init__()
95
+ pad = ((7 - 1) * dilation) // 2
96
+ self.block = nn.Sequential(
97
+ Activation1d(activation=SnakeBeta(dim, alpha_logscale=True)),
98
+ WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
99
+ Activation1d(activation=SnakeBeta(dim, alpha_logscale=True)),
100
+ WNConv1d(dim, dim, kernel_size=1),
101
+ )
102
+
103
+ def forward(self, x):
104
+ return x + self.block(x)
105
+
106
+ class CNNLSTM(nn.Module):
107
+ def __init__(self, indim, outdim, head, global_pred=False):
108
+ super().__init__()
109
+ self.global_pred = global_pred
110
+ self.model = nn.Sequential(
111
+ ResidualUnit(indim, dilation=1),
112
+ ResidualUnit(indim, dilation=2),
113
+ ResidualUnit(indim, dilation=3),
114
+ Activation1d(activation=SnakeBeta(indim, alpha_logscale=True)),
115
+ Rearrange("b c t -> b t c"),
116
+ )
117
+ self.heads = nn.ModuleList([nn.Linear(indim, outdim) for i in range(head)])
118
+
119
+ def forward(self, x):
120
+ # x: [B, C, T]
121
+ x = self.model(x)
122
+ if self.global_pred:
123
+ x = torch.mean(x, dim=1, keepdim=False)
124
+ outs = [head(x) for head in self.heads]
125
+ return outs
126
+
127
+ def sequence_mask(length, max_length=None):
128
+ if max_length is None:
129
+ max_length = length.max()
130
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
131
+ return x.unsqueeze(0) < length.unsqueeze(1)
132
+
133
+ class MFCC(nn.Module):
134
+ def __init__(self, n_mfcc=40, n_mels=80):
135
+ super(MFCC, self).__init__()
136
+ self.n_mfcc = n_mfcc
137
+ self.n_mels = n_mels
138
+ self.norm = 'ortho'
139
+ dct_mat = audio_F.create_dct(self.n_mfcc, self.n_mels, self.norm)
140
+ self.register_buffer('dct_mat', dct_mat)
141
+
142
+ def forward(self, mel_specgram):
143
+ if len(mel_specgram.shape) == 2:
144
+ mel_specgram = mel_specgram.unsqueeze(0)
145
+ unsqueezed = True
146
+ else:
147
+ unsqueezed = False
148
+ # (channel, n_mels, time).tranpose(...) dot (n_mels, n_mfcc)
149
+ # -> (channel, time, n_mfcc).tranpose(...)
150
+ mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2)
151
+
152
+ # unpack batch
153
+ if unsqueezed:
154
+ mfcc = mfcc.squeeze(0)
155
+ return mfcc
156
+ class FAquantizer(nn.Module):
157
+ def __init__(self, in_dim=1024,
158
+ n_p_codebooks=1,
159
+ n_c_codebooks=2,
160
+ n_t_codebooks=2,
161
+ n_r_codebooks=3,
162
+ codebook_size=1024,
163
+ codebook_dim=8,
164
+ quantizer_dropout=0.5,
165
+ causal=False,
166
+ separate_prosody_encoder=False,
167
+ timbre_norm=False,):
168
+ super(FAquantizer, self).__init__()
169
+ conv1d_type = SConv1d# if causal else nn.Conv1d
170
+ self.prosody_quantizer = ResidualVectorQuantize(
171
+ input_dim=in_dim,
172
+ n_codebooks=n_p_codebooks,
173
+ codebook_size=codebook_size,
174
+ codebook_dim=codebook_dim,
175
+ quantizer_dropout=quantizer_dropout,
176
+ )
177
+
178
+ self.content_quantizer = ResidualVectorQuantize(
179
+ input_dim=in_dim,
180
+ n_codebooks=n_c_codebooks,
181
+ codebook_size=codebook_size,
182
+ codebook_dim=codebook_dim,
183
+ quantizer_dropout=quantizer_dropout,
184
+ )
185
+
186
+ if not timbre_norm:
187
+ self.timbre_quantizer = ResidualVectorQuantize(
188
+ input_dim=in_dim,
189
+ n_codebooks=n_t_codebooks,
190
+ codebook_size=codebook_size,
191
+ codebook_dim=codebook_dim,
192
+ quantizer_dropout=quantizer_dropout,
193
+ )
194
+ else:
195
+ self.timbre_encoder = StyleEncoder(in_dim=80, hidden_dim=512, out_dim=in_dim)
196
+ self.timbre_linear = nn.Linear(1024, 1024 * 2)
197
+ self.timbre_linear.bias.data[:1024] = 1
198
+ self.timbre_linear.bias.data[1024:] = 0
199
+ self.timbre_norm = nn.LayerNorm(1024, elementwise_affine=False)
200
+
201
+ self.residual_quantizer = ResidualVectorQuantize(
202
+ input_dim=in_dim,
203
+ n_codebooks=n_r_codebooks,
204
+ codebook_size=codebook_size,
205
+ codebook_dim=codebook_dim,
206
+ quantizer_dropout=quantizer_dropout,
207
+ )
208
+
209
+ if separate_prosody_encoder:
210
+ self.melspec_linear = conv1d_type(in_channels=20, out_channels=256, kernel_size=1, causal=causal)
211
+ self.melspec_encoder = WN(hidden_channels=256, kernel_size=5, dilation_rate=1, n_layers=8, gin_channels=0, p_dropout=0.2, causal=causal)
212
+ self.melspec_linear2 = conv1d_type(in_channels=256, out_channels=1024, kernel_size=1, causal=causal)
213
+ else:
214
+ pass
215
+ self.separate_prosody_encoder = separate_prosody_encoder
216
+
217
+ self.prob_random_mask_residual = 0.75
218
+
219
+ SPECT_PARAMS = {
220
+ "n_fft": 2048,
221
+ "win_length": 1200,
222
+ "hop_length": 300,
223
+ }
224
+ MEL_PARAMS = {
225
+ "n_mels": 80,
226
+ }
227
+
228
+ self.to_mel = torchaudio.transforms.MelSpectrogram(
229
+ n_mels=MEL_PARAMS["n_mels"], sample_rate=24000, **SPECT_PARAMS
230
+ )
231
+ self.mel_mean, self.mel_std = -4, 4
232
+ self.frame_rate = 24000 / 300
233
+ self.hop_length = 300
234
+
235
+ self.is_timbre_norm = timbre_norm
236
+ if timbre_norm:
237
+ self.forward = self.forward_v2
238
+
239
+ def preprocess(self, wave_tensor, n_bins=20):
240
+ mel_tensor = self.to_mel(wave_tensor.squeeze(1))
241
+ mel_tensor = (torch.log(1e-5 + mel_tensor) - self.mel_mean) / self.mel_std
242
+ return mel_tensor[:, :n_bins, :int(wave_tensor.size(-1) / self.hop_length)]
243
+
244
+ @torch.no_grad()
245
+ def decode(self, codes):
246
+ code_c, code_p, code_t = codes.split([1, 1, 2], dim=1)
247
+
248
+ z_c = self.content_quantizer.from_codes(code_c)[0]
249
+ z_p = self.prosody_quantizer.from_codes(code_p)[0]
250
+ z_t = self.timbre_quantizer.from_codes(code_t)[0]
251
+
252
+ z = z_c + z_p + z_t
253
+
254
+ return z, [z_c, z_p, z_t]
255
+
256
+
257
+ @torch.no_grad()
258
+ def encode(self, x, wave_segments, n_c=1):
259
+ outs = 0
260
+ if self.separate_prosody_encoder:
261
+ prosody_feature = self.preprocess(wave_segments)
262
+
263
+ f0_input = prosody_feature # (B, T, 20)
264
+ f0_input = self.melspec_linear(f0_input)
265
+ f0_input = self.melspec_encoder(f0_input, torch.ones(f0_input.shape[0], 1, f0_input.shape[2]).to(
266
+ f0_input.device).bool())
267
+ f0_input = self.melspec_linear2(f0_input)
268
+
269
+ common_min_size = min(f0_input.size(2), x.size(2))
270
+ f0_input = f0_input[:, :, :common_min_size]
271
+
272
+ x = x[:, :, :common_min_size]
273
+
274
+ z_p, codes_p, latents_p, commitment_loss_p, codebook_loss_p = self.prosody_quantizer(
275
+ f0_input, 1
276
+ )
277
+ outs += z_p.detach()
278
+ else:
279
+ z_p, codes_p, latents_p, commitment_loss_p, codebook_loss_p = self.prosody_quantizer(
280
+ x, 1
281
+ )
282
+ outs += z_p.detach()
283
+
284
+ z_c, codes_c, latents_c, commitment_loss_c, codebook_loss_c = self.content_quantizer(
285
+ x, n_c
286
+ )
287
+ outs += z_c.detach()
288
+
289
+ timbre_residual_feature = x - z_p.detach() - z_c.detach()
290
+
291
+ z_t, codes_t, latents_t, commitment_loss_t, codebook_loss_t = self.timbre_quantizer(
292
+ timbre_residual_feature, 2
293
+ )
294
+ outs += z_t # we should not detach timbre
295
+
296
+ residual_feature = timbre_residual_feature - z_t
297
+
298
+ z_r, codes_r, latents_r, commitment_loss_r, codebook_loss_r = self.residual_quantizer(
299
+ residual_feature, 3
300
+ )
301
+
302
+ return [codes_c, codes_p, codes_t, codes_r], [z_c, z_p, z_t, z_r]
303
+ def forward(self, x, wave_segments, noise_added_flags, recon_noisy_flags, n_c=2, n_t=2):
304
+ # timbre = self.timbre_encoder(mels, sequence_mask(mel_lens, mels.size(-1)).unsqueeze(1))
305
+ # timbre = self.timbre_encoder(mel_segments, torch.ones(mel_segments.size(0), 1, mel_segments.size(2)).bool().to(mel_segments.device))
306
+ outs = 0
307
+ if self.separate_prosody_encoder:
308
+ prosody_feature = self.preprocess(wave_segments)
309
+
310
+ f0_input = prosody_feature # (B, T, 20)
311
+ f0_input = self.melspec_linear(f0_input)
312
+ f0_input = self.melspec_encoder(f0_input, torch.ones(f0_input.shape[0], 1, f0_input.shape[2]).to(f0_input.device).bool())
313
+ f0_input = self.melspec_linear2(f0_input)
314
+
315
+ common_min_size = min(f0_input.size(2), x.size(2))
316
+ f0_input = f0_input[:, :, :common_min_size]
317
+
318
+ x = x[:, :, :common_min_size]
319
+
320
+ z_p, codes_p, latents_p, commitment_loss_p, codebook_loss_p = self.prosody_quantizer(
321
+ f0_input, 1
322
+ )
323
+ outs += z_p.detach()
324
+ else:
325
+ z_p, codes_p, latents_p, commitment_loss_p, codebook_loss_p = self.prosody_quantizer(
326
+ x, 1
327
+ )
328
+ outs += z_p.detach()
329
+
330
+ z_c, codes_c, latents_c, commitment_loss_c, codebook_loss_c = self.content_quantizer(
331
+ x, n_c
332
+ )
333
+ outs += z_c.detach()
334
+
335
+ timbre_residual_feature = x - z_p.detach() - z_c.detach()
336
+
337
+ z_t, codes_t, latents_t, commitment_loss_t, codebook_loss_t = self.timbre_quantizer(
338
+ timbre_residual_feature, n_t
339
+ )
340
+ outs += z_t # we should not detach timbre
341
+
342
+ residual_feature = timbre_residual_feature - z_t
343
+
344
+ z_r, codes_r, latents_r, commitment_loss_r, codebook_loss_r = self.residual_quantizer(
345
+ residual_feature, 3
346
+ )
347
+
348
+ bsz = z_r.shape[0]
349
+ res_mask = np.random.choice(
350
+ [0, 1],
351
+ size=bsz,
352
+ p=[
353
+ self.prob_random_mask_residual,
354
+ 1 - self.prob_random_mask_residual,
355
+ ],
356
+ )
357
+ res_mask = (
358
+ torch.from_numpy(res_mask).unsqueeze(1).unsqueeze(1)
359
+ ) # (B, 1, 1)
360
+ res_mask = res_mask.to(
361
+ device=z_r.device, dtype=z_r.dtype
362
+ )
363
+ noise_must_on = noise_added_flags * recon_noisy_flags
364
+ noise_must_off = noise_added_flags * (~recon_noisy_flags)
365
+ res_mask[noise_must_on] = 1
366
+ res_mask[noise_must_off] = 0
367
+
368
+ outs += z_r * res_mask
369
+
370
+ quantized = [z_p, z_c, z_t, z_r]
371
+ commitment_losses = commitment_loss_p + commitment_loss_c + commitment_loss_t + commitment_loss_r
372
+ codebook_losses = codebook_loss_p + codebook_loss_c + codebook_loss_t + codebook_loss_r
373
+
374
+ return outs, quantized, commitment_losses, codebook_losses
375
+ def forward_v2(self, x, wave_segments, n_c=1, n_t=2, full_waves=None, wave_lens=None, return_codes=False):
376
+ # timbre = self.timbre_encoder(x, sequence_mask(mel_lens, mels.size(-1)).unsqueeze(1))
377
+ if full_waves is None:
378
+ mel = self.preprocess(wave_segments, n_bins=80)
379
+ timbre = self.timbre_encoder(mel, torch.ones(mel.size(0), 1, mel.size(2)).bool().to(mel.device))
380
+ else:
381
+ mel = self.preprocess(full_waves, n_bins=80)
382
+ timbre = self.timbre_encoder(mel, sequence_mask(wave_lens // self.hop_length, mel.size(-1)).unsqueeze(1))
383
+ outs = 0
384
+ if self.separate_prosody_encoder:
385
+ prosody_feature = self.preprocess(wave_segments)
386
+
387
+ f0_input = prosody_feature # (B, T, 20)
388
+ f0_input = self.melspec_linear(f0_input)
389
+ f0_input = self.melspec_encoder(f0_input, torch.ones(f0_input.shape[0], 1, f0_input.shape[2]).to(
390
+ f0_input.device).bool())
391
+ f0_input = self.melspec_linear2(f0_input)
392
+
393
+ common_min_size = min(f0_input.size(2), x.size(2))
394
+ f0_input = f0_input[:, :, :common_min_size]
395
+
396
+ x = x[:, :, :common_min_size]
397
+
398
+ z_p, codes_p, latents_p, commitment_loss_p, codebook_loss_p = self.prosody_quantizer(
399
+ f0_input, 1
400
+ )
401
+ outs += z_p.detach()
402
+ else:
403
+ z_p, codes_p, latents_p, commitment_loss_p, codebook_loss_p = self.prosody_quantizer(
404
+ x, 1
405
+ )
406
+ outs += z_p.detach()
407
+
408
+ z_c, codes_c, latents_c, commitment_loss_c, codebook_loss_c = self.content_quantizer(
409
+ x, n_c
410
+ )
411
+ outs += z_c.detach()
412
+
413
+ residual_feature = x - z_p.detach() - z_c.detach()
414
+
415
+ z_r, codes_r, latents_r, commitment_loss_r, codebook_loss_r = self.residual_quantizer(
416
+ residual_feature, 3
417
+ )
418
+
419
+ bsz = z_r.shape[0]
420
+ res_mask = np.random.choice(
421
+ [0, 1],
422
+ size=bsz,
423
+ p=[
424
+ self.prob_random_mask_residual,
425
+ 1 - self.prob_random_mask_residual,
426
+ ],
427
+ )
428
+ res_mask = (
429
+ torch.from_numpy(res_mask).unsqueeze(1).unsqueeze(1)
430
+ ) # (B, 1, 1)
431
+ res_mask = res_mask.to(
432
+ device=z_r.device, dtype=z_r.dtype
433
+ )
434
+
435
+ if not self.training:
436
+ res_mask = torch.ones_like(res_mask)
437
+ outs += z_r * res_mask
438
+
439
+ quantized = [z_p, z_c, z_r]
440
+ codes = [codes_p, codes_c, codes_r]
441
+ commitment_losses = commitment_loss_p + commitment_loss_c + commitment_loss_r
442
+ codebook_losses = codebook_loss_p + codebook_loss_c + codebook_loss_r
443
+
444
+ style = self.timbre_linear(timbre).unsqueeze(2) # (B, 2d, 1)
445
+ gamma, beta = style.chunk(2, 1) # (B, d, 1)
446
+ outs = outs.transpose(1, 2)
447
+ outs = self.timbre_norm(outs)
448
+ outs = outs.transpose(1, 2)
449
+ outs = outs * gamma + beta
450
+
451
+ if return_codes:
452
+ return outs, quantized, commitment_losses, codebook_losses, timbre, codes
453
+ else:
454
+ return outs, quantized, commitment_losses, codebook_losses, timbre
455
+
456
+ class FApredictors(nn.Module):
457
+ def __init__(self,
458
+ in_dim=1024,
459
+ use_gr_content_f0=False,
460
+ use_gr_prosody_phone=False,
461
+ use_gr_residual_f0=False,
462
+ use_gr_residual_phone=False,
463
+ use_gr_timbre_content=True,
464
+ use_gr_timbre_prosody=True,
465
+ use_gr_x_timbre=False,
466
+ norm_f0=True,
467
+ timbre_norm=False,
468
+ use_gr_content_global_f0=False,
469
+ ):
470
+ super(FApredictors, self).__init__()
471
+ self.f0_predictor = CNNLSTM(in_dim, 1, 2)
472
+ self.phone_predictor = CNNLSTM(in_dim, 1024, 1)
473
+ if timbre_norm:
474
+ self.timbre_predictor = nn.Linear(in_dim, 20000)
475
+ else:
476
+ self.timbre_predictor = CNNLSTM(in_dim, 20000, 1, global_pred=True)
477
+
478
+ self.use_gr_content_f0 = use_gr_content_f0
479
+ self.use_gr_prosody_phone = use_gr_prosody_phone
480
+ self.use_gr_residual_f0 = use_gr_residual_f0
481
+ self.use_gr_residual_phone = use_gr_residual_phone
482
+ self.use_gr_timbre_content = use_gr_timbre_content
483
+ self.use_gr_timbre_prosody = use_gr_timbre_prosody
484
+ self.use_gr_x_timbre = use_gr_x_timbre
485
+
486
+ self.rev_f0_predictor = nn.Sequential(
487
+ GradientReversal(alpha=1.0), CNNLSTM(in_dim, 1, 2)
488
+ )
489
+ self.rev_content_predictor = nn.Sequential(
490
+ GradientReversal(alpha=1.0), CNNLSTM(in_dim, 1024, 1)
491
+ )
492
+ self.rev_timbre_predictor = nn.Sequential(
493
+ GradientReversal(alpha=1.0), CNNLSTM(in_dim, 20000, 1, global_pred=True)
494
+ )
495
+
496
+ self.norm_f0 = norm_f0
497
+ self.timbre_norm = timbre_norm
498
+ if timbre_norm:
499
+ self.forward = self.forward_v2
500
+ self.global_f0_predictor = nn.Linear(in_dim, 1)
501
+
502
+ self.use_gr_content_global_f0 = use_gr_content_global_f0
503
+ if use_gr_content_global_f0:
504
+ self.rev_global_f0_predictor = nn.Sequential(
505
+ GradientReversal(alpha=1.0), CNNLSTM(in_dim, 1, 1, global_pred=True)
506
+ )
507
+ def forward(self, quantized):
508
+ prosody_latent = quantized[0]
509
+ content_latent = quantized[1]
510
+ timbre_latent = quantized[2]
511
+ residual_latent = quantized[3]
512
+ content_pred = self.phone_predictor(content_latent)[0]
513
+
514
+ if self.norm_f0:
515
+ spk_pred = self.timbre_predictor(timbre_latent)[0]
516
+ f0_pred, uv_pred = self.f0_predictor(prosody_latent)
517
+ else:
518
+ spk_pred = self.timbre_predictor(timbre_latent + prosody_latent)[0]
519
+ f0_pred, uv_pred = self.f0_predictor(prosody_latent + timbre_latent)
520
+
521
+ prosody_rev_latent = torch.zeros_like(quantized[0])
522
+ if self.use_gr_content_f0:
523
+ prosody_rev_latent += quantized[1]
524
+ if self.use_gr_timbre_prosody:
525
+ prosody_rev_latent += quantized[2]
526
+ if self.use_gr_residual_f0:
527
+ prosody_rev_latent += quantized[3]
528
+ rev_f0_pred, rev_uv_pred = self.rev_f0_predictor(prosody_rev_latent)
529
+
530
+ content_rev_latent = torch.zeros_like(quantized[1])
531
+ if self.use_gr_prosody_phone:
532
+ content_rev_latent += quantized[0]
533
+ if self.use_gr_timbre_content:
534
+ content_rev_latent += quantized[2]
535
+ if self.use_gr_residual_phone:
536
+ content_rev_latent += quantized[3]
537
+ rev_content_pred = self.rev_content_predictor(content_rev_latent)[0]
538
+
539
+ if self.norm_f0:
540
+ timbre_rev_latent = quantized[0] + quantized[1] + quantized[3]
541
+ else:
542
+ timbre_rev_latent = quantized[1] + quantized[3]
543
+ if self.use_gr_x_timbre:
544
+ x_spk_pred = self.rev_timbre_predictor(timbre_rev_latent)[0]
545
+ else:
546
+ x_spk_pred = None
547
+
548
+
549
+
550
+ preds = {
551
+ 'f0': f0_pred,
552
+ 'uv': uv_pred,
553
+ 'content': content_pred,
554
+ 'timbre': spk_pred,
555
+ }
556
+
557
+ rev_preds = {
558
+ 'rev_f0': rev_f0_pred,
559
+ 'rev_uv': rev_uv_pred,
560
+ 'rev_content': rev_content_pred,
561
+ 'x_timbre': x_spk_pred,
562
+ }
563
+ return preds, rev_preds
564
+ def forward_v2(self, quantized, timbre):
565
+ assert self.use_gr_content_global_f0
566
+ prosody_latent = quantized[0]
567
+ content_latent = quantized[1]
568
+ residual_latent = quantized[2]
569
+ content_pred = self.phone_predictor(content_latent)[0]
570
+
571
+ # spk_pred = self.timbre_predictor(timbre)[0]
572
+ f0_pred, uv_pred = self.f0_predictor(prosody_latent)
573
+
574
+ prosody_rev_latent = torch.zeros_like(prosody_latent)
575
+ if self.use_gr_content_f0:
576
+ prosody_rev_latent += content_latent
577
+ if self.use_gr_residual_f0:
578
+ prosody_rev_latent += residual_latent
579
+ rev_f0_pred, rev_uv_pred = self.rev_f0_predictor(prosody_rev_latent)
580
+
581
+ content_rev_latent = torch.zeros_like(content_latent)
582
+ if self.use_gr_prosody_phone:
583
+ content_rev_latent += prosody_latent
584
+ if self.use_gr_residual_phone:
585
+ content_rev_latent += residual_latent
586
+ rev_content_pred = self.rev_content_predictor(content_rev_latent)[0]
587
+
588
+ timbre_rev_latent = prosody_latent + content_latent + residual_latent
589
+ if self.use_gr_x_timbre:
590
+ x_spk_pred = self.rev_timbre_predictor(timbre_rev_latent)[0]
591
+ else:
592
+ x_spk_pred = None
593
+
594
+ global_f0_pred = self.global_f0_predictor(timbre)
595
+ if self.use_gr_content_global_f0:
596
+ rev_global_f0_pred = self.rev_global_f0_predictor(content_latent + prosody_latent + residual_latent)[0]
597
+
598
+ preds = {
599
+ 'f0': f0_pred,
600
+ 'uv': uv_pred,
601
+ 'content': content_pred,
602
+ 'timbre': None,
603
+ 'global_f0': global_f0_pred,
604
+ }
605
+
606
+ rev_preds = {
607
+ 'rev_f0': rev_f0_pred,
608
+ 'rev_uv': rev_uv_pred,
609
+ 'rev_content': rev_content_pred,
610
+ 'x_timbre': x_spk_pred,
611
+ 'rev_global_f0': rev_global_f0_pred,
612
+ }
613
+ return preds, rev_preds
modules/redecoder.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from modules.wavenet import WN
3
+ #
4
+ class Redecoder(torch.nn.Module):
5
+ def __init__(self, args):
6
+ super(Redecoder, self).__init__()
7
+ self.n_p_codebooks = args.n_p_codebooks # number of prosody codebooks
8
+ self.n_c_codebooks = args.n_c_codebooks # number of content codebooks
9
+ self.codebook_size = 1024 # codebook size
10
+ self.encoder_type = args.encoder_type
11
+ if args.encoder_type == "wavenet":
12
+ self.embed_dim = args.wavenet_embed_dim
13
+ self.encoder = WN(hidden_channels=self.embed_dim, kernel_size=5, dilation_rate=1, n_layers=16, gin_channels=1024
14
+ , p_dropout=0.2, causal=args.decoder_causal)
15
+ self.conv_out = torch.nn.Conv1d(self.embed_dim, 1024, 1)
16
+ self.prosody_embed = torch.nn.ModuleList(
17
+ [torch.nn.Embedding(self.codebook_size, self.embed_dim) for _ in range(self.n_p_codebooks)])
18
+ self.content_embed = torch.nn.ModuleList(
19
+ [torch.nn.Embedding(self.codebook_size, self.embed_dim) for _ in range(self.n_c_codebooks)])
20
+ elif args.encoder_type == "mamba":
21
+ from modules.mamba import Mambo
22
+ self.embed_dim = args.mamba_embed_dim
23
+ self.encoder = Mambo(d_model=self.embed_dim, n_layer=24, vocab_size=1024,
24
+ prob_random_mask_prosody=args.prob_random_mask_prosody,
25
+ prob_random_mask_content=args.prob_random_mask_content,)
26
+ self.conv_out = torch.nn.Linear(self.embed_dim, 1024)
27
+ self.forward = self.forward_v2
28
+ self.prosody_embed = torch.nn.ModuleList(
29
+ [torch.nn.Embedding(self.codebook_size, self.embed_dim) for _ in range(self.n_p_codebooks)])
30
+ self.content_embed = torch.nn.ModuleList(
31
+ [torch.nn.Embedding(self.codebook_size, self.embed_dim) for _ in range(self.n_c_codebooks)])
32
+ else:
33
+ raise NotImplementedError
34
+
35
+ def forward(self, p_code, c_code, timbre_vec, use_p_code=True, use_c_code=True, n_c=2):
36
+ B, _, T = p_code.size()
37
+ p_embed = torch.zeros(B, T, self.embed_dim).to(p_code.device)
38
+ c_embed = torch.zeros(B, T, self.embed_dim).to(c_code.device)
39
+ if use_p_code:
40
+ for i in range(self.n_p_codebooks):
41
+ p_embed += self.prosody_embed[i](p_code[:, i, :])
42
+ if use_c_code:
43
+ for i in range(n_c):
44
+ c_embed += self.content_embed[i](c_code[:, i, :])
45
+ x = p_embed + c_embed
46
+ x = self.encoder(x.transpose(1, 2), x_mask=torch.ones(B, 1, T).to(p_code.device), g=timbre_vec.unsqueeze(2))
47
+ x = self.conv_out(x)
48
+ return x
49
+ def forward_v2(self, p_code, c_code, timbre_vec, use_p_code=True, use_c_code=True, n_c=2):
50
+ x = self.encoder(torch.cat([p_code, c_code], dim=1), timbre_vec)
51
+ x = self.conv_out(x).transpose(1, 2)
52
+ return x
53
+ @torch.no_grad()
54
+ def generate(self, prompt_ids, input_ids, prompt_context, timbre, use_p_code=True, use_c_code=True, n_c=2):
55
+ from modules.mamba import InferenceParams
56
+ assert self.encoder_type == "mamba"
57
+ inference_params = InferenceParams(max_seqlen=8192, max_batch_size=1)
58
+ # run once with prompt to initialize memory first
59
+ prompt_out = self.encoder(prompt_ids, prompt_context, timbre, inference_params=inference_params)
60
+ for i in range(input_ids.size(-1)):
61
+ input_id = input_ids[..., i]
62
+ prompt_out = self.encoder(input_id, prompt_out, timbre, inference_params=inference_params)
63
+
modules/style_encoder.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from . import attentions
2
+ from torch import nn
3
+ import torch
4
+ from torch.nn import functional as F
5
+
6
+ class Mish(nn.Module):
7
+ def __init__(self):
8
+ super(Mish, self).__init__()
9
+ def forward(self, x):
10
+ return x * torch.tanh(F.softplus(x))
11
+
12
+
13
+ class Conv1dGLU(nn.Module):
14
+ '''
15
+ Conv1d + GLU(Gated Linear Unit) with residual connection.
16
+ For GLU refer to https://arxiv.org/abs/1612.08083 paper.
17
+ '''
18
+
19
+ def __init__(self, in_channels, out_channels, kernel_size, dropout):
20
+ super(Conv1dGLU, self).__init__()
21
+ self.out_channels = out_channels
22
+ self.conv1 = nn.Conv1d(in_channels, 2 * out_channels, kernel_size=kernel_size, padding=2)
23
+ self.dropout = nn.Dropout(dropout)
24
+
25
+ def forward(self, x):
26
+ residual = x
27
+ x = self.conv1(x)
28
+ x1, x2 = torch.split(x, split_size_or_sections=self.out_channels, dim=1)
29
+ x = x1 * torch.sigmoid(x2)
30
+ x = residual + self.dropout(x)
31
+ return x
32
+
33
+ class StyleEncoder(torch.nn.Module):
34
+ def __init__(self, in_dim=513, hidden_dim=128, out_dim=256):
35
+
36
+ super().__init__()
37
+
38
+ self.in_dim = in_dim # Linear 513 wav2vec 2.0 1024
39
+ self.hidden_dim = hidden_dim
40
+ self.out_dim = out_dim
41
+ self.kernel_size = 5
42
+ self.n_head = 2
43
+ self.dropout = 0.1
44
+
45
+ self.spectral = nn.Sequential(
46
+ nn.Conv1d(self.in_dim, self.hidden_dim, 1),
47
+ Mish(),
48
+ nn.Dropout(self.dropout),
49
+ nn.Conv1d(self.hidden_dim, self.hidden_dim, 1),
50
+ Mish(),
51
+ nn.Dropout(self.dropout)
52
+ )
53
+
54
+ self.temporal = nn.Sequential(
55
+ Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout),
56
+ Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout),
57
+ )
58
+
59
+ self.slf_attn = attentions.MultiHeadAttention(self.hidden_dim, self.hidden_dim, self.n_head, p_dropout = self.dropout, proximal_bias= False, proximal_init=True)
60
+ self.atten_drop = nn.Dropout(self.dropout)
61
+ self.fc = nn.Conv1d(self.hidden_dim, self.out_dim, 1)
62
+
63
+ def forward(self, x, mask=None):
64
+
65
+ # spectral
66
+ x = self.spectral(x)*mask
67
+ # temporal
68
+ x = self.temporal(x)*mask
69
+
70
+ # self-attention
71
+ attn_mask = mask.unsqueeze(2) * mask.unsqueeze(-1)
72
+ y = self.slf_attn(x,x, attn_mask=attn_mask)
73
+ x = x+ self.atten_drop(y)
74
+
75
+ # fc
76
+ x = self.fc(x)
77
+
78
+ # temoral average pooling
79
+ w = self.temporal_avg_pool(x, mask=mask)
80
+
81
+ return w
82
+
83
+ def temporal_avg_pool(self, x, mask=None):
84
+ if mask is None:
85
+ out = torch.mean(x, dim=2)
86
+ else:
87
+ len_ = mask.sum(dim=2)
88
+ x = x.sum(dim=2)
89
+
90
+ out = torch.div(x, len_)
91
+ return out
modules/wavenet.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+
6
+ from dac.model.encodec import SConv1d
7
+
8
+ from . import commons
9
+ LRELU_SLOPE = 0.1
10
+
11
+ class LayerNorm(nn.Module):
12
+ def __init__(self, channels, eps=1e-5):
13
+ super().__init__()
14
+ self.channels = channels
15
+ self.eps = eps
16
+
17
+ self.gamma = nn.Parameter(torch.ones(channels))
18
+ self.beta = nn.Parameter(torch.zeros(channels))
19
+
20
+ def forward(self, x):
21
+ x = x.transpose(1, -1)
22
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
23
+ return x.transpose(1, -1)
24
+
25
+
26
+ class ConvReluNorm(nn.Module):
27
+ def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
28
+ super().__init__()
29
+ self.in_channels = in_channels
30
+ self.hidden_channels = hidden_channels
31
+ self.out_channels = out_channels
32
+ self.kernel_size = kernel_size
33
+ self.n_layers = n_layers
34
+ self.p_dropout = p_dropout
35
+ assert n_layers > 1, "Number of layers should be larger than 0."
36
+
37
+ self.conv_layers = nn.ModuleList()
38
+ self.norm_layers = nn.ModuleList()
39
+ self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
40
+ self.norm_layers.append(LayerNorm(hidden_channels))
41
+ self.relu_drop = nn.Sequential(
42
+ nn.ReLU(),
43
+ nn.Dropout(p_dropout))
44
+ for _ in range(n_layers - 1):
45
+ self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
46
+ self.norm_layers.append(LayerNorm(hidden_channels))
47
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
48
+ self.proj.weight.data.zero_()
49
+ self.proj.bias.data.zero_()
50
+
51
+ def forward(self, x, x_mask):
52
+ x_org = x
53
+ for i in range(self.n_layers):
54
+ x = self.conv_layers[i](x * x_mask)
55
+ x = self.norm_layers[i](x)
56
+ x = self.relu_drop(x)
57
+ x = x_org + self.proj(x)
58
+ return x * x_mask
59
+
60
+
61
+ class DDSConv(nn.Module):
62
+ """
63
+ Dialted and Depth-Separable Convolution
64
+ """
65
+
66
+ def __init__(self, channels, kernel_size, n_layers, p_dropout=0.):
67
+ super().__init__()
68
+ self.channels = channels
69
+ self.kernel_size = kernel_size
70
+ self.n_layers = n_layers
71
+ self.p_dropout = p_dropout
72
+
73
+ self.drop = nn.Dropout(p_dropout)
74
+ self.convs_sep = nn.ModuleList()
75
+ self.convs_1x1 = nn.ModuleList()
76
+ self.norms_1 = nn.ModuleList()
77
+ self.norms_2 = nn.ModuleList()
78
+ for i in range(n_layers):
79
+ dilation = kernel_size ** i
80
+ padding = (kernel_size * dilation - dilation) // 2
81
+ self.convs_sep.append(nn.Conv1d(channels, channels, kernel_size,
82
+ groups=channels, dilation=dilation, padding=padding
83
+ ))
84
+ self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
85
+ self.norms_1.append(LayerNorm(channels))
86
+ self.norms_2.append(LayerNorm(channels))
87
+
88
+ def forward(self, x, x_mask, g=None):
89
+ if g is not None:
90
+ x = x + g
91
+ for i in range(self.n_layers):
92
+ y = self.convs_sep[i](x * x_mask)
93
+ y = self.norms_1[i](y)
94
+ y = F.gelu(y)
95
+ y = self.convs_1x1[i](y)
96
+ y = self.norms_2[i](y)
97
+ y = F.gelu(y)
98
+ y = self.drop(y)
99
+ x = x + y
100
+ return x * x_mask
101
+
102
+
103
+ class WN(torch.nn.Module):
104
+ def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0, causal=False):
105
+ super(WN, self).__init__()
106
+ conv1d_type = SConv1d
107
+ assert (kernel_size % 2 == 1)
108
+ self.hidden_channels = hidden_channels
109
+ self.kernel_size = kernel_size,
110
+ self.dilation_rate = dilation_rate
111
+ self.n_layers = n_layers
112
+ self.gin_channels = gin_channels
113
+ self.p_dropout = p_dropout
114
+
115
+ self.in_layers = torch.nn.ModuleList()
116
+ self.res_skip_layers = torch.nn.ModuleList()
117
+ self.drop = nn.Dropout(p_dropout)
118
+
119
+ if gin_channels != 0:
120
+ self.cond_layer = conv1d_type(gin_channels, 2 * hidden_channels * n_layers, 1, norm='weight_norm')
121
+
122
+ for i in range(n_layers):
123
+ dilation = dilation_rate ** i
124
+ padding = int((kernel_size * dilation - dilation) / 2)
125
+ in_layer = conv1d_type(hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilation,
126
+ padding=padding, norm='weight_norm', causal=causal)
127
+ self.in_layers.append(in_layer)
128
+
129
+ # last one is not necessary
130
+ if i < n_layers - 1:
131
+ res_skip_channels = 2 * hidden_channels
132
+ else:
133
+ res_skip_channels = hidden_channels
134
+
135
+ res_skip_layer = conv1d_type(hidden_channels, res_skip_channels, 1, norm='weight_norm', causal=causal)
136
+ self.res_skip_layers.append(res_skip_layer)
137
+
138
+ def forward(self, x, x_mask, g=None, **kwargs):
139
+ output = torch.zeros_like(x)
140
+ n_channels_tensor = torch.IntTensor([self.hidden_channels])
141
+
142
+ if g is not None:
143
+ g = self.cond_layer(g)
144
+
145
+ for i in range(self.n_layers):
146
+ x_in = self.in_layers[i](x)
147
+ if g is not None:
148
+ cond_offset = i * 2 * self.hidden_channels
149
+ g_l = g[:, cond_offset:cond_offset + 2 * self.hidden_channels, :]
150
+ else:
151
+ g_l = torch.zeros_like(x_in)
152
+
153
+ acts = commons.fused_add_tanh_sigmoid_multiply(
154
+ x_in,
155
+ g_l,
156
+ n_channels_tensor)
157
+ acts = self.drop(acts)
158
+
159
+ res_skip_acts = self.res_skip_layers[i](acts)
160
+ if i < self.n_layers - 1:
161
+ res_acts = res_skip_acts[:, :self.hidden_channels, :]
162
+ x = (x + res_acts) * x_mask
163
+ output = output + res_skip_acts[:, self.hidden_channels:, :]
164
+ else:
165
+ output = output + res_skip_acts
166
+ return output * x_mask
167
+
168
+ def remove_weight_norm(self):
169
+ if self.gin_channels != 0:
170
+ torch.nn.utils.remove_weight_norm(self.cond_layer)
171
+ for l in self.in_layers:
172
+ torch.nn.utils.remove_weight_norm(l)
173
+ for l in self.res_skip_layers:
174
+ torch.nn.utils.remove_weight_norm(l)