diff --git a/alias_free_torch/__init__.py b/alias_free_torch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..080beecd700d74c9c915b7c41c5ab13c79103902 --- /dev/null +++ b/alias_free_torch/__init__.py @@ -0,0 +1,5 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 + +from .filter import * +from .resample import * +from .act import * diff --git a/alias_free_torch/__pycache__/__init__.cpython-310.pyc b/alias_free_torch/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c75b5e594a44efbca854dcafda8e2c180ba51ed5 Binary files /dev/null and b/alias_free_torch/__pycache__/__init__.cpython-310.pyc differ diff --git a/alias_free_torch/__pycache__/act.cpython-310.pyc b/alias_free_torch/__pycache__/act.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1b22efc16b9b2e13222df9a254da6678688eb6d4 Binary files /dev/null and b/alias_free_torch/__pycache__/act.cpython-310.pyc differ diff --git a/alias_free_torch/__pycache__/filter.cpython-310.pyc b/alias_free_torch/__pycache__/filter.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..41bd841e2b9f779466a0b38464f7283956b225ee Binary files /dev/null and b/alias_free_torch/__pycache__/filter.cpython-310.pyc differ diff --git a/alias_free_torch/__pycache__/resample.cpython-310.pyc b/alias_free_torch/__pycache__/resample.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..400811d54bc1f94795475507d40f879ad6b24353 Binary files /dev/null and b/alias_free_torch/__pycache__/resample.cpython-310.pyc differ diff --git a/alias_free_torch/act.py b/alias_free_torch/act.py new file mode 100644 index 0000000000000000000000000000000000000000..adc7eb7469c6fcb25e200abb78425f977d655cd8 --- /dev/null +++ b/alias_free_torch/act.py @@ -0,0 +1,29 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 + +import torch.nn as nn +from .resample import UpSample1d, DownSample1d + + +class Activation1d(nn.Module): + def __init__( + self, + activation, + up_ratio: int = 2, + down_ratio: int = 2, + up_kernel_size: int = 12, + down_kernel_size: int = 12, + ): + super().__init__() + self.up_ratio = up_ratio + self.down_ratio = down_ratio + self.act = activation + self.upsample = UpSample1d(up_ratio, up_kernel_size) + self.downsample = DownSample1d(down_ratio, down_kernel_size) + + # x: [B,C,T] + def forward(self, x): + x = self.upsample(x) + x = self.act(x) + x = self.downsample(x) + + return x diff --git a/alias_free_torch/filter.py b/alias_free_torch/filter.py new file mode 100644 index 0000000000000000000000000000000000000000..8e18bda585d0006c4f34f5ae04bdea854c4e37b7 --- /dev/null +++ b/alias_free_torch/filter.py @@ -0,0 +1,96 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 + +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + +if "sinc" in dir(torch): + sinc = torch.sinc +else: + # This code is adopted from adefossez's julius.core.sinc under the MIT License + # https://adefossez.github.io/julius/julius/core.html + def sinc(x: torch.Tensor): + """ + Implementation of sinc, i.e. sin(pi * x) / (pi * x) + __Warning__: Different to julius.sinc, the input is multiplied by `pi`! + """ + return torch.where( + x == 0, + torch.tensor(1.0, device=x.device, dtype=x.dtype), + torch.sin(math.pi * x) / math.pi / x, + ) + + +# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License +# https://adefossez.github.io/julius/julius/lowpass.html +def kaiser_sinc_filter1d( + cutoff, half_width, kernel_size +): # return filter [1,1,kernel_size] + even = kernel_size % 2 == 0 + half_size = kernel_size // 2 + + # For kaiser window + delta_f = 4 * half_width + A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 + if A > 50.0: + beta = 0.1102 * (A - 8.7) + elif A >= 21.0: + beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0) + else: + beta = 0.0 + window = torch.kaiser_window(kernel_size, beta=beta, periodic=False) + + # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio + if even: + time = torch.arange(-half_size, half_size) + 0.5 + else: + time = torch.arange(kernel_size) - half_size + if cutoff == 0: + filter_ = torch.zeros_like(time) + else: + filter_ = 2 * cutoff * window * sinc(2 * cutoff * time) + # Normalize filter to have sum = 1, otherwise we will have a small leakage + # of the constant component in the input signal. + filter_ /= filter_.sum() + filter = filter_.view(1, 1, kernel_size) + + return filter + + +class LowPassFilter1d(nn.Module): + def __init__( + self, + cutoff=0.5, + half_width=0.6, + stride: int = 1, + padding: bool = True, + padding_mode: str = "replicate", + kernel_size: int = 12, + ): + # kernel_size should be even number for stylegan3 setup, + # in this implementation, odd number is also possible. + super().__init__() + if cutoff < -0.0: + raise ValueError("Minimum cutoff must be larger than zero.") + if cutoff > 0.5: + raise ValueError("A cutoff above 0.5 does not make sense.") + self.kernel_size = kernel_size + self.even = kernel_size % 2 == 0 + self.pad_left = kernel_size // 2 - int(self.even) + self.pad_right = kernel_size // 2 + self.stride = stride + self.padding = padding + self.padding_mode = padding_mode + filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size) + self.register_buffer("filter", filter) + + # input [B, C, T] + def forward(self, x): + _, C, _ = x.shape + + if self.padding: + x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode) + out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C) + + return out diff --git a/alias_free_torch/resample.py b/alias_free_torch/resample.py new file mode 100644 index 0000000000000000000000000000000000000000..dd35bc32b4b30a293ded20f4a66c1578ffbfc54e --- /dev/null +++ b/alias_free_torch/resample.py @@ -0,0 +1,57 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 + +import torch.nn as nn +from torch.nn import functional as F +from .filter import LowPassFilter1d +from .filter import kaiser_sinc_filter1d + + +class UpSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None): + super().__init__() + self.ratio = ratio + self.kernel_size = ( + int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + ) + self.stride = ratio + self.pad = self.kernel_size // ratio - 1 + self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2 + self.pad_right = ( + self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2 + ) + filter = kaiser_sinc_filter1d( + cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size + ) + self.register_buffer("filter", filter) + + # x: [B, C, T] + def forward(self, x): + _, C, _ = x.shape + + x = F.pad(x, (self.pad, self.pad), mode="replicate") + x = self.ratio * F.conv_transpose1d( + x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C + ) + x = x[..., self.pad_left : -self.pad_right] + + return x + + +class DownSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None): + super().__init__() + self.ratio = ratio + self.kernel_size = ( + int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + ) + self.lowpass = LowPassFilter1d( + cutoff=0.5 / ratio, + half_width=0.6 / ratio, + stride=ratio, + kernel_size=self.kernel_size, + ) + + def forward(self, x): + xx = self.lowpass(x) + + return xx diff --git a/dac/__init__.py b/dac/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4a5d03388fa63486960c783ebe7f1bd411b95b1d --- /dev/null +++ b/dac/__init__.py @@ -0,0 +1,16 @@ +__version__ = "1.0.0" + +# preserved here for legacy reasons +__model_version__ = "latest" + +import audiotools + +audiotools.ml.BaseModel.INTERN += ["dac.**"] +audiotools.ml.BaseModel.EXTERN += ["einops"] + + +from . import nn +from . import model +from . import utils +from .model import DAC +from .model import DACFile diff --git a/dac/__main__.py b/dac/__main__.py new file mode 100644 index 0000000000000000000000000000000000000000..393698e7da671bce1478b4d78b2f685d2640636b --- /dev/null +++ b/dac/__main__.py @@ -0,0 +1,36 @@ +import sys + +import argbind + +from dac.utils import download +from dac.utils.decode import decode +from dac.utils.encode import encode + +STAGES = ["encode", "decode", "download"] + + +def run(stage: str): + """Run stages. + + Parameters + ---------- + stage : str + Stage to run + """ + if stage not in STAGES: + raise ValueError(f"Unknown command: {stage}. Allowed commands are {STAGES}") + stage_fn = globals()[stage] + + if stage == "download": + stage_fn() + return + + stage_fn() + + +if __name__ == "__main__": + group = sys.argv.pop(1) + args = argbind.parse_args(group=group) + + with argbind.scope(args): + run(group) diff --git a/dac/__pycache__/__init__.cpython-310.pyc b/dac/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8703ba1470dbfa18097c051ca8a8bf779a7a47f8 Binary files /dev/null and b/dac/__pycache__/__init__.cpython-310.pyc differ diff --git a/dac/model/__init__.py b/dac/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..58b47475d39e4249d2cd45577503bb68ffdacd00 --- /dev/null +++ b/dac/model/__init__.py @@ -0,0 +1,4 @@ +from .base import CodecMixin +from .base import DACFile +from .dac import DAC +from .discriminator import Discriminator diff --git a/dac/model/__pycache__/__init__.cpython-310.pyc b/dac/model/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3625a289baf823d704a88cfb64c0c6c2cfc53c5e Binary files /dev/null and b/dac/model/__pycache__/__init__.cpython-310.pyc differ diff --git a/dac/model/__pycache__/base.cpython-310.pyc b/dac/model/__pycache__/base.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e7180b494c011cb420809ce3beceaf517715114 Binary files /dev/null and b/dac/model/__pycache__/base.cpython-310.pyc differ diff --git a/dac/model/__pycache__/dac.cpython-310.pyc b/dac/model/__pycache__/dac.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b3ebcfc6617bcb6e279f03678dd3373c76024cbb Binary files /dev/null and b/dac/model/__pycache__/dac.cpython-310.pyc differ diff --git a/dac/model/__pycache__/discriminator.cpython-310.pyc b/dac/model/__pycache__/discriminator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6bc1254baf887f24b9351050b1bf4150182c5ca Binary files /dev/null and b/dac/model/__pycache__/discriminator.cpython-310.pyc differ diff --git a/dac/model/__pycache__/encodec.cpython-310.pyc b/dac/model/__pycache__/encodec.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d29068c69504ccd6a65cc9b1e5f05e1bfa7f62b Binary files /dev/null and b/dac/model/__pycache__/encodec.cpython-310.pyc differ diff --git a/dac/model/base.py b/dac/model/base.py new file mode 100644 index 0000000000000000000000000000000000000000..2ef5a44074ca2bd5d726fe90fa7c8c87da1b3b7a --- /dev/null +++ b/dac/model/base.py @@ -0,0 +1,294 @@ +import math +from dataclasses import dataclass +from pathlib import Path +from typing import Union + +import numpy as np +import torch +import tqdm +from audiotools import AudioSignal +from torch import nn + +SUPPORTED_VERSIONS = ["1.0.0"] + + +@dataclass +class DACFile: + codes: torch.Tensor + + # Metadata + chunk_length: int + original_length: int + input_db: float + channels: int + sample_rate: int + padding: bool + dac_version: str + + def save(self, path): + artifacts = { + "codes": self.codes.numpy().astype(np.uint16), + "metadata": { + "input_db": self.input_db.numpy().astype(np.float32), + "original_length": self.original_length, + "sample_rate": self.sample_rate, + "chunk_length": self.chunk_length, + "channels": self.channels, + "padding": self.padding, + "dac_version": SUPPORTED_VERSIONS[-1], + }, + } + path = Path(path).with_suffix(".dac") + with open(path, "wb") as f: + np.save(f, artifacts) + return path + + @classmethod + def load(cls, path): + artifacts = np.load(path, allow_pickle=True)[()] + codes = torch.from_numpy(artifacts["codes"].astype(int)) + if artifacts["metadata"].get("dac_version", None) not in SUPPORTED_VERSIONS: + raise RuntimeError( + f"Given file {path} can't be loaded with this version of descript-audio-codec." + ) + return cls(codes=codes, **artifacts["metadata"]) + + +class CodecMixin: + @property + def padding(self): + if not hasattr(self, "_padding"): + self._padding = True + return self._padding + + @padding.setter + def padding(self, value): + assert isinstance(value, bool) + + layers = [ + l for l in self.modules() if isinstance(l, (nn.Conv1d, nn.ConvTranspose1d)) + ] + + for layer in layers: + if value: + if hasattr(layer, "original_padding"): + layer.padding = layer.original_padding + else: + layer.original_padding = layer.padding + layer.padding = tuple(0 for _ in range(len(layer.padding))) + + self._padding = value + + def get_delay(self): + # Any number works here, delay is invariant to input length + l_out = self.get_output_length(0) + L = l_out + + layers = [] + for layer in self.modules(): + if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)): + layers.append(layer) + + for layer in reversed(layers): + d = layer.dilation[0] + k = layer.kernel_size[0] + s = layer.stride[0] + + if isinstance(layer, nn.ConvTranspose1d): + L = ((L - d * (k - 1) - 1) / s) + 1 + elif isinstance(layer, nn.Conv1d): + L = (L - 1) * s + d * (k - 1) + 1 + + L = math.ceil(L) + + l_in = L + + return (l_in - l_out) // 2 + + def get_output_length(self, input_length): + L = input_length + # Calculate output length + for layer in self.modules(): + if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)): + d = layer.dilation[0] + k = layer.kernel_size[0] + s = layer.stride[0] + + if isinstance(layer, nn.Conv1d): + L = ((L - d * (k - 1) - 1) / s) + 1 + elif isinstance(layer, nn.ConvTranspose1d): + L = (L - 1) * s + d * (k - 1) + 1 + + L = math.floor(L) + return L + + @torch.no_grad() + def compress( + self, + audio_path_or_signal: Union[str, Path, AudioSignal], + win_duration: float = 1.0, + verbose: bool = False, + normalize_db: float = -16, + n_quantizers: int = None, + ) -> DACFile: + """Processes an audio signal from a file or AudioSignal object into + discrete codes. This function processes the signal in short windows, + using constant GPU memory. + + Parameters + ---------- + audio_path_or_signal : Union[str, Path, AudioSignal] + audio signal to reconstruct + win_duration : float, optional + window duration in seconds, by default 5.0 + verbose : bool, optional + by default False + normalize_db : float, optional + normalize db, by default -16 + + Returns + ------- + DACFile + Object containing compressed codes and metadata + required for decompression + """ + audio_signal = audio_path_or_signal + if isinstance(audio_signal, (str, Path)): + audio_signal = AudioSignal.load_from_file_with_ffmpeg(str(audio_signal)) + + self.eval() + original_padding = self.padding + original_device = audio_signal.device + + audio_signal = audio_signal.clone() + original_sr = audio_signal.sample_rate + + resample_fn = audio_signal.resample + loudness_fn = audio_signal.loudness + + # If audio is > 10 minutes long, use the ffmpeg versions + if audio_signal.signal_duration >= 10 * 60 * 60: + resample_fn = audio_signal.ffmpeg_resample + loudness_fn = audio_signal.ffmpeg_loudness + + original_length = audio_signal.signal_length + resample_fn(self.sample_rate) + input_db = loudness_fn() + + if normalize_db is not None: + audio_signal.normalize(normalize_db) + audio_signal.ensure_max_of_audio() + + nb, nac, nt = audio_signal.audio_data.shape + audio_signal.audio_data = audio_signal.audio_data.reshape(nb * nac, 1, nt) + win_duration = ( + audio_signal.signal_duration if win_duration is None else win_duration + ) + + if audio_signal.signal_duration <= win_duration: + # Unchunked compression (used if signal length < win duration) + self.padding = True + n_samples = nt + hop = nt + else: + # Chunked inference + self.padding = False + # Zero-pad signal on either side by the delay + audio_signal.zero_pad(self.delay, self.delay) + n_samples = int(win_duration * self.sample_rate) + # Round n_samples to nearest hop length multiple + n_samples = int(math.ceil(n_samples / self.hop_length) * self.hop_length) + hop = self.get_output_length(n_samples) + + codes = [] + range_fn = range if not verbose else tqdm.trange + + for i in range_fn(0, nt, hop): + x = audio_signal[..., i : i + n_samples] + x = x.zero_pad(0, max(0, n_samples - x.shape[-1])) + + audio_data = x.audio_data.to(self.device) + audio_data = self.preprocess(audio_data, self.sample_rate) + _, c, _, _, _ = self.encode(audio_data, n_quantizers) + codes.append(c.to(original_device)) + chunk_length = c.shape[-1] + + codes = torch.cat(codes, dim=-1) + + dac_file = DACFile( + codes=codes, + chunk_length=chunk_length, + original_length=original_length, + input_db=input_db, + channels=nac, + sample_rate=original_sr, + padding=self.padding, + dac_version=SUPPORTED_VERSIONS[-1], + ) + + if n_quantizers is not None: + codes = codes[:, :n_quantizers, :] + + self.padding = original_padding + return dac_file + + @torch.no_grad() + def decompress( + self, + obj: Union[str, Path, DACFile], + verbose: bool = False, + ) -> AudioSignal: + """Reconstruct audio from a given .dac file + + Parameters + ---------- + obj : Union[str, Path, DACFile] + .dac file location or corresponding DACFile object. + verbose : bool, optional + Prints progress if True, by default False + + Returns + ------- + AudioSignal + Object with the reconstructed audio + """ + self.eval() + if isinstance(obj, (str, Path)): + obj = DACFile.load(obj) + + original_padding = self.padding + self.padding = obj.padding + + range_fn = range if not verbose else tqdm.trange + codes = obj.codes + original_device = codes.device + chunk_length = obj.chunk_length + recons = [] + + for i in range_fn(0, codes.shape[-1], chunk_length): + c = codes[..., i : i + chunk_length].to(self.device) + z = self.quantizer.from_codes(c)[0] + r = self.decode(z) + recons.append(r.to(original_device)) + + recons = torch.cat(recons, dim=-1) + recons = AudioSignal(recons, self.sample_rate) + + resample_fn = recons.resample + loudness_fn = recons.loudness + + # If audio is > 10 minutes long, use the ffmpeg versions + if recons.signal_duration >= 10 * 60 * 60: + resample_fn = recons.ffmpeg_resample + loudness_fn = recons.ffmpeg_loudness + + recons.normalize(obj.input_db) + resample_fn(obj.sample_rate) + recons = recons[..., : obj.original_length] + loudness_fn() + recons.audio_data = recons.audio_data.reshape( + -1, obj.channels, obj.original_length + ) + + self.padding = original_padding + return recons diff --git a/dac/model/dac.py b/dac/model/dac.py new file mode 100644 index 0000000000000000000000000000000000000000..b235ef98454ab7dc0870b555a48fa061d0d9d1e1 --- /dev/null +++ b/dac/model/dac.py @@ -0,0 +1,389 @@ +import math +from typing import List +from typing import Union + +import numpy as np +import torch +from audiotools import AudioSignal +from audiotools.ml import BaseModel +from torch import nn + +from .base import CodecMixin +from dac.nn.layers import Snake1d +from dac.nn.layers import WNConv1d +from dac.nn.layers import WNConvTranspose1d +from dac.nn.quantize import ResidualVectorQuantize +from .encodec import SConv1d, SConvTranspose1d, SLSTM + + +def init_weights(m): + if isinstance(m, nn.Conv1d): + nn.init.trunc_normal_(m.weight, std=0.02) + nn.init.constant_(m.bias, 0) + + +class ResidualUnit(nn.Module): + def __init__(self, dim: int = 16, dilation: int = 1, causal: bool = False): + super().__init__() + conv1d_type = SConv1d# if causal else WNConv1d + pad = ((7 - 1) * dilation) // 2 + self.block = nn.Sequential( + Snake1d(dim), + conv1d_type(dim, dim, kernel_size=7, dilation=dilation, padding=pad, causal=causal, norm='weight_norm'), + Snake1d(dim), + conv1d_type(dim, dim, kernel_size=1, causal=causal, norm='weight_norm'), + ) + + def forward(self, x): + y = self.block(x) + pad = (x.shape[-1] - y.shape[-1]) // 2 + if pad > 0: + x = x[..., pad:-pad] + return x + y + + +class EncoderBlock(nn.Module): + def __init__(self, dim: int = 16, stride: int = 1, causal: bool = False): + super().__init__() + conv1d_type = SConv1d# if causal else WNConv1d + self.block = nn.Sequential( + ResidualUnit(dim // 2, dilation=1, causal=causal), + ResidualUnit(dim // 2, dilation=3, causal=causal), + ResidualUnit(dim // 2, dilation=9, causal=causal), + Snake1d(dim // 2), + conv1d_type( + dim // 2, + dim, + kernel_size=2 * stride, + stride=stride, + padding=math.ceil(stride / 2), + causal=causal, + norm='weight_norm', + ), + ) + + def forward(self, x): + return self.block(x) + + +class Encoder(nn.Module): + def __init__( + self, + d_model: int = 64, + strides: list = [2, 4, 8, 8], + d_latent: int = 64, + causal: bool = False, + lstm: int = 2, + ): + super().__init__() + conv1d_type = SConv1d# if causal else WNConv1d + # Create first convolution + self.block = [conv1d_type(1, d_model, kernel_size=7, padding=3, causal=causal, norm='weight_norm')] + + # Create EncoderBlocks that double channels as they downsample by `stride` + for stride in strides: + d_model *= 2 + self.block += [EncoderBlock(d_model, stride=stride, causal=causal)] + + # Add LSTM if needed + self.use_lstm = lstm + if lstm: + self.block += [SLSTM(d_model, lstm)] + + # Create last convolution + self.block += [ + Snake1d(d_model), + conv1d_type(d_model, d_latent, kernel_size=3, padding=1, causal=causal, norm='weight_norm'), + ] + + # Wrap black into nn.Sequential + self.block = nn.Sequential(*self.block) + self.enc_dim = d_model + + def forward(self, x): + return self.block(x) + + +class DecoderBlock(nn.Module): + def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1, causal: bool = False): + super().__init__() + conv1d_type = SConvTranspose1d #if causal else WNConvTranspose1d + self.block = nn.Sequential( + Snake1d(input_dim), + conv1d_type( + input_dim, + output_dim, + kernel_size=2 * stride, + stride=stride, + padding=math.ceil(stride / 2), + causal=causal, + norm='weight_norm' + ), + ResidualUnit(output_dim, dilation=1, causal=causal), + ResidualUnit(output_dim, dilation=3, causal=causal), + ResidualUnit(output_dim, dilation=9, causal=causal), + ) + + def forward(self, x): + return self.block(x) + + +class Decoder(nn.Module): + def __init__( + self, + input_channel, + channels, + rates, + d_out: int = 1, + causal: bool = False, + lstm: int = 2, + ): + super().__init__() + conv1d_type = SConv1d# if causal else WNConv1d + # Add first conv layer + layers = [conv1d_type(input_channel, channels, kernel_size=7, padding=3, causal=causal, norm='weight_norm')] + + if lstm: + layers += [SLSTM(channels, num_layers=lstm)] + + # Add upsampling + MRF blocks + for i, stride in enumerate(rates): + input_dim = channels // 2**i + output_dim = channels // 2 ** (i + 1) + layers += [DecoderBlock(input_dim, output_dim, stride, causal=causal)] + + # Add final conv layer + layers += [ + Snake1d(output_dim), + conv1d_type(output_dim, d_out, kernel_size=7, padding=3, causal=causal, norm='weight_norm'), + nn.Tanh(), + ] + + self.model = nn.Sequential(*layers) + + def forward(self, x): + return self.model(x) + + +class DAC(BaseModel, CodecMixin): + def __init__( + self, + encoder_dim: int = 64, + encoder_rates: List[int] = [2, 4, 8, 8], + latent_dim: int = None, + decoder_dim: int = 1536, + decoder_rates: List[int] = [8, 8, 4, 2], + n_codebooks: int = 9, + codebook_size: int = 1024, + codebook_dim: Union[int, list] = 8, + quantizer_dropout: bool = False, + sample_rate: int = 44100, + lstm: int = 2, + causal: bool = False, + ): + super().__init__() + + self.encoder_dim = encoder_dim + self.encoder_rates = encoder_rates + self.decoder_dim = decoder_dim + self.decoder_rates = decoder_rates + self.sample_rate = sample_rate + + if latent_dim is None: + latent_dim = encoder_dim * (2 ** len(encoder_rates)) + + self.latent_dim = latent_dim + + self.hop_length = np.prod(encoder_rates) + self.encoder = Encoder(encoder_dim, encoder_rates, latent_dim, causal=causal, lstm=lstm) + + self.n_codebooks = n_codebooks + self.codebook_size = codebook_size + self.codebook_dim = codebook_dim + self.quantizer = ResidualVectorQuantize( + input_dim=latent_dim, + n_codebooks=n_codebooks, + codebook_size=codebook_size, + codebook_dim=codebook_dim, + quantizer_dropout=quantizer_dropout, + ) + + self.decoder = Decoder( + latent_dim, + decoder_dim, + decoder_rates, + lstm=lstm, + causal=causal, + ) + self.sample_rate = sample_rate + self.apply(init_weights) + + self.delay = self.get_delay() + + def preprocess(self, audio_data, sample_rate): + if sample_rate is None: + sample_rate = self.sample_rate + assert sample_rate == self.sample_rate + + length = audio_data.shape[-1] + right_pad = math.ceil(length / self.hop_length) * self.hop_length - length + audio_data = nn.functional.pad(audio_data, (0, right_pad)) + + return audio_data + + def encode( + self, + audio_data: torch.Tensor, + n_quantizers: int = None, + ): + """Encode given audio data and return quantized latent codes + + Parameters + ---------- + audio_data : Tensor[B x 1 x T] + Audio data to encode + n_quantizers : int, optional + Number of quantizers to use, by default None + If None, all quantizers are used. + + Returns + ------- + dict + A dictionary with the following keys: + "z" : Tensor[B x D x T] + Quantized continuous representation of input + "codes" : Tensor[B x N x T] + Codebook indices for each codebook + (quantized discrete representation of input) + "latents" : Tensor[B x N*D x T] + Projected latents (continuous representation of input before quantization) + "vq/commitment_loss" : Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + "vq/codebook_loss" : Tensor[1] + Codebook loss to update the codebook + "length" : int + Number of samples in input audio + """ + z = self.encoder(audio_data) + z, codes, latents, commitment_loss, codebook_loss = self.quantizer( + z, n_quantizers + ) + return z, codes, latents, commitment_loss, codebook_loss + + def decode(self, z: torch.Tensor): + """Decode given latent codes and return audio data + + Parameters + ---------- + z : Tensor[B x D x T] + Quantized continuous representation of input + length : int, optional + Number of samples in output audio, by default None + + Returns + ------- + dict + A dictionary with the following keys: + "audio" : Tensor[B x 1 x length] + Decoded audio data. + """ + return self.decoder(z) + + def forward( + self, + audio_data: torch.Tensor, + sample_rate: int = None, + n_quantizers: int = None, + ): + """Model forward pass + + Parameters + ---------- + audio_data : Tensor[B x 1 x T] + Audio data to encode + sample_rate : int, optional + Sample rate of audio data in Hz, by default None + If None, defaults to `self.sample_rate` + n_quantizers : int, optional + Number of quantizers to use, by default None. + If None, all quantizers are used. + + Returns + ------- + dict + A dictionary with the following keys: + "z" : Tensor[B x D x T] + Quantized continuous representation of input + "codes" : Tensor[B x N x T] + Codebook indices for each codebook + (quantized discrete representation of input) + "latents" : Tensor[B x N*D x T] + Projected latents (continuous representation of input before quantization) + "vq/commitment_loss" : Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + "vq/codebook_loss" : Tensor[1] + Codebook loss to update the codebook + "length" : int + Number of samples in input audio + "audio" : Tensor[B x 1 x length] + Decoded audio data. + """ + length = audio_data.shape[-1] + audio_data = self.preprocess(audio_data, sample_rate) + z, codes, latents, commitment_loss, codebook_loss = self.encode( + audio_data, n_quantizers + ) + + x = self.decode(z) + return { + "audio": x[..., :length], + "z": z, + "codes": codes, + "latents": latents, + "vq/commitment_loss": commitment_loss, + "vq/codebook_loss": codebook_loss, + } + + +if __name__ == "__main__": + import numpy as np + from functools import partial + + model = DAC().to("cpu") + + for n, m in model.named_modules(): + o = m.extra_repr() + p = sum([np.prod(p.size()) for p in m.parameters()]) + fn = lambda o, p: o + f" {p/1e6:<.3f}M params." + setattr(m, "extra_repr", partial(fn, o=o, p=p)) + print(model) + print("Total # of params: ", sum([np.prod(p.size()) for p in model.parameters()])) + + length = 88200 * 2 + x = torch.randn(1, 1, length).to(model.device) + x.requires_grad_(True) + x.retain_grad() + + # Make a forward pass + out = model(x)["audio"] + print("Input shape:", x.shape) + print("Output shape:", out.shape) + + # Create gradient variable + grad = torch.zeros_like(out) + grad[:, :, grad.shape[-1] // 2] = 1 + + # Make a backward pass + out.backward(grad) + + # Check non-zero values + gradmap = x.grad.squeeze(0) + gradmap = (gradmap != 0).sum(0) # sum across features + rf = (gradmap != 0).sum() + + print(f"Receptive field: {rf.item()}") + + x = AudioSignal(torch.randn(1, 1, 44100 * 60), 44100) + model.decompress(model.compress(x, verbose=True), verbose=True) diff --git a/dac/model/discriminator.py b/dac/model/discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..5aa1a307e1dcff19ed36c681eac3dba39ddabb59 --- /dev/null +++ b/dac/model/discriminator.py @@ -0,0 +1,228 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from audiotools import AudioSignal +from audiotools import ml +from audiotools import STFTParams +from einops import rearrange +from torch.nn.utils import weight_norm + + +def WNConv1d(*args, **kwargs): + act = kwargs.pop("act", True) + conv = weight_norm(nn.Conv1d(*args, **kwargs)) + if not act: + return conv + return nn.Sequential(conv, nn.LeakyReLU(0.1)) + + +def WNConv2d(*args, **kwargs): + act = kwargs.pop("act", True) + conv = weight_norm(nn.Conv2d(*args, **kwargs)) + if not act: + return conv + return nn.Sequential(conv, nn.LeakyReLU(0.1)) + + +class MPD(nn.Module): + def __init__(self, period): + super().__init__() + self.period = period + self.convs = nn.ModuleList( + [ + WNConv2d(1, 32, (5, 1), (3, 1), padding=(2, 0)), + WNConv2d(32, 128, (5, 1), (3, 1), padding=(2, 0)), + WNConv2d(128, 512, (5, 1), (3, 1), padding=(2, 0)), + WNConv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0)), + WNConv2d(1024, 1024, (5, 1), 1, padding=(2, 0)), + ] + ) + self.conv_post = WNConv2d( + 1024, 1, kernel_size=(3, 1), padding=(1, 0), act=False + ) + + def pad_to_period(self, x): + t = x.shape[-1] + x = F.pad(x, (0, self.period - t % self.period), mode="reflect") + return x + + def forward(self, x): + fmap = [] + + x = self.pad_to_period(x) + x = rearrange(x, "b c (l p) -> b c l p", p=self.period) + + for layer in self.convs: + x = layer(x) + fmap.append(x) + + x = self.conv_post(x) + fmap.append(x) + + return fmap + + +class MSD(nn.Module): + def __init__(self, rate: int = 1, sample_rate: int = 44100): + super().__init__() + self.convs = nn.ModuleList( + [ + WNConv1d(1, 16, 15, 1, padding=7), + WNConv1d(16, 64, 41, 4, groups=4, padding=20), + WNConv1d(64, 256, 41, 4, groups=16, padding=20), + WNConv1d(256, 1024, 41, 4, groups=64, padding=20), + WNConv1d(1024, 1024, 41, 4, groups=256, padding=20), + WNConv1d(1024, 1024, 5, 1, padding=2), + ] + ) + self.conv_post = WNConv1d(1024, 1, 3, 1, padding=1, act=False) + self.sample_rate = sample_rate + self.rate = rate + + def forward(self, x): + x = AudioSignal(x, self.sample_rate) + x.resample(self.sample_rate // self.rate) + x = x.audio_data + + fmap = [] + + for l in self.convs: + x = l(x) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + + return fmap + + +BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)] + + +class MRD(nn.Module): + def __init__( + self, + window_length: int, + hop_factor: float = 0.25, + sample_rate: int = 44100, + bands: list = BANDS, + ): + """Complex multi-band spectrogram discriminator. + Parameters + ---------- + window_length : int + Window length of STFT. + hop_factor : float, optional + Hop factor of the STFT, defaults to ``0.25 * window_length``. + sample_rate : int, optional + Sampling rate of audio in Hz, by default 44100 + bands : list, optional + Bands to run discriminator over. + """ + super().__init__() + + self.window_length = window_length + self.hop_factor = hop_factor + self.sample_rate = sample_rate + self.stft_params = STFTParams( + window_length=window_length, + hop_length=int(window_length * hop_factor), + match_stride=True, + ) + + n_fft = window_length // 2 + 1 + bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands] + self.bands = bands + + ch = 32 + convs = lambda: nn.ModuleList( + [ + WNConv2d(2, ch, (3, 9), (1, 1), padding=(1, 4)), + WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), + WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), + WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), + WNConv2d(ch, ch, (3, 3), (1, 1), padding=(1, 1)), + ] + ) + self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))]) + self.conv_post = WNConv2d(ch, 1, (3, 3), (1, 1), padding=(1, 1), act=False) + + def spectrogram(self, x): + x = AudioSignal(x, self.sample_rate, stft_params=self.stft_params) + x = torch.view_as_real(x.stft()) + x = rearrange(x, "b 1 f t c -> (b 1) c t f") + # Split into bands + x_bands = [x[..., b[0] : b[1]] for b in self.bands] + return x_bands + + def forward(self, x): + x_bands = self.spectrogram(x) + fmap = [] + + x = [] + for band, stack in zip(x_bands, self.band_convs): + for layer in stack: + band = layer(band) + fmap.append(band) + x.append(band) + + x = torch.cat(x, dim=-1) + x = self.conv_post(x) + fmap.append(x) + + return fmap + + +class Discriminator(nn.Module): + def __init__( + self, + rates: list = [], + periods: list = [2, 3, 5, 7, 11], + fft_sizes: list = [2048, 1024, 512], + sample_rate: int = 44100, + bands: list = BANDS, + ): + """Discriminator that combines multiple discriminators. + + Parameters + ---------- + rates : list, optional + sampling rates (in Hz) to run MSD at, by default [] + If empty, MSD is not used. + periods : list, optional + periods (of samples) to run MPD at, by default [2, 3, 5, 7, 11] + fft_sizes : list, optional + Window sizes of the FFT to run MRD at, by default [2048, 1024, 512] + sample_rate : int, optional + Sampling rate of audio in Hz, by default 44100 + bands : list, optional + Bands to run MRD at, by default `BANDS` + """ + super().__init__() + discs = [] + discs += [MPD(p) for p in periods] + discs += [MSD(r, sample_rate=sample_rate) for r in rates] + discs += [MRD(f, sample_rate=sample_rate, bands=bands) for f in fft_sizes] + self.discriminators = nn.ModuleList(discs) + + def preprocess(self, y): + # Remove DC offset + y = y - y.mean(dim=-1, keepdims=True) + # Peak normalize the volume of input audio + y = 0.8 * y / (y.abs().max(dim=-1, keepdim=True)[0] + 1e-9) + return y + + def forward(self, x): + x = self.preprocess(x) + fmaps = [d(x) for d in self.discriminators] + return fmaps + + +if __name__ == "__main__": + disc = Discriminator() + x = torch.zeros(1, 1, 44100) + results = disc(x) + for i, result in enumerate(results): + print(f"disc{i}") + for i, r in enumerate(result): + print(r.shape, r.mean(), r.min(), r.max()) + print() diff --git a/dac/model/encodec.py b/dac/model/encodec.py new file mode 100644 index 0000000000000000000000000000000000000000..1cb8cd367370357a6269b86f1e7fe6f749c64efd --- /dev/null +++ b/dac/model/encodec.py @@ -0,0 +1,288 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Convolutional layers wrappers and utilities.""" + +import math +import typing as tp +import warnings + +import torch +from torch import nn +from torch.nn import functional as F +from torch.nn.utils import spectral_norm, weight_norm + +import typing as tp + +import einops + + +class ConvLayerNorm(nn.LayerNorm): + """ + Convolution-friendly LayerNorm that moves channels to last dimensions + before running the normalization and moves them back to original position right after. + """ + def __init__(self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs): + super().__init__(normalized_shape, **kwargs) + + def forward(self, x): + x = einops.rearrange(x, 'b ... t -> b t ...') + x = super().forward(x) + x = einops.rearrange(x, 'b t ... -> b ... t') + return + + +CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm', + 'time_layer_norm', 'layer_norm', 'time_group_norm']) + + +def apply_parametrization_norm(module: nn.Module, norm: str = 'none') -> nn.Module: + assert norm in CONV_NORMALIZATIONS + if norm == 'weight_norm': + return weight_norm(module) + elif norm == 'spectral_norm': + return spectral_norm(module) + else: + # We already check was in CONV_NORMALIZATION, so any other choice + # doesn't need reparametrization. + return module + + +def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs) -> nn.Module: + """Return the proper normalization module. If causal is True, this will ensure the returned + module is causal, or return an error if the normalization doesn't support causal evaluation. + """ + assert norm in CONV_NORMALIZATIONS + if norm == 'layer_norm': + assert isinstance(module, nn.modules.conv._ConvNd) + return ConvLayerNorm(module.out_channels, **norm_kwargs) + elif norm == 'time_group_norm': + if causal: + raise ValueError("GroupNorm doesn't support causal evaluation.") + assert isinstance(module, nn.modules.conv._ConvNd) + return nn.GroupNorm(1, module.out_channels, **norm_kwargs) + else: + return nn.Identity() + + +def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, + padding_total: int = 0) -> int: + """See `pad_for_conv1d`. + """ + length = x.shape[-1] + n_frames = (length - kernel_size + padding_total) / stride + 1 + ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total) + return ideal_length - length + + +def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0): + """Pad for a convolution to make sure that the last window is full. + Extra padding is added at the end. This is required to ensure that we can rebuild + an output of the same length, as otherwise, even with padding, some time steps + might get removed. + For instance, with total padding = 4, kernel size = 4, stride = 2: + 0 0 1 2 3 4 5 0 0 # (0s are padding) + 1 2 3 # (output frames of a convolution, last 0 is never used) + 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding) + 1 2 3 4 # once you removed padding, we are missing one time step ! + """ + extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) + return F.pad(x, (0, extra_padding)) + + +def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'zero', value: float = 0.): + """Tiny wrapper around F.pad, just to allow for reflect padding on small input. + If this is the case, we insert extra 0 padding to the right before the reflection happen. + """ + length = x.shape[-1] + padding_left, padding_right = paddings + assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) + if mode == 'reflect': + max_pad = max(padding_left, padding_right) + extra_pad = 0 + if length <= max_pad: + extra_pad = max_pad - length + 1 + x = F.pad(x, (0, extra_pad)) + padded = F.pad(x, paddings, mode, value) + end = padded.shape[-1] - extra_pad + return padded[..., :end] + else: + return F.pad(x, paddings, mode, value) + + +def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]): + """Remove padding from x, handling properly zero padding. Only for 1d!""" + padding_left, padding_right = paddings + assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) + assert (padding_left + padding_right) <= x.shape[-1] + end = x.shape[-1] - padding_right + return x[..., padding_left: end] + + +class NormConv1d(nn.Module): + """Wrapper around Conv1d and normalization applied to this conv + to provide a uniform interface across normalization approaches. + """ + def __init__(self, *args, causal: bool = False, norm: str = 'none', + norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): + super().__init__() + self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm) + self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs) + self.norm_type = norm + + def forward(self, x): + x = self.conv(x) + x = self.norm(x) + return x + + +class NormConv2d(nn.Module): + """Wrapper around Conv2d and normalization applied to this conv + to provide a uniform interface across normalization approaches. + """ + def __init__(self, *args, norm: str = 'none', + norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): + super().__init__() + self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm) + self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs) + self.norm_type = norm + + def forward(self, x): + x = self.conv(x) + x = self.norm(x) + return x + + +class NormConvTranspose1d(nn.Module): + """Wrapper around ConvTranspose1d and normalization applied to this conv + to provide a uniform interface across normalization approaches. + """ + def __init__(self, *args, causal: bool = False, norm: str = 'none', + norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): + super().__init__() + self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm) + self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs) + self.norm_type = norm + + def forward(self, x): + x = self.convtr(x) + x = self.norm(x) + return x + + +class NormConvTranspose2d(nn.Module): + """Wrapper around ConvTranspose2d and normalization applied to this conv + to provide a uniform interface across normalization approaches. + """ + def __init__(self, *args, norm: str = 'none', + norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): + super().__init__() + self.convtr = apply_parametrization_norm(nn.ConvTranspose2d(*args, **kwargs), norm) + self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs) + + def forward(self, x): + x = self.convtr(x) + x = self.norm(x) + return x + + +class SConv1d(nn.Module): + """Conv1d with some builtin handling of asymmetric or causal padding + and normalization. + """ + def __init__(self, in_channels: int, out_channels: int, + kernel_size: int, stride: int = 1, dilation: int = 1, + groups: int = 1, bias: bool = True, causal: bool = False, + norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, + pad_mode: str = 'reflect', **kwargs): + super().__init__() + # warn user on unusual setup between dilation and stride + if stride > 1 and dilation > 1: + warnings.warn('SConv1d has been initialized with stride > 1 and dilation > 1' + f' (kernel_size={kernel_size} stride={stride}, dilation={dilation}).') + self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride, + dilation=dilation, groups=groups, bias=bias, causal=causal, + norm=norm, norm_kwargs=norm_kwargs) + self.causal = causal + self.pad_mode = pad_mode + + def forward(self, x): + B, C, T = x.shape + kernel_size = self.conv.conv.kernel_size[0] + stride = self.conv.conv.stride[0] + dilation = self.conv.conv.dilation[0] + kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations + padding_total = kernel_size - stride + extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) + if self.causal: + # Left padding for causal + x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode) + else: + # Asymmetric padding required for odd strides + padding_right = padding_total // 2 + padding_left = padding_total - padding_right + x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode) + return self.conv(x) + + +class SConvTranspose1d(nn.Module): + """ConvTranspose1d with some builtin handling of asymmetric or causal padding + and normalization. + """ + def __init__(self, in_channels: int, out_channels: int, + kernel_size: int, stride: int = 1, causal: bool = False, + norm: str = 'none', trim_right_ratio: float = 1., + norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): + super().__init__() + self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride, + causal=causal, norm=norm, norm_kwargs=norm_kwargs) + self.causal = causal + self.trim_right_ratio = trim_right_ratio + assert self.causal or self.trim_right_ratio == 1., \ + "`trim_right_ratio` != 1.0 only makes sense for causal convolutions" + assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1. + + def forward(self, x): + kernel_size = self.convtr.convtr.kernel_size[0] + stride = self.convtr.convtr.stride[0] + padding_total = kernel_size - stride + + y = self.convtr(x) + + # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be + # removed at the very end, when keeping only the right length for the output, + # as removing it here would require also passing the length at the matching layer + # in the encoder. + if self.causal: + # Trim the padding on the right according to the specified ratio + # if trim_right_ratio = 1.0, trim everything from right + padding_right = math.ceil(padding_total * self.trim_right_ratio) + padding_left = padding_total - padding_right + y = unpad1d(y, (padding_left, padding_right)) + else: + # Asymmetric padding required for odd strides + padding_right = padding_total // 2 + padding_left = padding_total - padding_right + y = unpad1d(y, (padding_left, padding_right)) + return y + +class SLSTM(nn.Module): + """ + LSTM without worrying about the hidden state, nor the layout of the data. + Expects input as convolutional layout. + """ + def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True): + super().__init__() + self.skip = skip + self.lstm = nn.LSTM(dimension, dimension, num_layers) + + def forward(self, x): + x = x.permute(2, 0, 1) + y, _ = self.lstm(x) + if self.skip: + y = y + x + y = y.permute(1, 2, 0) + return y \ No newline at end of file diff --git a/dac/nn/__init__.py b/dac/nn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ec2f6b972aeb8a21764ff73dae2095eb94bb8ba4 --- /dev/null +++ b/dac/nn/__init__.py @@ -0,0 +1,3 @@ +from . import layers +from . import loss +from . import quantize diff --git a/dac/nn/__pycache__/__init__.cpython-310.pyc b/dac/nn/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..722ff5a805bdba09e61bf3877556d2c7a17fc6e8 Binary files /dev/null and b/dac/nn/__pycache__/__init__.cpython-310.pyc differ diff --git a/dac/nn/__pycache__/layers.cpython-310.pyc b/dac/nn/__pycache__/layers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6634f0fa488939cc118ab224b42dbd92596a9df8 Binary files /dev/null and b/dac/nn/__pycache__/layers.cpython-310.pyc differ diff --git a/dac/nn/__pycache__/loss.cpython-310.pyc b/dac/nn/__pycache__/loss.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..93fc8803daba44278eeb5c020785e084f3496796 Binary files /dev/null and b/dac/nn/__pycache__/loss.cpython-310.pyc differ diff --git a/dac/nn/__pycache__/quantize.cpython-310.pyc b/dac/nn/__pycache__/quantize.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..babfe8c95cb39cec26141dfb32040194098f0418 Binary files /dev/null and b/dac/nn/__pycache__/quantize.cpython-310.pyc differ diff --git a/dac/nn/layers.py b/dac/nn/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..63c53a88b881deab94fb06b6a395951cf9cc995a --- /dev/null +++ b/dac/nn/layers.py @@ -0,0 +1,33 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch.nn.utils import weight_norm + + +def WNConv1d(*args, **kwargs): + return weight_norm(nn.Conv1d(*args, **kwargs)) + + +def WNConvTranspose1d(*args, **kwargs): + return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) + + +# Scripting this brings model speed up 1.4x +@torch.jit.script +def snake(x, alpha): + shape = x.shape + x = x.reshape(shape[0], shape[1], -1) + x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2) + x = x.reshape(shape) + return x + + +class Snake1d(nn.Module): + def __init__(self, channels): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1, channels, 1)) + + def forward(self, x): + return snake(x, self.alpha) diff --git a/dac/nn/loss.py b/dac/nn/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..7a26e62f885f0bb7c6774e95660a7848e10a8f87 --- /dev/null +++ b/dac/nn/loss.py @@ -0,0 +1,368 @@ +import typing +from typing import List + +import torch +import torch.nn.functional as F +from audiotools import AudioSignal +from audiotools import STFTParams +from torch import nn + + +class L1Loss(nn.L1Loss): + """L1 Loss between AudioSignals. Defaults + to comparing ``audio_data``, but any + attribute of an AudioSignal can be used. + + Parameters + ---------- + attribute : str, optional + Attribute of signal to compare, defaults to ``audio_data``. + weight : float, optional + Weight of this loss, defaults to 1.0. + + Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py + """ + + def __init__(self, attribute: str = "audio_data", weight: float = 1.0, **kwargs): + self.attribute = attribute + self.weight = weight + super().__init__(**kwargs) + + def forward(self, x: AudioSignal, y: AudioSignal): + """ + Parameters + ---------- + x : AudioSignal + Estimate AudioSignal + y : AudioSignal + Reference AudioSignal + + Returns + ------- + torch.Tensor + L1 loss between AudioSignal attributes. + """ + if isinstance(x, AudioSignal): + x = getattr(x, self.attribute) + y = getattr(y, self.attribute) + return super().forward(x, y) + + +class SISDRLoss(nn.Module): + """ + Computes the Scale-Invariant Source-to-Distortion Ratio between a batch + of estimated and reference audio signals or aligned features. + + Parameters + ---------- + scaling : int, optional + Whether to use scale-invariant (True) or + signal-to-noise ratio (False), by default True + reduction : str, optional + How to reduce across the batch (either 'mean', + 'sum', or none).], by default ' mean' + zero_mean : int, optional + Zero mean the references and estimates before + computing the loss, by default True + clip_min : int, optional + The minimum possible loss value. Helps network + to not focus on making already good examples better, by default None + weight : float, optional + Weight of this loss, defaults to 1.0. + + Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py + """ + + def __init__( + self, + scaling: int = True, + reduction: str = "mean", + zero_mean: int = True, + clip_min: int = None, + weight: float = 1.0, + ): + self.scaling = scaling + self.reduction = reduction + self.zero_mean = zero_mean + self.clip_min = clip_min + self.weight = weight + super().__init__() + + def forward(self, x: AudioSignal, y: AudioSignal): + eps = 1e-8 + # nb, nc, nt + if isinstance(x, AudioSignal): + references = x.audio_data + estimates = y.audio_data + else: + references = x + estimates = y + + nb = references.shape[0] + references = references.reshape(nb, 1, -1).permute(0, 2, 1) + estimates = estimates.reshape(nb, 1, -1).permute(0, 2, 1) + + # samples now on axis 1 + if self.zero_mean: + mean_reference = references.mean(dim=1, keepdim=True) + mean_estimate = estimates.mean(dim=1, keepdim=True) + else: + mean_reference = 0 + mean_estimate = 0 + + _references = references - mean_reference + _estimates = estimates - mean_estimate + + references_projection = (_references**2).sum(dim=-2) + eps + references_on_estimates = (_estimates * _references).sum(dim=-2) + eps + + scale = ( + (references_on_estimates / references_projection).unsqueeze(1) + if self.scaling + else 1 + ) + + e_true = scale * _references + e_res = _estimates - e_true + + signal = (e_true**2).sum(dim=1) + noise = (e_res**2).sum(dim=1) + sdr = -10 * torch.log10(signal / noise + eps) + + if self.clip_min is not None: + sdr = torch.clamp(sdr, min=self.clip_min) + + if self.reduction == "mean": + sdr = sdr.mean() + elif self.reduction == "sum": + sdr = sdr.sum() + return sdr + + +class MultiScaleSTFTLoss(nn.Module): + """Computes the multi-scale STFT loss from [1]. + + Parameters + ---------- + window_lengths : List[int], optional + Length of each window of each STFT, by default [2048, 512] + loss_fn : typing.Callable, optional + How to compare each loss, by default nn.L1Loss() + clamp_eps : float, optional + Clamp on the log magnitude, below, by default 1e-5 + mag_weight : float, optional + Weight of raw magnitude portion of loss, by default 1.0 + log_weight : float, optional + Weight of log magnitude portion of loss, by default 1.0 + pow : float, optional + Power to raise magnitude to before taking log, by default 2.0 + weight : float, optional + Weight of this loss, by default 1.0 + match_stride : bool, optional + Whether to match the stride of convolutional layers, by default False + + References + ---------- + + 1. Engel, Jesse, Chenjie Gu, and Adam Roberts. + "DDSP: Differentiable Digital Signal Processing." + International Conference on Learning Representations. 2019. + + Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py + """ + + def __init__( + self, + window_lengths: List[int] = [2048, 512], + loss_fn: typing.Callable = nn.L1Loss(), + clamp_eps: float = 1e-5, + mag_weight: float = 1.0, + log_weight: float = 1.0, + pow: float = 2.0, + weight: float = 1.0, + match_stride: bool = False, + window_type: str = None, + ): + super().__init__() + self.stft_params = [ + STFTParams( + window_length=w, + hop_length=w // 4, + match_stride=match_stride, + window_type=window_type, + ) + for w in window_lengths + ] + self.loss_fn = loss_fn + self.log_weight = log_weight + self.mag_weight = mag_weight + self.clamp_eps = clamp_eps + self.weight = weight + self.pow = pow + + def forward(self, x: AudioSignal, y: AudioSignal): + """Computes multi-scale STFT between an estimate and a reference + signal. + + Parameters + ---------- + x : AudioSignal + Estimate signal + y : AudioSignal + Reference signal + + Returns + ------- + torch.Tensor + Multi-scale STFT loss. + """ + loss = 0.0 + for s in self.stft_params: + x.stft(s.window_length, s.hop_length, s.window_type) + y.stft(s.window_length, s.hop_length, s.window_type) + loss += self.log_weight * self.loss_fn( + x.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(), + y.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(), + ) + loss += self.mag_weight * self.loss_fn(x.magnitude, y.magnitude) + return loss + + +class MelSpectrogramLoss(nn.Module): + """Compute distance between mel spectrograms. Can be used + in a multi-scale way. + + Parameters + ---------- + n_mels : List[int] + Number of mels per STFT, by default [150, 80], + window_lengths : List[int], optional + Length of each window of each STFT, by default [2048, 512] + loss_fn : typing.Callable, optional + How to compare each loss, by default nn.L1Loss() + clamp_eps : float, optional + Clamp on the log magnitude, below, by default 1e-5 + mag_weight : float, optional + Weight of raw magnitude portion of loss, by default 1.0 + log_weight : float, optional + Weight of log magnitude portion of loss, by default 1.0 + pow : float, optional + Power to raise magnitude to before taking log, by default 2.0 + weight : float, optional + Weight of this loss, by default 1.0 + match_stride : bool, optional + Whether to match the stride of convolutional layers, by default False + + Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py + """ + + def __init__( + self, + n_mels: List[int] = [150, 80], + window_lengths: List[int] = [2048, 512], + loss_fn: typing.Callable = nn.L1Loss(), + clamp_eps: float = 1e-5, + mag_weight: float = 1.0, + log_weight: float = 1.0, + pow: float = 2.0, + weight: float = 1.0, + match_stride: bool = False, + mel_fmin: List[float] = [0.0, 0.0], + mel_fmax: List[float] = [None, None], + window_type: str = None, + ): + super().__init__() + self.stft_params = [ + STFTParams( + window_length=w, + hop_length=w // 4, + match_stride=match_stride, + window_type=window_type, + ) + for w in window_lengths + ] + self.n_mels = n_mels + self.loss_fn = loss_fn + self.clamp_eps = clamp_eps + self.log_weight = log_weight + self.mag_weight = mag_weight + self.weight = weight + self.mel_fmin = mel_fmin + self.mel_fmax = mel_fmax + self.pow = pow + + def forward(self, x: AudioSignal, y: AudioSignal): + """Computes mel loss between an estimate and a reference + signal. + + Parameters + ---------- + x : AudioSignal + Estimate signal + y : AudioSignal + Reference signal + + Returns + ------- + torch.Tensor + Mel loss. + """ + loss = 0.0 + for n_mels, fmin, fmax, s in zip( + self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params + ): + kwargs = { + "window_length": s.window_length, + "hop_length": s.hop_length, + "window_type": s.window_type, + } + x_mels = x.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs) + y_mels = y.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs) + + loss += self.log_weight * self.loss_fn( + x_mels.clamp(self.clamp_eps).pow(self.pow).log10(), + y_mels.clamp(self.clamp_eps).pow(self.pow).log10(), + ) + loss += self.mag_weight * self.loss_fn(x_mels, y_mels) + return loss + + +class GANLoss(nn.Module): + """ + Computes a discriminator loss, given a discriminator on + generated waveforms/spectrograms compared to ground truth + waveforms/spectrograms. Computes the loss for both the + discriminator and the generator in separate functions. + """ + + def __init__(self, discriminator): + super().__init__() + self.discriminator = discriminator + + def forward(self, fake, real): + d_fake = self.discriminator(fake.audio_data) + d_real = self.discriminator(real.audio_data) + return d_fake, d_real + + def discriminator_loss(self, fake, real): + d_fake, d_real = self.forward(fake.clone().detach(), real) + + loss_d = 0 + for x_fake, x_real in zip(d_fake, d_real): + loss_d += torch.mean(x_fake[-1] ** 2) + loss_d += torch.mean((1 - x_real[-1]) ** 2) + return loss_d + + def generator_loss(self, fake, real): + d_fake, d_real = self.forward(fake, real) + + loss_g = 0 + for x_fake in d_fake: + loss_g += torch.mean((1 - x_fake[-1]) ** 2) + + loss_feature = 0 + + for i in range(len(d_fake)): + for j in range(len(d_fake[i]) - 1): + loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach()) + return loss_g, loss_feature diff --git a/dac/nn/quantize.py b/dac/nn/quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..333014ab814d47df1778ae34c1ba0ff55e90f480 --- /dev/null +++ b/dac/nn/quantize.py @@ -0,0 +1,262 @@ +from typing import Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch.nn.utils import weight_norm + +from dac.nn.layers import WNConv1d + + +class VectorQuantize(nn.Module): + """ + Implementation of VQ similar to Karpathy's repo: + https://github.com/karpathy/deep-vector-quantization + Additionally uses following tricks from Improved VQGAN + (https://arxiv.org/pdf/2110.04627.pdf): + 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space + for improved codebook usage + 2. l2-normalized codes: Converts euclidean distance to cosine similarity which + improves training stability + """ + + def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int): + super().__init__() + self.codebook_size = codebook_size + self.codebook_dim = codebook_dim + + self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1) + self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1) + self.codebook = nn.Embedding(codebook_size, codebook_dim) + + def forward(self, z): + """Quantized the input tensor using a fixed codebook and returns + the corresponding codebook vectors + + Parameters + ---------- + z : Tensor[B x D x T] + + Returns + ------- + Tensor[B x D x T] + Quantized continuous representation of input + Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + Tensor[1] + Codebook loss to update the codebook + Tensor[B x T] + Codebook indices (quantized discrete representation of input) + Tensor[B x D x T] + Projected latents (continuous representation of input before quantization) + """ + + # Factorized codes (ViT-VQGAN) Project input into low-dimensional space + z_e = self.in_proj(z) # z_e : (B x D x T) + z_q, indices = self.decode_latents(z_e) + + commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2]) + codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2]) + + z_q = ( + z_e + (z_q - z_e).detach() + ) # noop in forward pass, straight-through gradient estimator in backward pass + + z_q = self.out_proj(z_q) + + return z_q, commitment_loss, codebook_loss, indices, z_e + + def embed_code(self, embed_id): + return F.embedding(embed_id, self.codebook.weight) + + def decode_code(self, embed_id): + return self.embed_code(embed_id).transpose(1, 2) + + def decode_latents(self, latents): + encodings = rearrange(latents, "b d t -> (b t) d") + codebook = self.codebook.weight # codebook: (N x D) + + # L2 normalize encodings and codebook (ViT-VQGAN) + encodings = F.normalize(encodings) + codebook = F.normalize(codebook) + + # Compute euclidean distance with codebook + dist = ( + encodings.pow(2).sum(1, keepdim=True) + - 2 * encodings @ codebook.t() + + codebook.pow(2).sum(1, keepdim=True).t() + ) + indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0)) + z_q = self.decode_code(indices) + return z_q, indices + + +class ResidualVectorQuantize(nn.Module): + """ + Introduced in SoundStream: An end2end neural audio codec + https://arxiv.org/abs/2107.03312 + """ + + def __init__( + self, + input_dim: int = 512, + n_codebooks: int = 9, + codebook_size: int = 1024, + codebook_dim: Union[int, list] = 8, + quantizer_dropout: float = 0.0, + ): + super().__init__() + if isinstance(codebook_dim, int): + codebook_dim = [codebook_dim for _ in range(n_codebooks)] + + self.n_codebooks = n_codebooks + self.codebook_dim = codebook_dim + self.codebook_size = codebook_size + + self.quantizers = nn.ModuleList( + [ + VectorQuantize(input_dim, codebook_size, codebook_dim[i]) + for i in range(n_codebooks) + ] + ) + self.quantizer_dropout = quantizer_dropout + + def forward(self, z, n_quantizers: int = None): + """Quantized the input tensor using a fixed set of `n` codebooks and returns + the corresponding codebook vectors + Parameters + ---------- + z : Tensor[B x D x T] + n_quantizers : int, optional + No. of quantizers to use + (n_quantizers < self.n_codebooks ex: for quantizer dropout) + Note: if `self.quantizer_dropout` is True, this argument is ignored + when in training mode, and a random number of quantizers is used. + Returns + ------- + dict + A dictionary with the following keys: + + "z" : Tensor[B x D x T] + Quantized continuous representation of input + "codes" : Tensor[B x N x T] + Codebook indices for each codebook + (quantized discrete representation of input) + "latents" : Tensor[B x N*D x T] + Projected latents (continuous representation of input before quantization) + "vq/commitment_loss" : Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + "vq/codebook_loss" : Tensor[1] + Codebook loss to update the codebook + """ + z_q = 0 + residual = z + commitment_loss = 0 + codebook_loss = 0 + + codebook_indices = [] + latents = [] + + if n_quantizers is None: + n_quantizers = self.n_codebooks + if self.training: + n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1 + dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],)) + n_dropout = int(z.shape[0] * self.quantizer_dropout) + n_quantizers[:n_dropout] = dropout[:n_dropout] + n_quantizers = n_quantizers.to(z.device) + + for i, quantizer in enumerate(self.quantizers): + if self.training is False and i >= n_quantizers: + break + + z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer( + residual + ) + + # Create mask to apply quantizer dropout + mask = ( + torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers + ) + z_q = z_q + z_q_i * mask[:, None, None] + residual = residual - z_q_i + + # Sum losses + commitment_loss += (commitment_loss_i * mask).mean() + codebook_loss += (codebook_loss_i * mask).mean() + + codebook_indices.append(indices_i) + latents.append(z_e_i) + + codes = torch.stack(codebook_indices, dim=1) + latents = torch.cat(latents, dim=1) + + return z_q, codes, latents, commitment_loss, codebook_loss + + def from_codes(self, codes: torch.Tensor): + """Given the quantized codes, reconstruct the continuous representation + Parameters + ---------- + codes : Tensor[B x N x T] + Quantized discrete representation of input + Returns + ------- + Tensor[B x D x T] + Quantized continuous representation of input + """ + z_q = 0.0 + z_p = [] + n_codebooks = codes.shape[1] + for i in range(n_codebooks): + z_p_i = self.quantizers[i].decode_code(codes[:, i, :]) + z_p.append(z_p_i) + + z_q_i = self.quantizers[i].out_proj(z_p_i) + z_q = z_q + z_q_i + return z_q, torch.cat(z_p, dim=1), codes + + def from_latents(self, latents: torch.Tensor): + """Given the unquantized latents, reconstruct the + continuous representation after quantization. + + Parameters + ---------- + latents : Tensor[B x N x T] + Continuous representation of input after projection + + Returns + ------- + Tensor[B x D x T] + Quantized representation of full-projected space + Tensor[B x D x T] + Quantized representation of latent space + """ + z_q = 0 + z_p = [] + codes = [] + dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers]) + + n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[ + 0 + ] + for i in range(n_codebooks): + j, k = dims[i], dims[i + 1] + z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :]) + z_p.append(z_p_i) + codes.append(codes_i) + + z_q_i = self.quantizers[i].out_proj(z_p_i) + z_q = z_q + z_q_i + + return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1) + + +if __name__ == "__main__": + rvq = ResidualVectorQuantize(quantizer_dropout=True) + x = torch.randn(16, 512, 80) + y = rvq(x) + print(y["latents"].shape) diff --git a/dac/utils/__init__.py b/dac/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f6dbede7b3e57655b4a732c92c84d7c72fbc9059 --- /dev/null +++ b/dac/utils/__init__.py @@ -0,0 +1,123 @@ +from pathlib import Path + +import argbind +from audiotools import ml + +import dac + +DAC = dac.model.DAC +Accelerator = ml.Accelerator + +__MODEL_LATEST_TAGS__ = { + ("44khz", "8kbps"): "0.0.1", + ("24khz", "8kbps"): "0.0.4", + ("16khz", "8kbps"): "0.0.5", + ("44khz", "16kbps"): "1.0.0", +} + +__MODEL_URLS__ = { + ( + "44khz", + "0.0.1", + "8kbps", + ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.1/weights.pth", + ( + "24khz", + "0.0.4", + "8kbps", + ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.4/weights_24khz.pth", + ( + "16khz", + "0.0.5", + "8kbps", + ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.5/weights_16khz.pth", + ( + "44khz", + "1.0.0", + "16kbps", + ): "https://github.com/descriptinc/descript-audio-codec/releases/download/1.0.0/weights_44khz_16kbps.pth", +} + + +@argbind.bind(group="download", positional=True, without_prefix=True) +def download( + model_type: str = "44khz", model_bitrate: str = "8kbps", tag: str = "latest" +): + """ + Function that downloads the weights file from URL if a local cache is not found. + + Parameters + ---------- + model_type : str + The type of model to download. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". + model_bitrate: str + Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps". + Only 44khz model supports 16kbps. + tag : str + The tag of the model to download. Defaults to "latest". + + Returns + ------- + Path + Directory path required to load model via audiotools. + """ + model_type = model_type.lower() + tag = tag.lower() + + assert model_type in [ + "44khz", + "24khz", + "16khz", + ], "model_type must be one of '44khz', '24khz', or '16khz'" + + assert model_bitrate in [ + "8kbps", + "16kbps", + ], "model_bitrate must be one of '8kbps', or '16kbps'" + + if tag == "latest": + tag = __MODEL_LATEST_TAGS__[(model_type, model_bitrate)] + + download_link = __MODEL_URLS__.get((model_type, tag, model_bitrate), None) + + if download_link is None: + raise ValueError( + f"Could not find model with tag {tag} and model type {model_type}" + ) + + local_path = ( + Path.home() + / ".cache" + / "descript" + / "dac" + / f"weights_{model_type}_{model_bitrate}_{tag}.pth" + ) + if not local_path.exists(): + local_path.parent.mkdir(parents=True, exist_ok=True) + + # Download the model + import requests + + response = requests.get(download_link) + + if response.status_code != 200: + raise ValueError( + f"Could not download model. Received response code {response.status_code}" + ) + local_path.write_bytes(response.content) + + return local_path + + +def load_model( + model_type: str = "44khz", + model_bitrate: str = "8kbps", + tag: str = "latest", + load_path: str = None, +): + if not load_path: + load_path = download( + model_type=model_type, model_bitrate=model_bitrate, tag=tag + ) + generator = DAC.load(load_path) + return generator diff --git a/dac/utils/__pycache__/__init__.cpython-310.pyc b/dac/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..98a97e209c35162f4a42627e1241733738c07a72 Binary files /dev/null and b/dac/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/dac/utils/decode.py b/dac/utils/decode.py new file mode 100644 index 0000000000000000000000000000000000000000..48c25298fd0f9a53f19c1d6d260f1f0563c7b7e4 --- /dev/null +++ b/dac/utils/decode.py @@ -0,0 +1,95 @@ +import warnings +from pathlib import Path + +import argbind +import numpy as np +import torch +from audiotools import AudioSignal +from tqdm import tqdm + +from dac import DACFile +from dac.utils import load_model + +warnings.filterwarnings("ignore", category=UserWarning) + + +@argbind.bind(group="decode", positional=True, without_prefix=True) +@torch.inference_mode() +@torch.no_grad() +def decode( + input: str, + output: str = "", + weights_path: str = "", + model_tag: str = "latest", + model_bitrate: str = "8kbps", + device: str = "cuda", + model_type: str = "44khz", + verbose: bool = False, +): + """Decode audio from codes. + + Parameters + ---------- + input : str + Path to input directory or file + output : str, optional + Path to output directory, by default "". + If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`. + weights_path : str, optional + Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the + model_tag and model_type. + model_tag : str, optional + Tag of the model to use, by default "latest". Ignored if `weights_path` is specified. + model_bitrate: str + Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps". + device : str, optional + Device to use, by default "cuda". If "cpu", the model will be loaded on the CPU. + model_type : str, optional + The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified. + """ + generator = load_model( + model_type=model_type, + model_bitrate=model_bitrate, + tag=model_tag, + load_path=weights_path, + ) + generator.to(device) + generator.eval() + + # Find all .dac files in input directory + _input = Path(input) + input_files = list(_input.glob("**/*.dac")) + + # If input is a .dac file, add it to the list + if _input.suffix == ".dac": + input_files.append(_input) + + # Create output directory + output = Path(output) + output.mkdir(parents=True, exist_ok=True) + + for i in tqdm(range(len(input_files)), desc=f"Decoding files"): + # Load file + artifact = DACFile.load(input_files[i]) + + # Reconstruct audio from codes + recons = generator.decompress(artifact, verbose=verbose) + + # Compute output path + relative_path = input_files[i].relative_to(input) + output_dir = output / relative_path.parent + if not relative_path.name: + output_dir = output + relative_path = input_files[i] + output_name = relative_path.with_suffix(".wav").name + output_path = output_dir / output_name + output_path.parent.mkdir(parents=True, exist_ok=True) + + # Write to file + recons.write(output_path) + + +if __name__ == "__main__": + args = argbind.parse_args() + with argbind.scope(args): + decode() diff --git a/dac/utils/encode.py b/dac/utils/encode.py new file mode 100644 index 0000000000000000000000000000000000000000..7c9a8b582b24c194ba82c2d280c4df4fa2cfff87 --- /dev/null +++ b/dac/utils/encode.py @@ -0,0 +1,94 @@ +import math +import warnings +from pathlib import Path + +import argbind +import numpy as np +import torch +from audiotools import AudioSignal +from audiotools.core import util +from tqdm import tqdm + +from dac.utils import load_model + +warnings.filterwarnings("ignore", category=UserWarning) + + +@argbind.bind(group="encode", positional=True, without_prefix=True) +@torch.inference_mode() +@torch.no_grad() +def encode( + input: str, + output: str = "", + weights_path: str = "", + model_tag: str = "latest", + model_bitrate: str = "8kbps", + n_quantizers: int = None, + device: str = "cuda", + model_type: str = "44khz", + win_duration: float = 5.0, + verbose: bool = False, +): + """Encode audio files in input path to .dac format. + + Parameters + ---------- + input : str + Path to input audio file or directory + output : str, optional + Path to output directory, by default "". If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`. + weights_path : str, optional + Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the + model_tag and model_type. + model_tag : str, optional + Tag of the model to use, by default "latest". Ignored if `weights_path` is specified. + model_bitrate: str + Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps". + n_quantizers : int, optional + Number of quantizers to use, by default None. If not specified, all the quantizers will be used and the model will compress at maximum bitrate. + device : str, optional + Device to use, by default "cuda" + model_type : str, optional + The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified. + """ + generator = load_model( + model_type=model_type, + model_bitrate=model_bitrate, + tag=model_tag, + load_path=weights_path, + ) + generator.to(device) + generator.eval() + kwargs = {"n_quantizers": n_quantizers} + + # Find all audio files in input path + input = Path(input) + audio_files = util.find_audio(input) + + output = Path(output) + output.mkdir(parents=True, exist_ok=True) + + for i in tqdm(range(len(audio_files)), desc="Encoding files"): + # Load file + signal = AudioSignal(audio_files[i]) + + # Encode audio to .dac format + artifact = generator.compress(signal, win_duration, verbose=verbose, **kwargs) + + # Compute output path + relative_path = audio_files[i].relative_to(input) + output_dir = output / relative_path.parent + if not relative_path.name: + output_dir = output + relative_path = audio_files[i] + output_name = relative_path.with_suffix(".dac").name + output_path = output_dir / output_name + output_path.parent.mkdir(parents=True, exist_ok=True) + + artifact.save(output_path) + + +if __name__ == "__main__": + args = argbind.parse_args() + with argbind.scope(args): + encode() diff --git a/hf_utils.py b/hf_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0eada6aa788c93e5deb818993da905ee65e50e3e --- /dev/null +++ b/hf_utils.py @@ -0,0 +1,11 @@ +import torch +import os +from huggingface_hub import hf_hub_download + + +def load_custom_model_from_hf(repo_id, model_filename="pytorch_model.bin", config_filename="config.yml"): + os.makedirs("./checkpoints", exist_ok=True) + model_path = hf_hub_download(repo_id=repo_id, filename=model_filename, cache_dir="./checkpoints") + config_path = hf_hub_download(repo_id=repo_id, filename=config_filename, cache_dir="./checkpoints") + + return model_path, config_path \ No newline at end of file diff --git a/modules/__pycache__/attentions.cpython-310.pyc b/modules/__pycache__/attentions.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..08462b1d8d27d78331b0de89634f3ff4bfa31e5c Binary files /dev/null and b/modules/__pycache__/attentions.cpython-310.pyc differ diff --git a/modules/__pycache__/commons.cpython-310.pyc b/modules/__pycache__/commons.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..abfffd060a0868fa5d96b8b244b008d2b20aeaeb Binary files /dev/null and b/modules/__pycache__/commons.cpython-310.pyc differ diff --git a/modules/__pycache__/layers.cpython-310.pyc b/modules/__pycache__/layers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d21317d26a64b60e41520dc2dbe195b75006cf31 Binary files /dev/null and b/modules/__pycache__/layers.cpython-310.pyc differ diff --git a/modules/__pycache__/mamba.cpython-310.pyc b/modules/__pycache__/mamba.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1a436e7a79999c85aa2df491e2b71a82efd32f5 Binary files /dev/null and b/modules/__pycache__/mamba.cpython-310.pyc differ diff --git a/modules/__pycache__/quantize.cpython-310.pyc b/modules/__pycache__/quantize.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..21d057c490d4ccbb5775140cd500cd1fb83782d8 Binary files /dev/null and b/modules/__pycache__/quantize.cpython-310.pyc differ diff --git a/modules/__pycache__/redecoder.cpython-310.pyc b/modules/__pycache__/redecoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a7159138e8fba457745949be732d49744f34532e Binary files /dev/null and b/modules/__pycache__/redecoder.cpython-310.pyc differ diff --git a/modules/__pycache__/style_encoder.cpython-310.pyc b/modules/__pycache__/style_encoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e0881c9a03449773ab36ce705fd4806bfa20d57 Binary files /dev/null and b/modules/__pycache__/style_encoder.cpython-310.pyc differ diff --git a/modules/__pycache__/wavenet.cpython-310.pyc b/modules/__pycache__/wavenet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..674cf145a686d37675fee439e2fa7a06fb481b60 Binary files /dev/null and b/modules/__pycache__/wavenet.cpython-310.pyc differ diff --git a/modules/attentions.py b/modules/attentions.py new file mode 100644 index 0000000000000000000000000000000000000000..de838ae921f5b47bb70182fcee03ca8fd8d3a4ea --- /dev/null +++ b/modules/attentions.py @@ -0,0 +1,324 @@ +import copy +import math +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F + +from . import commons +class LayerNorm(nn.Module): + def __init__(self, channels, eps=1e-5): + super().__init__() + self.channels = channels + self.eps = eps + + self.gamma = nn.Parameter(torch.ones(channels)) + self.beta = nn.Parameter(torch.zeros(channels)) + + def forward(self, x): + x = x.transpose(1, -1) + x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) + return x.transpose(1, -1) + + +class Encoder(nn.Module): + def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., window_size=4, + **kwargs): + super().__init__() + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.window_size = window_size + + self.drop = nn.Dropout(p_dropout) + self.attn_layers = nn.ModuleList() + self.norm_layers_1 = nn.ModuleList() + self.ffn_layers = nn.ModuleList() + self.norm_layers_2 = nn.ModuleList() + for i in range(self.n_layers): + self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, + window_size=window_size)) + self.norm_layers_1.append(LayerNorm(hidden_channels)) + self.ffn_layers.append( + FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout)) + self.norm_layers_2.append(LayerNorm(hidden_channels)) + + def forward(self, x, x_mask): + attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) + x = x * x_mask + for i in range(self.n_layers): + y = self.attn_layers[i](x, x, attn_mask) + y = self.drop(y) + x = self.norm_layers_1[i](x + y) + + y = self.ffn_layers[i](x, x_mask) + y = self.drop(y) + x = self.norm_layers_2[i](x + y) + x = x * x_mask + return x + + +class Decoder(nn.Module): + def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., + proximal_bias=False, proximal_init=True, **kwargs): + super().__init__() + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.proximal_bias = proximal_bias + self.proximal_init = proximal_init + + self.drop = nn.Dropout(p_dropout) + self.self_attn_layers = nn.ModuleList() + self.norm_layers_0 = nn.ModuleList() + self.encdec_attn_layers = nn.ModuleList() + self.norm_layers_1 = nn.ModuleList() + self.ffn_layers = nn.ModuleList() + self.norm_layers_2 = nn.ModuleList() + for i in range(self.n_layers): + self.self_attn_layers.append( + MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, + proximal_bias=proximal_bias, proximal_init=proximal_init)) + self.norm_layers_0.append(LayerNorm(hidden_channels)) + self.encdec_attn_layers.append( + MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout)) + self.norm_layers_1.append(LayerNorm(hidden_channels)) + self.ffn_layers.append( + FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout, causal=True)) + self.norm_layers_2.append(LayerNorm(hidden_channels)) + + def forward(self, x, x_mask, h, h_mask): + """ + x: decoder input + h: encoder output + """ + self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype) + encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1) + x = x * x_mask + for i in range(self.n_layers): + y = self.self_attn_layers[i](x, x, self_attn_mask) + y = self.drop(y) + x = self.norm_layers_0[i](x + y) + + y = self.encdec_attn_layers[i](x, h, encdec_attn_mask) + y = self.drop(y) + x = self.norm_layers_1[i](x + y) + + y = self.ffn_layers[i](x, x_mask) + y = self.drop(y) + x = self.norm_layers_2[i](x + y) + x = x * x_mask + return x + + +class MultiHeadAttention(nn.Module): + def __init__(self, channels, out_channels, n_heads, p_dropout=0., window_size=None, heads_share=True, + block_length=None, proximal_bias=False, proximal_init=False): + super().__init__() + assert channels % n_heads == 0 + + self.channels = channels + self.out_channels = out_channels + self.n_heads = n_heads + self.p_dropout = p_dropout + self.window_size = window_size + self.heads_share = heads_share + self.block_length = block_length + self.proximal_bias = proximal_bias + self.proximal_init = proximal_init + self.attn = None + + self.k_channels = channels // n_heads + self.conv_q = nn.Conv1d(channels, channels, 1) + self.conv_k = nn.Conv1d(channels, channels, 1) + self.conv_v = nn.Conv1d(channels, channels, 1) + self.conv_o = nn.Conv1d(channels, out_channels, 1) + self.drop = nn.Dropout(p_dropout) + + if window_size is not None: + n_heads_rel = 1 if heads_share else n_heads + rel_stddev = self.k_channels ** -0.5 + self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev) + self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev) + + nn.init.xavier_uniform_(self.conv_q.weight) + nn.init.xavier_uniform_(self.conv_k.weight) + nn.init.xavier_uniform_(self.conv_v.weight) + if proximal_init: + with torch.no_grad(): + self.conv_k.weight.copy_(self.conv_q.weight) + self.conv_k.bias.copy_(self.conv_q.bias) + + def forward(self, x, c, attn_mask=None): + q = self.conv_q(x) + k = self.conv_k(c) + v = self.conv_v(c) + + x, self.attn = self.attention(q, k, v, mask=attn_mask) + + x = self.conv_o(x) + return x + + def attention(self, query, key, value, mask=None): + # reshape [b, d, t] -> [b, n_h, t, d_k] + b, d, t_s, t_t = (*key.size(), query.size(2)) + query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3) + key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) + value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) + + scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1)) + if self.window_size is not None: + assert t_s == t_t, "Relative attention is only available for self-attention." + key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) + rel_logits = self._matmul_with_relative_keys(query / math.sqrt(self.k_channels), key_relative_embeddings) + scores_local = self._relative_position_to_absolute_position(rel_logits) + scores = scores + scores_local + if self.proximal_bias: + assert t_s == t_t, "Proximal bias is only available for self-attention." + scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype) + if mask is not None: + scores = scores.masked_fill(mask == 0, -1e4) + if self.block_length is not None: + assert t_s == t_t, "Local attention is only available for self-attention." + block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length) + scores = scores.masked_fill(block_mask == 0, -1e4) + p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s] + p_attn = self.drop(p_attn) + output = torch.matmul(p_attn, value) + if self.window_size is not None: + relative_weights = self._absolute_position_to_relative_position(p_attn) + value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s) + output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings) + output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t] + return output, p_attn + + def _matmul_with_relative_values(self, x, y): + """ + x: [b, h, l, m] + y: [h or 1, m, d] + ret: [b, h, l, d] + """ + ret = torch.matmul(x, y.unsqueeze(0)) + return ret + + def _matmul_with_relative_keys(self, x, y): + """ + x: [b, h, l, d] + y: [h or 1, m, d] + ret: [b, h, l, m] + """ + ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1)) + return ret + + def _get_relative_embeddings(self, relative_embeddings, length): + max_relative_position = 2 * self.window_size + 1 + # Pad first before slice to avoid using cond ops. + pad_length = max(length - (self.window_size + 1), 0) + slice_start_position = max((self.window_size + 1) - length, 0) + slice_end_position = slice_start_position + 2 * length - 1 + if pad_length > 0: + padded_relative_embeddings = F.pad( + relative_embeddings, + commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]])) + else: + padded_relative_embeddings = relative_embeddings + used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position] + return used_relative_embeddings + + def _relative_position_to_absolute_position(self, x): + """ + x: [b, h, l, 2*l-1] + ret: [b, h, l, l] + """ + batch, heads, length, _ = x.size() + # Concat columns of pad to shift from relative to absolute indexing. + x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]])) + + # Concat extra elements so to add up to shape (len+1, 2*len-1). + x_flat = x.view([batch, heads, length * 2 * length]) + x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])) + + # Reshape and slice out the padded elements. + x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1:] + return x_final + + def _absolute_position_to_relative_position(self, x): + """ + x: [b, h, l, l] + ret: [b, h, l, 2*l-1] + """ + batch, heads, length, _ = x.size() + # padd along column + x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])) + x_flat = x.view([batch, heads, length ** 2 + length * (length - 1)]) + # add 0's in the beginning that will skew the elements after reshape + x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]])) + x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:] + return x_final + + def _attention_bias_proximal(self, length): + """Bias for self-attention to encourage attention to close positions. + Args: + length: an integer scalar. + Returns: + a Tensor with shape [1, 1, length, length] + """ + r = torch.arange(length, dtype=torch.float32) + diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) + return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) + + +class FFN(nn.Module): + def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., activation=None, + causal=False): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.activation = activation + self.causal = causal + + if causal: + self.padding = self._causal_padding + else: + self.padding = self._same_padding + + self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size) + self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size) + self.drop = nn.Dropout(p_dropout) + + def forward(self, x, x_mask): + x = self.conv_1(self.padding(x * x_mask)) + if self.activation == "gelu": + x = x * torch.sigmoid(1.702 * x) + else: + x = torch.relu(x) + x = self.drop(x) + x = self.conv_2(self.padding(x * x_mask)) + return x * x_mask + + def _causal_padding(self, x): + if self.kernel_size == 1: + return x + pad_l = self.kernel_size - 1 + pad_r = 0 + padding = [[0, 0], [0, 0], [pad_l, pad_r]] + x = F.pad(x, commons.convert_pad_shape(padding)) + return x + + def _same_padding(self, x): + if self.kernel_size == 1: + return x + pad_l = (self.kernel_size - 1) // 2 + pad_r = self.kernel_size // 2 + padding = [[0, 0], [0, 0], [pad_l, pad_r]] + x = F.pad(x, commons.convert_pad_shape(padding)) + return x diff --git a/modules/beta_vae.py b/modules/beta_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..f839cc176d9ec15001f8b97e6a8ec2196635517f --- /dev/null +++ b/modules/beta_vae.py @@ -0,0 +1,101 @@ +import torch +import torch.nn as nn +import torch.distributions as td +import numpy as np + +class Swish(nn.Module): + def __init__(self): + super(Swish, self).__init__() + def forward(self, x): + return x * torch.sigmoid(x) + +def cycle_interval(starting_value, num_frames, min_val, max_val): + """Cycles through the state space in a single cycle.""" + starting_in_01 = ((starting_value - min_val) / (max_val - min_val)).cpu() + grid = torch.linspace(starting_in_01.item(), starting_in_01.item() + 2., steps=num_frames + 1)[:-1] + grid -= np.maximum(0, 2 * grid - 2) + grid += np.maximum(0, -2 * grid) + return grid * (max_val - min_val) + min_val +class BetaVAE_Linear(nn.Module): + def __init__(self, in_dim=1024, n_hidden=64, latent=8): + super(BetaVAE_Linear, self).__init__() + + self.n_hidden = n_hidden + self.latent = latent + + # Encoder + self.encoder = nn.Sequential( + nn.Linear(in_dim, n_hidden), Swish(), + ) + + # Latent + self.mu = nn.Linear(n_hidden, latent) + self.lv = nn.Linear(n_hidden, latent) + + # Decoder + self.decoder = nn.Sequential( + nn.Linear(latent, n_hidden), Swish(), + nn.Linear(n_hidden, in_dim), Swish() + ) + + def BottomUp(self, x): + out = self.encoder(x) + mu, lv = self.mu(out), self.lv(out) + return mu, lv + + def reparameterize(self, mu, lv): + std = torch.exp(0.5 * lv) + eps = torch.randn_like(std) + return mu + std * eps + + def TopDown(self, z): + out = self.decoder(z) + return out + + def forward(self, x): + # x = x.view(x.shape[0], -1) + mu, lv = self.BottomUp(x) + z = self.reparameterize(mu, lv) + out = self.TopDown(z) + return out, mu, lv + + def calc_loss(self, x, beta): + mu, lv = self.BottomUp(x) + z = self.reparameterize(mu, lv) + out = torch.sigmoid(self.TopDown(z)) + + nll = -nn.functional.binary_cross_entropy(out, x, reduction='sum') / x.shape[0] + kl = (-0.5 * torch.sum(1 + lv - mu.pow(2) - lv.exp()) + 1e-5) / x.shape[0] + # print(kl, nll) + + return -nll + kl * beta, kl, nll + + def LT_fitted_gauss_2std(self, x,num_var=6, num_traversal=5): + # Cycle linearly through +-2 std dev of a fitted Gaussian. + x = x.view(x.shape[0], -1) + mu, lv = self.BottomUp(x) + + images = [] + for i, batch_mu in enumerate(mu[:num_var]): + images.append(torch.sigmoid(self.TopDown(batch_mu)).unsqueeze(0)) + for latent_var in range(batch_mu.shape[0]): + new_mu = batch_mu.unsqueeze(0).repeat([num_traversal, 1]) + loc = mu[:, latent_var].mean() + total_var = lv[:, latent_var].exp().mean() + mu[:, latent_var].var() + scale = total_var.sqrt() + new_mu[:, latent_var] = cycle_interval(batch_mu[latent_var], num_traversal, + loc - 2 * scale, loc + 2 * scale) + images.append(torch.sigmoid(self.TopDown(new_mu))) + return images + + +if __name__ == "__main__": + model = BetaVAE_Linear() + x = torch.rand(10, 784) + out = model(x) + print(out.shape) + loss, kl, nll = model.calc_loss(x, 0.05) + print(loss, kl, nll) + images = model.LT_fitted_gauss_2std(x) + print(len(images), images[0].shape) + print(images[0].shape) \ No newline at end of file diff --git a/modules/commons.py b/modules/commons.py new file mode 100644 index 0000000000000000000000000000000000000000..13666330afc6aa8eea5bc6a487b62b6f2e0ee409 --- /dev/null +++ b/modules/commons.py @@ -0,0 +1,479 @@ +import math +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F +from munch import Munch +import json + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size*dilation - dilation)/2) + + +def convert_pad_shape(pad_shape): + l = pad_shape[::-1] + pad_shape = [item for sublist in l for item in sublist] + return pad_shape + + +def intersperse(lst, item): + result = [item] * (len(lst) * 2 + 1) + result[1::2] = lst + return result + + +def kl_divergence(m_p, logs_p, m_q, logs_q): + """KL(P||Q)""" + kl = (logs_q - logs_p) - 0.5 + kl += 0.5 * (torch.exp(2. * logs_p) + ((m_p - m_q)**2)) * torch.exp(-2. * logs_q) + return kl + + +def rand_gumbel(shape): + """Sample from the Gumbel distribution, protect from overflows.""" + uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 + return -torch.log(-torch.log(uniform_samples)) + + +def rand_gumbel_like(x): + g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) + return g + + +def slice_segments(x, ids_str, segment_size=4): + ret = torch.zeros_like(x[:, :, :segment_size]) + for i in range(x.size(0)): + idx_str = ids_str[i] + idx_end = idx_str + segment_size + ret[i] = x[i, :, idx_str:idx_end] + return ret + +def slice_segments_audio(x, ids_str, segment_size=4): + ret = torch.zeros_like(x[:, :segment_size]) + for i in range(x.size(0)): + idx_str = ids_str[i] + idx_end = idx_str + segment_size + ret[i] = x[i, idx_str:idx_end] + return ret + +def rand_slice_segments(x, x_lengths=None, segment_size=4): + b, d, t = x.size() + if x_lengths is None: + x_lengths = t + ids_str_max = x_lengths - segment_size + 1 + ids_str = ((torch.rand([b]).to(device=x.device) * ids_str_max).clip(0)).to(dtype=torch.long) + ret = slice_segments(x, ids_str, segment_size) + return ret, ids_str + + +def get_timing_signal_1d( + length, channels, min_timescale=1.0, max_timescale=1.0e4): + position = torch.arange(length, dtype=torch.float) + num_timescales = channels // 2 + log_timescale_increment = ( + math.log(float(max_timescale) / float(min_timescale)) / + (num_timescales - 1)) + inv_timescales = min_timescale * torch.exp( + torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment) + scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) + signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) + signal = F.pad(signal, [0, 0, 0, channels % 2]) + signal = signal.view(1, channels, length) + return signal + + +def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): + b, channels, length = x.size() + signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) + return x + signal.to(dtype=x.dtype, device=x.device) + + +def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): + b, channels, length = x.size() + signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) + return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) + + +def subsequent_mask(length): + mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) + return mask + + +@torch.jit.script +def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): + n_channels_int = n_channels[0] + in_act = input_a + input_b + t_act = torch.tanh(in_act[:, :n_channels_int, :]) + s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) + acts = t_act * s_act + return acts + + +def convert_pad_shape(pad_shape): + l = pad_shape[::-1] + pad_shape = [item for sublist in l for item in sublist] + return pad_shape + + +def shift_1d(x): + x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] + return x + + +def sequence_mask(length, max_length=None): + if max_length is None: + max_length = length.max() + x = torch.arange(max_length, dtype=length.dtype, device=length.device) + return x.unsqueeze(0) < length.unsqueeze(1) + + +def generate_path(duration, mask): + """ + duration: [b, 1, t_x] + mask: [b, 1, t_y, t_x] + """ + device = duration.device + + b, _, t_y, t_x = mask.shape + cum_duration = torch.cumsum(duration, -1) + + cum_duration_flat = cum_duration.view(b * t_x) + path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) + path = path.view(b, t_x, t_y) + path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] + path = path.unsqueeze(1).transpose(2,3) * mask + return path + + +def clip_grad_value_(parameters, clip_value, norm_type=2): + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = list(filter(lambda p: p.grad is not None, parameters)) + norm_type = float(norm_type) + if clip_value is not None: + clip_value = float(clip_value) + + total_norm = 0 + for p in parameters: + param_norm = p.grad.data.norm(norm_type) + total_norm += param_norm.item() ** norm_type + if clip_value is not None: + p.grad.data.clamp_(min=-clip_value, max=clip_value) + total_norm = total_norm ** (1. / norm_type) + return total_norm + +def log_norm(x, mean=-4, std=4, dim=2): + """ + normalized log mel -> mel -> norm -> log(norm) + """ + x = torch.log(torch.exp(x * std + mean).norm(dim=dim)) + return x + +def load_F0_models(path): + # load F0 model + from .JDC.model import JDCNet + F0_model = JDCNet(num_class=1, seq_len=192) + params = torch.load(path, map_location='cpu')['net'] + F0_model.load_state_dict(params) + _ = F0_model.train() + + return F0_model + +def modify_w2v_forward(self, output_layer=15): + ''' + change forward method of w2v encoder to get its intermediate layer output + :param self: + :param layer: + :return: + ''' + from transformers.modeling_outputs import BaseModelOutput + def forward( + hidden_states, + attention_mask=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + conv_attention_mask = attention_mask + if attention_mask is not None: + # make sure padded tokens output 0 + hidden_states = hidden_states.masked_fill(~attention_mask.bool().unsqueeze(-1), 0.0) + + # extend attention_mask + attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype) + attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min + attention_mask = attention_mask.expand( + attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] + ) + + hidden_states = self.dropout(hidden_states) + + if self.embed_positions is not None: + relative_position_embeddings = self.embed_positions(hidden_states) + else: + relative_position_embeddings = None + + deepspeed_zero3_is_enabled = False + + for i, layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = torch.rand([]) + + skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False + if not skip_the_layer or deepspeed_zero3_is_enabled: + # under deepspeed zero3 all gpus must run in sync + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + attention_mask, + relative_position_embeddings, + output_attentions, + conv_attention_mask, + ) + else: + layer_outputs = layer( + hidden_states, + attention_mask=attention_mask, + relative_position_embeddings=relative_position_embeddings, + output_attentions=output_attentions, + conv_attention_mask=conv_attention_mask, + ) + hidden_states = layer_outputs[0] + + if skip_the_layer: + layer_outputs = (None, None) + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if i == output_layer - 1: + break + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + return forward + + +def build_model(args, stage='codec'): + if stage == 'codec': + # Generators + from dac.model.dac import Encoder, Decoder + from modules.quantize import FAquantizer, FApredictors, CNNLSTM, GradientReversal + + # Discriminators + from dac.model.discriminator import Discriminator + + encoder = Encoder(d_model=args.DAC.encoder_dim, + strides=args.DAC.encoder_rates, + d_latent=1024, + causal=args.causal, + lstm=args.lstm,) + + quantizer = FAquantizer(in_dim=1024, + n_p_codebooks=1, + n_c_codebooks=args.n_c_codebooks, + n_t_codebooks=2, + n_r_codebooks=3, + codebook_size=1024, + codebook_dim=8, + quantizer_dropout=0.5, + causal=args.causal, + separate_prosody_encoder=args.separate_prosody_encoder, + timbre_norm=args.timbre_norm, + ) + + fa_predictors = FApredictors(in_dim=1024, + use_gr_content_f0=args.use_gr_content_f0, + use_gr_prosody_phone=args.use_gr_prosody_phone, + use_gr_residual_f0=True, + use_gr_residual_phone=True, + use_gr_timbre_content=True, + use_gr_timbre_prosody=args.use_gr_timbre_prosody, + use_gr_x_timbre=True, + norm_f0=args.norm_f0, + timbre_norm=args.timbre_norm, + use_gr_content_global_f0=args.use_gr_content_global_f0, + ) + + + + decoder = Decoder( + input_channel=1024, + channels=args.DAC.decoder_dim, + rates=args.DAC.decoder_rates, + causal=args.causal, + lstm=args.lstm, + ) + + discriminator = Discriminator( + rates=[], + periods=[2, 3, 5, 7, 11], + fft_sizes=[2048, 1024, 512], + sample_rate=args.DAC.sr, + bands=[(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)], + ) + + nets = Munch( + encoder=encoder, + quantizer=quantizer, + decoder=decoder, + discriminator=discriminator, + fa_predictors=fa_predictors, + ) + elif stage == 'beta_vae': + from dac.model.dac import Encoder, Decoder + from modules.beta_vae import BetaVAE_Linear + # Discriminators + from dac.model.discriminator import Discriminator + + encoder = Encoder(d_model=args.DAC.encoder_dim, + strides=args.DAC.encoder_rates, + d_latent=1024, + causal=args.causal, + lstm=args.lstm, ) + + decoder = Decoder( + input_channel=1024, + channels=args.DAC.decoder_dim, + rates=args.DAC.decoder_rates, + causal=args.causal, + lstm=args.lstm, + ) + + discriminator = Discriminator( + rates=[], + periods=[2, 3, 5, 7, 11], + fft_sizes=[2048, 1024, 512], + sample_rate=args.DAC.sr, + bands=[(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)], + ) + + beta_vae = BetaVAE_Linear(in_dim=1024, n_hidden=64, latent=8) + + nets = Munch( + encoder=encoder, + decoder=decoder, + discriminator=discriminator, + beta_vae=beta_vae, + ) + elif stage == 'redecoder': + # from vc.models import FastTransformer, SlowTransformer, Mambo + from dac.model.dac import Encoder, Decoder + from dac.model.discriminator import Discriminator + from modules.redecoder import Redecoder + + encoder = Redecoder(args) + + decoder = Decoder( + input_channel=1024, + channels=args.DAC.decoder_dim, + rates=args.DAC.decoder_rates, + causal=args.decoder_causal, + lstm=args.decoder_lstm, + ) + + discriminator = Discriminator( + rates=[], + periods=[2, 3, 5, 7, 11], + fft_sizes=[2048, 1024, 512], + sample_rate=args.DAC.sr, + bands=[(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)], + ) + + nets = Munch( + encoder=encoder, + decoder=decoder, + discriminator=discriminator, + ) + elif stage == 'encoder': + from dac.model.dac import Encoder, Decoder + from modules.quantize import FAquantizer + + encoder = Encoder(d_model=args.DAC.encoder_dim, + strides=args.DAC.encoder_rates, + d_latent=1024, + causal=args.encoder_causal, + lstm=args.encoder_lstm,) + + quantizer = FAquantizer(in_dim=1024, + n_p_codebooks=1, + n_c_codebooks=args.n_c_codebooks, + n_t_codebooks=2, + n_r_codebooks=3, + codebook_size=1024, + codebook_dim=8, + quantizer_dropout=0.5, + causal=args.encoder_causal, + separate_prosody_encoder=args.separate_prosody_encoder, + timbre_norm=args.timbre_norm, + ) + nets = Munch( + encoder=encoder, + quantizer=quantizer, + ) + else: + raise ValueError(f"Unknown stage: {stage}") + + return nets + + +def load_checkpoint(model, optimizer, path, load_only_params=True, ignore_modules=[], is_distributed=False): + state = torch.load(path, map_location='cpu') + params = state['net'] + for key in model: + if key in params and key not in ignore_modules: + if not is_distributed: + # strip prefix of DDP (module.), create a new OrderedDict that does not contain the prefix + for k in list(params[key].keys()): + if k.startswith('module.'): + params[key][k[len("module."):]] = params[key][k] + del params[key][k] + print('%s loaded' % key) + model[key].load_state_dict(params[key], strict=True) + _ = [model[key].eval() for key in model] + + if not load_only_params: + epoch = state["epoch"] + 1 + iters = state["iters"] + optimizer.load_state_dict(state["optimizer"]) + optimizer.load_scheduler_state_dict(state["scheduler"]) + + else: + epoch = state["epoch"] + 1 + iters = state["iters"] + + return model, optimizer, epoch, iters + +def recursive_munch(d): + if isinstance(d, dict): + return Munch((k, recursive_munch(v)) for k, v in d.items()) + elif isinstance(d, list): + return [recursive_munch(v) for v in d] + else: + return d \ No newline at end of file diff --git a/modules/layers.py b/modules/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..dc218d864e1b3d605220d989cd2f8f1a317ae4c3 --- /dev/null +++ b/modules/layers.py @@ -0,0 +1,354 @@ +import math +import torch +from torch import nn +from typing import Optional, Any +from torch import Tensor +import torch.nn.functional as F +import torchaudio +import torchaudio.functional as audio_F + +import random +random.seed(0) + + +def _get_activation_fn(activ): + if activ == 'relu': + return nn.ReLU() + elif activ == 'lrelu': + return nn.LeakyReLU(0.2) + elif activ == 'swish': + return lambda x: x*torch.sigmoid(x) + else: + raise RuntimeError('Unexpected activ type %s, expected [relu, lrelu, swish]' % activ) + +class LinearNorm(torch.nn.Module): + def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'): + super(LinearNorm, self).__init__() + self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias) + + torch.nn.init.xavier_uniform_( + self.linear_layer.weight, + gain=torch.nn.init.calculate_gain(w_init_gain)) + + def forward(self, x): + return self.linear_layer(x) + + +class ConvNorm(torch.nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, + padding=None, dilation=1, bias=True, w_init_gain='linear', param=None): + super(ConvNorm, self).__init__() + if padding is None: + assert(kernel_size % 2 == 1) + padding = int(dilation * (kernel_size - 1) / 2) + + self.conv = torch.nn.Conv1d(in_channels, out_channels, + kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, + bias=bias) + + torch.nn.init.xavier_uniform_( + self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain, param=param)) + + def forward(self, signal): + conv_signal = self.conv(signal) + return conv_signal + +class CausualConv(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=1, dilation=1, bias=True, w_init_gain='linear', param=None): + super(CausualConv, self).__init__() + if padding is None: + assert(kernel_size % 2 == 1) + padding = int(dilation * (kernel_size - 1) / 2) * 2 + else: + self.padding = padding * 2 + self.conv = nn.Conv1d(in_channels, out_channels, + kernel_size=kernel_size, stride=stride, + padding=self.padding, + dilation=dilation, + bias=bias) + + torch.nn.init.xavier_uniform_( + self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain, param=param)) + + def forward(self, x): + x = self.conv(x) + x = x[:, :, :-self.padding] + return x + +class CausualBlock(nn.Module): + def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ='lrelu'): + super(CausualBlock, self).__init__() + self.blocks = nn.ModuleList([ + self._get_conv(hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p) + for i in range(n_conv)]) + + def forward(self, x): + for block in self.blocks: + res = x + x = block(x) + x += res + return x + + def _get_conv(self, hidden_dim, dilation, activ='lrelu', dropout_p=0.2): + layers = [ + CausualConv(hidden_dim, hidden_dim, kernel_size=3, padding=dilation, dilation=dilation), + _get_activation_fn(activ), + nn.BatchNorm1d(hidden_dim), + nn.Dropout(p=dropout_p), + CausualConv(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1), + _get_activation_fn(activ), + nn.Dropout(p=dropout_p) + ] + return nn.Sequential(*layers) + +class ConvBlock(nn.Module): + def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ='relu'): + super().__init__() + self._n_groups = 8 + self.blocks = nn.ModuleList([ + self._get_conv(hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p) + for i in range(n_conv)]) + + + def forward(self, x): + for block in self.blocks: + res = x + x = block(x) + x += res + return x + + def _get_conv(self, hidden_dim, dilation, activ='relu', dropout_p=0.2): + layers = [ + ConvNorm(hidden_dim, hidden_dim, kernel_size=3, padding=dilation, dilation=dilation), + _get_activation_fn(activ), + nn.GroupNorm(num_groups=self._n_groups, num_channels=hidden_dim), + nn.Dropout(p=dropout_p), + ConvNorm(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1), + _get_activation_fn(activ), + nn.Dropout(p=dropout_p) + ] + return nn.Sequential(*layers) + +class LocationLayer(nn.Module): + def __init__(self, attention_n_filters, attention_kernel_size, + attention_dim): + super(LocationLayer, self).__init__() + padding = int((attention_kernel_size - 1) / 2) + self.location_conv = ConvNorm(2, attention_n_filters, + kernel_size=attention_kernel_size, + padding=padding, bias=False, stride=1, + dilation=1) + self.location_dense = LinearNorm(attention_n_filters, attention_dim, + bias=False, w_init_gain='tanh') + + def forward(self, attention_weights_cat): + processed_attention = self.location_conv(attention_weights_cat) + processed_attention = processed_attention.transpose(1, 2) + processed_attention = self.location_dense(processed_attention) + return processed_attention + + +class Attention(nn.Module): + def __init__(self, attention_rnn_dim, embedding_dim, attention_dim, + attention_location_n_filters, attention_location_kernel_size): + super(Attention, self).__init__() + self.query_layer = LinearNorm(attention_rnn_dim, attention_dim, + bias=False, w_init_gain='tanh') + self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False, + w_init_gain='tanh') + self.v = LinearNorm(attention_dim, 1, bias=False) + self.location_layer = LocationLayer(attention_location_n_filters, + attention_location_kernel_size, + attention_dim) + self.score_mask_value = -float("inf") + + def get_alignment_energies(self, query, processed_memory, + attention_weights_cat): + """ + PARAMS + ------ + query: decoder output (batch, n_mel_channels * n_frames_per_step) + processed_memory: processed encoder outputs (B, T_in, attention_dim) + attention_weights_cat: cumulative and prev. att weights (B, 2, max_time) + RETURNS + ------- + alignment (batch, max_time) + """ + + processed_query = self.query_layer(query.unsqueeze(1)) + processed_attention_weights = self.location_layer(attention_weights_cat) + energies = self.v(torch.tanh( + processed_query + processed_attention_weights + processed_memory)) + + energies = energies.squeeze(-1) + return energies + + def forward(self, attention_hidden_state, memory, processed_memory, + attention_weights_cat, mask): + """ + PARAMS + ------ + attention_hidden_state: attention rnn last output + memory: encoder outputs + processed_memory: processed encoder outputs + attention_weights_cat: previous and cummulative attention weights + mask: binary mask for padded data + """ + alignment = self.get_alignment_energies( + attention_hidden_state, processed_memory, attention_weights_cat) + + if mask is not None: + alignment.data.masked_fill_(mask, self.score_mask_value) + + attention_weights = F.softmax(alignment, dim=1) + attention_context = torch.bmm(attention_weights.unsqueeze(1), memory) + attention_context = attention_context.squeeze(1) + + return attention_context, attention_weights + + +class ForwardAttentionV2(nn.Module): + def __init__(self, attention_rnn_dim, embedding_dim, attention_dim, + attention_location_n_filters, attention_location_kernel_size): + super(ForwardAttentionV2, self).__init__() + self.query_layer = LinearNorm(attention_rnn_dim, attention_dim, + bias=False, w_init_gain='tanh') + self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False, + w_init_gain='tanh') + self.v = LinearNorm(attention_dim, 1, bias=False) + self.location_layer = LocationLayer(attention_location_n_filters, + attention_location_kernel_size, + attention_dim) + self.score_mask_value = -float(1e20) + + def get_alignment_energies(self, query, processed_memory, + attention_weights_cat): + """ + PARAMS + ------ + query: decoder output (batch, n_mel_channels * n_frames_per_step) + processed_memory: processed encoder outputs (B, T_in, attention_dim) + attention_weights_cat: prev. and cumulative att weights (B, 2, max_time) + RETURNS + ------- + alignment (batch, max_time) + """ + + processed_query = self.query_layer(query.unsqueeze(1)) + processed_attention_weights = self.location_layer(attention_weights_cat) + energies = self.v(torch.tanh( + processed_query + processed_attention_weights + processed_memory)) + + energies = energies.squeeze(-1) + return energies + + def forward(self, attention_hidden_state, memory, processed_memory, + attention_weights_cat, mask, log_alpha): + """ + PARAMS + ------ + attention_hidden_state: attention rnn last output + memory: encoder outputs + processed_memory: processed encoder outputs + attention_weights_cat: previous and cummulative attention weights + mask: binary mask for padded data + """ + log_energy = self.get_alignment_energies( + attention_hidden_state, processed_memory, attention_weights_cat) + + #log_energy = + + if mask is not None: + log_energy.data.masked_fill_(mask, self.score_mask_value) + + #attention_weights = F.softmax(alignment, dim=1) + + #content_score = log_energy.unsqueeze(1) #[B, MAX_TIME] -> [B, 1, MAX_TIME] + #log_alpha = log_alpha.unsqueeze(2) #[B, MAX_TIME] -> [B, MAX_TIME, 1] + + #log_total_score = log_alpha + content_score + + #previous_attention_weights = attention_weights_cat[:,0,:] + + log_alpha_shift_padded = [] + max_time = log_energy.size(1) + for sft in range(2): + shifted = log_alpha[:,:max_time-sft] + shift_padded = F.pad(shifted, (sft,0), 'constant', self.score_mask_value) + log_alpha_shift_padded.append(shift_padded.unsqueeze(2)) + + biased = torch.logsumexp(torch.cat(log_alpha_shift_padded,2), 2) + + log_alpha_new = biased + log_energy + + attention_weights = F.softmax(log_alpha_new, dim=1) + + attention_context = torch.bmm(attention_weights.unsqueeze(1), memory) + attention_context = attention_context.squeeze(1) + + return attention_context, attention_weights, log_alpha_new + + +class PhaseShuffle2d(nn.Module): + def __init__(self, n=2): + super(PhaseShuffle2d, self).__init__() + self.n = n + self.random = random.Random(1) + + def forward(self, x, move=None): + # x.size = (B, C, M, L) + if move is None: + move = self.random.randint(-self.n, self.n) + + if move == 0: + return x + else: + left = x[:, :, :, :move] + right = x[:, :, :, move:] + shuffled = torch.cat([right, left], dim=3) + return shuffled + +class PhaseShuffle1d(nn.Module): + def __init__(self, n=2): + super(PhaseShuffle1d, self).__init__() + self.n = n + self.random = random.Random(1) + + def forward(self, x, move=None): + # x.size = (B, C, M, L) + if move is None: + move = self.random.randint(-self.n, self.n) + + if move == 0: + return x + else: + left = x[:, :, :move] + right = x[:, :, move:] + shuffled = torch.cat([right, left], dim=2) + + return shuffled + +class MFCC(nn.Module): + def __init__(self, n_mfcc=40, n_mels=80): + super(MFCC, self).__init__() + self.n_mfcc = n_mfcc + self.n_mels = n_mels + self.norm = 'ortho' + dct_mat = audio_F.create_dct(self.n_mfcc, self.n_mels, self.norm) + self.register_buffer('dct_mat', dct_mat) + + def forward(self, mel_specgram): + if len(mel_specgram.shape) == 2: + mel_specgram = mel_specgram.unsqueeze(0) + unsqueezed = True + else: + unsqueezed = False + # (channel, n_mels, time).tranpose(...) dot (n_mels, n_mfcc) + # -> (channel, time, n_mfcc).tranpose(...) + mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2) + + # unpack batch + if unsqueezed: + mfcc = mfcc.squeeze(0) + return mfcc diff --git a/modules/quantize.py b/modules/quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..2c99097e25f0e05a2710e33c15587fdf07aa0d36 --- /dev/null +++ b/modules/quantize.py @@ -0,0 +1,613 @@ +from dac.nn.quantize import ResidualVectorQuantize +from torch import nn +from modules.wavenet import WN +from modules.style_encoder import StyleEncoder +from gradient_reversal import GradientReversal +import torch +import torchaudio +import torchaudio.functional as audio_F +import numpy as np +from alias_free_torch import * +from torch.nn.utils import weight_norm +from torch import nn, sin, pow +from einops.layers.torch import Rearrange +from dac.model.encodec import SConv1d + +def init_weights(m): + if isinstance(m, nn.Conv1d): + nn.init.trunc_normal_(m.weight, std=0.02) + nn.init.constant_(m.bias, 0) + + +def WNConv1d(*args, **kwargs): + return weight_norm(nn.Conv1d(*args, **kwargs)) + + +def WNConvTranspose1d(*args, **kwargs): + return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) + +class SnakeBeta(nn.Module): + """ + A modified Snake function which uses separate parameters for the magnitude of the periodic components + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + References: + - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://arxiv.org/abs/2006.08195 + Examples: + >>> a1 = snakebeta(256) + >>> x = torch.randn(256) + >>> x = a1(x) + """ + + def __init__( + self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False + ): + """ + Initialization. + INPUT: + - in_features: shape of the input + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + alpha is initialized to 1 by default, higher values = higher-frequency. + beta is initialized to 1 by default, higher values = higher-magnitude. + alpha will be trained along with the rest of your model. + """ + super(SnakeBeta, self).__init__() + self.in_features = in_features + + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: # log scale alphas initialized to zeros + self.alpha = nn.Parameter(torch.zeros(in_features) * alpha) + self.beta = nn.Parameter(torch.zeros(in_features) * alpha) + else: # linear scale alphas initialized to ones + self.alpha = nn.Parameter(torch.ones(in_features) * alpha) + self.beta = nn.Parameter(torch.ones(in_features) * alpha) + + self.alpha.requires_grad = alpha_trainable + self.beta.requires_grad = alpha_trainable + + self.no_div_by_zero = 0.000000001 + + def forward(self, x): + """ + Forward pass of the function. + Applies the function to the input elementwise. + SnakeBeta := x + 1/b * sin^2 (xa) + """ + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] + beta = self.beta.unsqueeze(0).unsqueeze(-1) + if self.alpha_logscale: + alpha = torch.exp(alpha) + beta = torch.exp(beta) + x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2) + + return x + +class ResidualUnit(nn.Module): + def __init__(self, dim: int = 16, dilation: int = 1): + super().__init__() + pad = ((7 - 1) * dilation) // 2 + self.block = nn.Sequential( + Activation1d(activation=SnakeBeta(dim, alpha_logscale=True)), + WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad), + Activation1d(activation=SnakeBeta(dim, alpha_logscale=True)), + WNConv1d(dim, dim, kernel_size=1), + ) + + def forward(self, x): + return x + self.block(x) + +class CNNLSTM(nn.Module): + def __init__(self, indim, outdim, head, global_pred=False): + super().__init__() + self.global_pred = global_pred + self.model = nn.Sequential( + ResidualUnit(indim, dilation=1), + ResidualUnit(indim, dilation=2), + ResidualUnit(indim, dilation=3), + Activation1d(activation=SnakeBeta(indim, alpha_logscale=True)), + Rearrange("b c t -> b t c"), + ) + self.heads = nn.ModuleList([nn.Linear(indim, outdim) for i in range(head)]) + + def forward(self, x): + # x: [B, C, T] + x = self.model(x) + if self.global_pred: + x = torch.mean(x, dim=1, keepdim=False) + outs = [head(x) for head in self.heads] + return outs + +def sequence_mask(length, max_length=None): + if max_length is None: + max_length = length.max() + x = torch.arange(max_length, dtype=length.dtype, device=length.device) + return x.unsqueeze(0) < length.unsqueeze(1) + +class MFCC(nn.Module): + def __init__(self, n_mfcc=40, n_mels=80): + super(MFCC, self).__init__() + self.n_mfcc = n_mfcc + self.n_mels = n_mels + self.norm = 'ortho' + dct_mat = audio_F.create_dct(self.n_mfcc, self.n_mels, self.norm) + self.register_buffer('dct_mat', dct_mat) + + def forward(self, mel_specgram): + if len(mel_specgram.shape) == 2: + mel_specgram = mel_specgram.unsqueeze(0) + unsqueezed = True + else: + unsqueezed = False + # (channel, n_mels, time).tranpose(...) dot (n_mels, n_mfcc) + # -> (channel, time, n_mfcc).tranpose(...) + mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2) + + # unpack batch + if unsqueezed: + mfcc = mfcc.squeeze(0) + return mfcc +class FAquantizer(nn.Module): + def __init__(self, in_dim=1024, + n_p_codebooks=1, + n_c_codebooks=2, + n_t_codebooks=2, + n_r_codebooks=3, + codebook_size=1024, + codebook_dim=8, + quantizer_dropout=0.5, + causal=False, + separate_prosody_encoder=False, + timbre_norm=False,): + super(FAquantizer, self).__init__() + conv1d_type = SConv1d# if causal else nn.Conv1d + self.prosody_quantizer = ResidualVectorQuantize( + input_dim=in_dim, + n_codebooks=n_p_codebooks, + codebook_size=codebook_size, + codebook_dim=codebook_dim, + quantizer_dropout=quantizer_dropout, + ) + + self.content_quantizer = ResidualVectorQuantize( + input_dim=in_dim, + n_codebooks=n_c_codebooks, + codebook_size=codebook_size, + codebook_dim=codebook_dim, + quantizer_dropout=quantizer_dropout, + ) + + if not timbre_norm: + self.timbre_quantizer = ResidualVectorQuantize( + input_dim=in_dim, + n_codebooks=n_t_codebooks, + codebook_size=codebook_size, + codebook_dim=codebook_dim, + quantizer_dropout=quantizer_dropout, + ) + else: + self.timbre_encoder = StyleEncoder(in_dim=80, hidden_dim=512, out_dim=in_dim) + self.timbre_linear = nn.Linear(1024, 1024 * 2) + self.timbre_linear.bias.data[:1024] = 1 + self.timbre_linear.bias.data[1024:] = 0 + self.timbre_norm = nn.LayerNorm(1024, elementwise_affine=False) + + self.residual_quantizer = ResidualVectorQuantize( + input_dim=in_dim, + n_codebooks=n_r_codebooks, + codebook_size=codebook_size, + codebook_dim=codebook_dim, + quantizer_dropout=quantizer_dropout, + ) + + if separate_prosody_encoder: + self.melspec_linear = conv1d_type(in_channels=20, out_channels=256, kernel_size=1, causal=causal) + self.melspec_encoder = WN(hidden_channels=256, kernel_size=5, dilation_rate=1, n_layers=8, gin_channels=0, p_dropout=0.2, causal=causal) + self.melspec_linear2 = conv1d_type(in_channels=256, out_channels=1024, kernel_size=1, causal=causal) + else: + pass + self.separate_prosody_encoder = separate_prosody_encoder + + self.prob_random_mask_residual = 0.75 + + SPECT_PARAMS = { + "n_fft": 2048, + "win_length": 1200, + "hop_length": 300, + } + MEL_PARAMS = { + "n_mels": 80, + } + + self.to_mel = torchaudio.transforms.MelSpectrogram( + n_mels=MEL_PARAMS["n_mels"], sample_rate=24000, **SPECT_PARAMS + ) + self.mel_mean, self.mel_std = -4, 4 + self.frame_rate = 24000 / 300 + self.hop_length = 300 + + self.is_timbre_norm = timbre_norm + if timbre_norm: + self.forward = self.forward_v2 + + def preprocess(self, wave_tensor, n_bins=20): + mel_tensor = self.to_mel(wave_tensor.squeeze(1)) + mel_tensor = (torch.log(1e-5 + mel_tensor) - self.mel_mean) / self.mel_std + return mel_tensor[:, :n_bins, :int(wave_tensor.size(-1) / self.hop_length)] + + @torch.no_grad() + def decode(self, codes): + code_c, code_p, code_t = codes.split([1, 1, 2], dim=1) + + z_c = self.content_quantizer.from_codes(code_c)[0] + z_p = self.prosody_quantizer.from_codes(code_p)[0] + z_t = self.timbre_quantizer.from_codes(code_t)[0] + + z = z_c + z_p + z_t + + return z, [z_c, z_p, z_t] + + + @torch.no_grad() + def encode(self, x, wave_segments, n_c=1): + outs = 0 + if self.separate_prosody_encoder: + prosody_feature = self.preprocess(wave_segments) + + f0_input = prosody_feature # (B, T, 20) + f0_input = self.melspec_linear(f0_input) + f0_input = self.melspec_encoder(f0_input, torch.ones(f0_input.shape[0], 1, f0_input.shape[2]).to( + f0_input.device).bool()) + f0_input = self.melspec_linear2(f0_input) + + common_min_size = min(f0_input.size(2), x.size(2)) + f0_input = f0_input[:, :, :common_min_size] + + x = x[:, :, :common_min_size] + + z_p, codes_p, latents_p, commitment_loss_p, codebook_loss_p = self.prosody_quantizer( + f0_input, 1 + ) + outs += z_p.detach() + else: + z_p, codes_p, latents_p, commitment_loss_p, codebook_loss_p = self.prosody_quantizer( + x, 1 + ) + outs += z_p.detach() + + z_c, codes_c, latents_c, commitment_loss_c, codebook_loss_c = self.content_quantizer( + x, n_c + ) + outs += z_c.detach() + + timbre_residual_feature = x - z_p.detach() - z_c.detach() + + z_t, codes_t, latents_t, commitment_loss_t, codebook_loss_t = self.timbre_quantizer( + timbre_residual_feature, 2 + ) + outs += z_t # we should not detach timbre + + residual_feature = timbre_residual_feature - z_t + + z_r, codes_r, latents_r, commitment_loss_r, codebook_loss_r = self.residual_quantizer( + residual_feature, 3 + ) + + return [codes_c, codes_p, codes_t, codes_r], [z_c, z_p, z_t, z_r] + def forward(self, x, wave_segments, noise_added_flags, recon_noisy_flags, n_c=2, n_t=2): + # timbre = self.timbre_encoder(mels, sequence_mask(mel_lens, mels.size(-1)).unsqueeze(1)) + # timbre = self.timbre_encoder(mel_segments, torch.ones(mel_segments.size(0), 1, mel_segments.size(2)).bool().to(mel_segments.device)) + outs = 0 + if self.separate_prosody_encoder: + prosody_feature = self.preprocess(wave_segments) + + f0_input = prosody_feature # (B, T, 20) + f0_input = self.melspec_linear(f0_input) + f0_input = self.melspec_encoder(f0_input, torch.ones(f0_input.shape[0], 1, f0_input.shape[2]).to(f0_input.device).bool()) + f0_input = self.melspec_linear2(f0_input) + + common_min_size = min(f0_input.size(2), x.size(2)) + f0_input = f0_input[:, :, :common_min_size] + + x = x[:, :, :common_min_size] + + z_p, codes_p, latents_p, commitment_loss_p, codebook_loss_p = self.prosody_quantizer( + f0_input, 1 + ) + outs += z_p.detach() + else: + z_p, codes_p, latents_p, commitment_loss_p, codebook_loss_p = self.prosody_quantizer( + x, 1 + ) + outs += z_p.detach() + + z_c, codes_c, latents_c, commitment_loss_c, codebook_loss_c = self.content_quantizer( + x, n_c + ) + outs += z_c.detach() + + timbre_residual_feature = x - z_p.detach() - z_c.detach() + + z_t, codes_t, latents_t, commitment_loss_t, codebook_loss_t = self.timbre_quantizer( + timbre_residual_feature, n_t + ) + outs += z_t # we should not detach timbre + + residual_feature = timbre_residual_feature - z_t + + z_r, codes_r, latents_r, commitment_loss_r, codebook_loss_r = self.residual_quantizer( + residual_feature, 3 + ) + + bsz = z_r.shape[0] + res_mask = np.random.choice( + [0, 1], + size=bsz, + p=[ + self.prob_random_mask_residual, + 1 - self.prob_random_mask_residual, + ], + ) + res_mask = ( + torch.from_numpy(res_mask).unsqueeze(1).unsqueeze(1) + ) # (B, 1, 1) + res_mask = res_mask.to( + device=z_r.device, dtype=z_r.dtype + ) + noise_must_on = noise_added_flags * recon_noisy_flags + noise_must_off = noise_added_flags * (~recon_noisy_flags) + res_mask[noise_must_on] = 1 + res_mask[noise_must_off] = 0 + + outs += z_r * res_mask + + quantized = [z_p, z_c, z_t, z_r] + commitment_losses = commitment_loss_p + commitment_loss_c + commitment_loss_t + commitment_loss_r + codebook_losses = codebook_loss_p + codebook_loss_c + codebook_loss_t + codebook_loss_r + + return outs, quantized, commitment_losses, codebook_losses + def forward_v2(self, x, wave_segments, n_c=1, n_t=2, full_waves=None, wave_lens=None, return_codes=False): + # timbre = self.timbre_encoder(x, sequence_mask(mel_lens, mels.size(-1)).unsqueeze(1)) + if full_waves is None: + mel = self.preprocess(wave_segments, n_bins=80) + timbre = self.timbre_encoder(mel, torch.ones(mel.size(0), 1, mel.size(2)).bool().to(mel.device)) + else: + mel = self.preprocess(full_waves, n_bins=80) + timbre = self.timbre_encoder(mel, sequence_mask(wave_lens // self.hop_length, mel.size(-1)).unsqueeze(1)) + outs = 0 + if self.separate_prosody_encoder: + prosody_feature = self.preprocess(wave_segments) + + f0_input = prosody_feature # (B, T, 20) + f0_input = self.melspec_linear(f0_input) + f0_input = self.melspec_encoder(f0_input, torch.ones(f0_input.shape[0], 1, f0_input.shape[2]).to( + f0_input.device).bool()) + f0_input = self.melspec_linear2(f0_input) + + common_min_size = min(f0_input.size(2), x.size(2)) + f0_input = f0_input[:, :, :common_min_size] + + x = x[:, :, :common_min_size] + + z_p, codes_p, latents_p, commitment_loss_p, codebook_loss_p = self.prosody_quantizer( + f0_input, 1 + ) + outs += z_p.detach() + else: + z_p, codes_p, latents_p, commitment_loss_p, codebook_loss_p = self.prosody_quantizer( + x, 1 + ) + outs += z_p.detach() + + z_c, codes_c, latents_c, commitment_loss_c, codebook_loss_c = self.content_quantizer( + x, n_c + ) + outs += z_c.detach() + + residual_feature = x - z_p.detach() - z_c.detach() + + z_r, codes_r, latents_r, commitment_loss_r, codebook_loss_r = self.residual_quantizer( + residual_feature, 3 + ) + + bsz = z_r.shape[0] + res_mask = np.random.choice( + [0, 1], + size=bsz, + p=[ + self.prob_random_mask_residual, + 1 - self.prob_random_mask_residual, + ], + ) + res_mask = ( + torch.from_numpy(res_mask).unsqueeze(1).unsqueeze(1) + ) # (B, 1, 1) + res_mask = res_mask.to( + device=z_r.device, dtype=z_r.dtype + ) + + if not self.training: + res_mask = torch.ones_like(res_mask) + outs += z_r * res_mask + + quantized = [z_p, z_c, z_r] + codes = [codes_p, codes_c, codes_r] + commitment_losses = commitment_loss_p + commitment_loss_c + commitment_loss_r + codebook_losses = codebook_loss_p + codebook_loss_c + codebook_loss_r + + style = self.timbre_linear(timbre).unsqueeze(2) # (B, 2d, 1) + gamma, beta = style.chunk(2, 1) # (B, d, 1) + outs = outs.transpose(1, 2) + outs = self.timbre_norm(outs) + outs = outs.transpose(1, 2) + outs = outs * gamma + beta + + if return_codes: + return outs, quantized, commitment_losses, codebook_losses, timbre, codes + else: + return outs, quantized, commitment_losses, codebook_losses, timbre + +class FApredictors(nn.Module): + def __init__(self, + in_dim=1024, + use_gr_content_f0=False, + use_gr_prosody_phone=False, + use_gr_residual_f0=False, + use_gr_residual_phone=False, + use_gr_timbre_content=True, + use_gr_timbre_prosody=True, + use_gr_x_timbre=False, + norm_f0=True, + timbre_norm=False, + use_gr_content_global_f0=False, + ): + super(FApredictors, self).__init__() + self.f0_predictor = CNNLSTM(in_dim, 1, 2) + self.phone_predictor = CNNLSTM(in_dim, 1024, 1) + if timbre_norm: + self.timbre_predictor = nn.Linear(in_dim, 20000) + else: + self.timbre_predictor = CNNLSTM(in_dim, 20000, 1, global_pred=True) + + self.use_gr_content_f0 = use_gr_content_f0 + self.use_gr_prosody_phone = use_gr_prosody_phone + self.use_gr_residual_f0 = use_gr_residual_f0 + self.use_gr_residual_phone = use_gr_residual_phone + self.use_gr_timbre_content = use_gr_timbre_content + self.use_gr_timbre_prosody = use_gr_timbre_prosody + self.use_gr_x_timbre = use_gr_x_timbre + + self.rev_f0_predictor = nn.Sequential( + GradientReversal(alpha=1.0), CNNLSTM(in_dim, 1, 2) + ) + self.rev_content_predictor = nn.Sequential( + GradientReversal(alpha=1.0), CNNLSTM(in_dim, 1024, 1) + ) + self.rev_timbre_predictor = nn.Sequential( + GradientReversal(alpha=1.0), CNNLSTM(in_dim, 20000, 1, global_pred=True) + ) + + self.norm_f0 = norm_f0 + self.timbre_norm = timbre_norm + if timbre_norm: + self.forward = self.forward_v2 + self.global_f0_predictor = nn.Linear(in_dim, 1) + + self.use_gr_content_global_f0 = use_gr_content_global_f0 + if use_gr_content_global_f0: + self.rev_global_f0_predictor = nn.Sequential( + GradientReversal(alpha=1.0), CNNLSTM(in_dim, 1, 1, global_pred=True) + ) + def forward(self, quantized): + prosody_latent = quantized[0] + content_latent = quantized[1] + timbre_latent = quantized[2] + residual_latent = quantized[3] + content_pred = self.phone_predictor(content_latent)[0] + + if self.norm_f0: + spk_pred = self.timbre_predictor(timbre_latent)[0] + f0_pred, uv_pred = self.f0_predictor(prosody_latent) + else: + spk_pred = self.timbre_predictor(timbre_latent + prosody_latent)[0] + f0_pred, uv_pred = self.f0_predictor(prosody_latent + timbre_latent) + + prosody_rev_latent = torch.zeros_like(quantized[0]) + if self.use_gr_content_f0: + prosody_rev_latent += quantized[1] + if self.use_gr_timbre_prosody: + prosody_rev_latent += quantized[2] + if self.use_gr_residual_f0: + prosody_rev_latent += quantized[3] + rev_f0_pred, rev_uv_pred = self.rev_f0_predictor(prosody_rev_latent) + + content_rev_latent = torch.zeros_like(quantized[1]) + if self.use_gr_prosody_phone: + content_rev_latent += quantized[0] + if self.use_gr_timbre_content: + content_rev_latent += quantized[2] + if self.use_gr_residual_phone: + content_rev_latent += quantized[3] + rev_content_pred = self.rev_content_predictor(content_rev_latent)[0] + + if self.norm_f0: + timbre_rev_latent = quantized[0] + quantized[1] + quantized[3] + else: + timbre_rev_latent = quantized[1] + quantized[3] + if self.use_gr_x_timbre: + x_spk_pred = self.rev_timbre_predictor(timbre_rev_latent)[0] + else: + x_spk_pred = None + + + + preds = { + 'f0': f0_pred, + 'uv': uv_pred, + 'content': content_pred, + 'timbre': spk_pred, + } + + rev_preds = { + 'rev_f0': rev_f0_pred, + 'rev_uv': rev_uv_pred, + 'rev_content': rev_content_pred, + 'x_timbre': x_spk_pred, + } + return preds, rev_preds + def forward_v2(self, quantized, timbre): + assert self.use_gr_content_global_f0 + prosody_latent = quantized[0] + content_latent = quantized[1] + residual_latent = quantized[2] + content_pred = self.phone_predictor(content_latent)[0] + + # spk_pred = self.timbre_predictor(timbre)[0] + f0_pred, uv_pred = self.f0_predictor(prosody_latent) + + prosody_rev_latent = torch.zeros_like(prosody_latent) + if self.use_gr_content_f0: + prosody_rev_latent += content_latent + if self.use_gr_residual_f0: + prosody_rev_latent += residual_latent + rev_f0_pred, rev_uv_pred = self.rev_f0_predictor(prosody_rev_latent) + + content_rev_latent = torch.zeros_like(content_latent) + if self.use_gr_prosody_phone: + content_rev_latent += prosody_latent + if self.use_gr_residual_phone: + content_rev_latent += residual_latent + rev_content_pred = self.rev_content_predictor(content_rev_latent)[0] + + timbre_rev_latent = prosody_latent + content_latent + residual_latent + if self.use_gr_x_timbre: + x_spk_pred = self.rev_timbre_predictor(timbre_rev_latent)[0] + else: + x_spk_pred = None + + global_f0_pred = self.global_f0_predictor(timbre) + if self.use_gr_content_global_f0: + rev_global_f0_pred = self.rev_global_f0_predictor(content_latent + prosody_latent + residual_latent)[0] + + preds = { + 'f0': f0_pred, + 'uv': uv_pred, + 'content': content_pred, + 'timbre': None, + 'global_f0': global_f0_pred, + } + + rev_preds = { + 'rev_f0': rev_f0_pred, + 'rev_uv': rev_uv_pred, + 'rev_content': rev_content_pred, + 'x_timbre': x_spk_pred, + 'rev_global_f0': rev_global_f0_pred, + } + return preds, rev_preds diff --git a/modules/redecoder.py b/modules/redecoder.py new file mode 100644 index 0000000000000000000000000000000000000000..36131c88927375be32f528de61da6d1867ca97fd --- /dev/null +++ b/modules/redecoder.py @@ -0,0 +1,63 @@ +import torch +from modules.wavenet import WN +# +class Redecoder(torch.nn.Module): + def __init__(self, args): + super(Redecoder, self).__init__() + self.n_p_codebooks = args.n_p_codebooks # number of prosody codebooks + self.n_c_codebooks = args.n_c_codebooks # number of content codebooks + self.codebook_size = 1024 # codebook size + self.encoder_type = args.encoder_type + if args.encoder_type == "wavenet": + self.embed_dim = args.wavenet_embed_dim + self.encoder = WN(hidden_channels=self.embed_dim, kernel_size=5, dilation_rate=1, n_layers=16, gin_channels=1024 + , p_dropout=0.2, causal=args.decoder_causal) + self.conv_out = torch.nn.Conv1d(self.embed_dim, 1024, 1) + self.prosody_embed = torch.nn.ModuleList( + [torch.nn.Embedding(self.codebook_size, self.embed_dim) for _ in range(self.n_p_codebooks)]) + self.content_embed = torch.nn.ModuleList( + [torch.nn.Embedding(self.codebook_size, self.embed_dim) for _ in range(self.n_c_codebooks)]) + elif args.encoder_type == "mamba": + from modules.mamba import Mambo + self.embed_dim = args.mamba_embed_dim + self.encoder = Mambo(d_model=self.embed_dim, n_layer=24, vocab_size=1024, + prob_random_mask_prosody=args.prob_random_mask_prosody, + prob_random_mask_content=args.prob_random_mask_content,) + self.conv_out = torch.nn.Linear(self.embed_dim, 1024) + self.forward = self.forward_v2 + self.prosody_embed = torch.nn.ModuleList( + [torch.nn.Embedding(self.codebook_size, self.embed_dim) for _ in range(self.n_p_codebooks)]) + self.content_embed = torch.nn.ModuleList( + [torch.nn.Embedding(self.codebook_size, self.embed_dim) for _ in range(self.n_c_codebooks)]) + else: + raise NotImplementedError + + def forward(self, p_code, c_code, timbre_vec, use_p_code=True, use_c_code=True, n_c=2): + B, _, T = p_code.size() + p_embed = torch.zeros(B, T, self.embed_dim).to(p_code.device) + c_embed = torch.zeros(B, T, self.embed_dim).to(c_code.device) + if use_p_code: + for i in range(self.n_p_codebooks): + p_embed += self.prosody_embed[i](p_code[:, i, :]) + if use_c_code: + for i in range(n_c): + c_embed += self.content_embed[i](c_code[:, i, :]) + x = p_embed + c_embed + x = self.encoder(x.transpose(1, 2), x_mask=torch.ones(B, 1, T).to(p_code.device), g=timbre_vec.unsqueeze(2)) + x = self.conv_out(x) + return x + def forward_v2(self, p_code, c_code, timbre_vec, use_p_code=True, use_c_code=True, n_c=2): + x = self.encoder(torch.cat([p_code, c_code], dim=1), timbre_vec) + x = self.conv_out(x).transpose(1, 2) + return x + @torch.no_grad() + def generate(self, prompt_ids, input_ids, prompt_context, timbre, use_p_code=True, use_c_code=True, n_c=2): + from modules.mamba import InferenceParams + assert self.encoder_type == "mamba" + inference_params = InferenceParams(max_seqlen=8192, max_batch_size=1) + # run once with prompt to initialize memory first + prompt_out = self.encoder(prompt_ids, prompt_context, timbre, inference_params=inference_params) + for i in range(input_ids.size(-1)): + input_id = input_ids[..., i] + prompt_out = self.encoder(input_id, prompt_out, timbre, inference_params=inference_params) + diff --git a/modules/style_encoder.py b/modules/style_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..4290c020861c313897bc4745d1607410ee3a10ed --- /dev/null +++ b/modules/style_encoder.py @@ -0,0 +1,91 @@ +from . import attentions +from torch import nn +import torch +from torch.nn import functional as F + +class Mish(nn.Module): + def __init__(self): + super(Mish, self).__init__() + def forward(self, x): + return x * torch.tanh(F.softplus(x)) + + +class Conv1dGLU(nn.Module): + ''' + Conv1d + GLU(Gated Linear Unit) with residual connection. + For GLU refer to https://arxiv.org/abs/1612.08083 paper. + ''' + + def __init__(self, in_channels, out_channels, kernel_size, dropout): + super(Conv1dGLU, self).__init__() + self.out_channels = out_channels + self.conv1 = nn.Conv1d(in_channels, 2 * out_channels, kernel_size=kernel_size, padding=2) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + residual = x + x = self.conv1(x) + x1, x2 = torch.split(x, split_size_or_sections=self.out_channels, dim=1) + x = x1 * torch.sigmoid(x2) + x = residual + self.dropout(x) + return x + +class StyleEncoder(torch.nn.Module): + def __init__(self, in_dim=513, hidden_dim=128, out_dim=256): + + super().__init__() + + self.in_dim = in_dim # Linear 513 wav2vec 2.0 1024 + self.hidden_dim = hidden_dim + self.out_dim = out_dim + self.kernel_size = 5 + self.n_head = 2 + self.dropout = 0.1 + + self.spectral = nn.Sequential( + nn.Conv1d(self.in_dim, self.hidden_dim, 1), + Mish(), + nn.Dropout(self.dropout), + nn.Conv1d(self.hidden_dim, self.hidden_dim, 1), + Mish(), + nn.Dropout(self.dropout) + ) + + self.temporal = nn.Sequential( + Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout), + Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout), + ) + + self.slf_attn = attentions.MultiHeadAttention(self.hidden_dim, self.hidden_dim, self.n_head, p_dropout = self.dropout, proximal_bias= False, proximal_init=True) + self.atten_drop = nn.Dropout(self.dropout) + self.fc = nn.Conv1d(self.hidden_dim, self.out_dim, 1) + + def forward(self, x, mask=None): + + # spectral + x = self.spectral(x)*mask + # temporal + x = self.temporal(x)*mask + + # self-attention + attn_mask = mask.unsqueeze(2) * mask.unsqueeze(-1) + y = self.slf_attn(x,x, attn_mask=attn_mask) + x = x+ self.atten_drop(y) + + # fc + x = self.fc(x) + + # temoral average pooling + w = self.temporal_avg_pool(x, mask=mask) + + return w + + def temporal_avg_pool(self, x, mask=None): + if mask is None: + out = torch.mean(x, dim=2) + else: + len_ = mask.sum(dim=2) + x = x.sum(dim=2) + + out = torch.div(x, len_) + return out \ No newline at end of file diff --git a/modules/wavenet.py b/modules/wavenet.py new file mode 100644 index 0000000000000000000000000000000000000000..d469ffb6379c7cc726ca8ea440e9f61a7a0ddb03 --- /dev/null +++ b/modules/wavenet.py @@ -0,0 +1,174 @@ +import math +import torch +from torch import nn +from torch.nn import functional as F + +from dac.model.encodec import SConv1d + +from . import commons +LRELU_SLOPE = 0.1 + +class LayerNorm(nn.Module): + def __init__(self, channels, eps=1e-5): + super().__init__() + self.channels = channels + self.eps = eps + + self.gamma = nn.Parameter(torch.ones(channels)) + self.beta = nn.Parameter(torch.zeros(channels)) + + def forward(self, x): + x = x.transpose(1, -1) + x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) + return x.transpose(1, -1) + + +class ConvReluNorm(nn.Module): + def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout): + super().__init__() + self.in_channels = in_channels + self.hidden_channels = hidden_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.p_dropout = p_dropout + assert n_layers > 1, "Number of layers should be larger than 0." + + self.conv_layers = nn.ModuleList() + self.norm_layers = nn.ModuleList() + self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2)) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.relu_drop = nn.Sequential( + nn.ReLU(), + nn.Dropout(p_dropout)) + for _ in range(n_layers - 1): + self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2)) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.proj = nn.Conv1d(hidden_channels, out_channels, 1) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward(self, x, x_mask): + x_org = x + for i in range(self.n_layers): + x = self.conv_layers[i](x * x_mask) + x = self.norm_layers[i](x) + x = self.relu_drop(x) + x = x_org + self.proj(x) + return x * x_mask + + +class DDSConv(nn.Module): + """ + Dialted and Depth-Separable Convolution + """ + + def __init__(self, channels, kernel_size, n_layers, p_dropout=0.): + super().__init__() + self.channels = channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.p_dropout = p_dropout + + self.drop = nn.Dropout(p_dropout) + self.convs_sep = nn.ModuleList() + self.convs_1x1 = nn.ModuleList() + self.norms_1 = nn.ModuleList() + self.norms_2 = nn.ModuleList() + for i in range(n_layers): + dilation = kernel_size ** i + padding = (kernel_size * dilation - dilation) // 2 + self.convs_sep.append(nn.Conv1d(channels, channels, kernel_size, + groups=channels, dilation=dilation, padding=padding + )) + self.convs_1x1.append(nn.Conv1d(channels, channels, 1)) + self.norms_1.append(LayerNorm(channels)) + self.norms_2.append(LayerNorm(channels)) + + def forward(self, x, x_mask, g=None): + if g is not None: + x = x + g + for i in range(self.n_layers): + y = self.convs_sep[i](x * x_mask) + y = self.norms_1[i](y) + y = F.gelu(y) + y = self.convs_1x1[i](y) + y = self.norms_2[i](y) + y = F.gelu(y) + y = self.drop(y) + x = x + y + return x * x_mask + + +class WN(torch.nn.Module): + def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0, causal=False): + super(WN, self).__init__() + conv1d_type = SConv1d + assert (kernel_size % 2 == 1) + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size, + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.gin_channels = gin_channels + self.p_dropout = p_dropout + + self.in_layers = torch.nn.ModuleList() + self.res_skip_layers = torch.nn.ModuleList() + self.drop = nn.Dropout(p_dropout) + + if gin_channels != 0: + self.cond_layer = conv1d_type(gin_channels, 2 * hidden_channels * n_layers, 1, norm='weight_norm') + + for i in range(n_layers): + dilation = dilation_rate ** i + padding = int((kernel_size * dilation - dilation) / 2) + in_layer = conv1d_type(hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilation, + padding=padding, norm='weight_norm', causal=causal) + self.in_layers.append(in_layer) + + # last one is not necessary + if i < n_layers - 1: + res_skip_channels = 2 * hidden_channels + else: + res_skip_channels = hidden_channels + + res_skip_layer = conv1d_type(hidden_channels, res_skip_channels, 1, norm='weight_norm', causal=causal) + self.res_skip_layers.append(res_skip_layer) + + def forward(self, x, x_mask, g=None, **kwargs): + output = torch.zeros_like(x) + n_channels_tensor = torch.IntTensor([self.hidden_channels]) + + if g is not None: + g = self.cond_layer(g) + + for i in range(self.n_layers): + x_in = self.in_layers[i](x) + if g is not None: + cond_offset = i * 2 * self.hidden_channels + g_l = g[:, cond_offset:cond_offset + 2 * self.hidden_channels, :] + else: + g_l = torch.zeros_like(x_in) + + acts = commons.fused_add_tanh_sigmoid_multiply( + x_in, + g_l, + n_channels_tensor) + acts = self.drop(acts) + + res_skip_acts = self.res_skip_layers[i](acts) + if i < self.n_layers - 1: + res_acts = res_skip_acts[:, :self.hidden_channels, :] + x = (x + res_acts) * x_mask + output = output + res_skip_acts[:, self.hidden_channels:, :] + else: + output = output + res_skip_acts + return output * x_mask + + def remove_weight_norm(self): + if self.gin_channels != 0: + torch.nn.utils.remove_weight_norm(self.cond_layer) + for l in self.in_layers: + torch.nn.utils.remove_weight_norm(l) + for l in self.res_skip_layers: + torch.nn.utils.remove_weight_norm(l) \ No newline at end of file diff --git a/quantize/__init__.py b/quantize/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..918ae65035a933b7717395bb24296fe10210785f --- /dev/null +++ b/quantize/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .fvq import * +from .rvq import * diff --git a/quantize/__pycache__/__init__.cpython-310.pyc b/quantize/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d01c5f83c8e18324734a229001e301fc2b957ef Binary files /dev/null and b/quantize/__pycache__/__init__.cpython-310.pyc differ diff --git a/quantize/__pycache__/fvq.cpython-310.pyc b/quantize/__pycache__/fvq.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e1a4beb91ca5cd663cdb0d5511d22086ae71da48 Binary files /dev/null and b/quantize/__pycache__/fvq.cpython-310.pyc differ diff --git a/quantize/__pycache__/rvq.cpython-310.pyc b/quantize/__pycache__/rvq.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..132d050750b30b5a758cf804fb0651db6fe0c5fd Binary files /dev/null and b/quantize/__pycache__/rvq.cpython-310.pyc differ diff --git a/quantize/fvq.py b/quantize/fvq.py new file mode 100644 index 0000000000000000000000000000000000000000..394f1b9e29677ffe208b1c5ceef989597eedc78d --- /dev/null +++ b/quantize/fvq.py @@ -0,0 +1,116 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch.nn.utils import weight_norm + + +class FactorizedVectorQuantize(nn.Module): + def __init__(self, dim, codebook_size, codebook_dim, commitment, **kwargs): + super().__init__() + self.codebook_size = codebook_size + self.codebook_dim = codebook_dim + self.commitment = commitment + + if dim != self.codebook_dim: + self.in_proj = weight_norm(nn.Linear(dim, self.codebook_dim)) + self.out_proj = weight_norm(nn.Linear(self.codebook_dim, dim)) + else: + self.in_proj = nn.Identity() + self.out_proj = nn.Identity() + self._codebook = nn.Embedding(codebook_size, self.codebook_dim) + + @property + def codebook(self): + return self._codebook + + def forward(self, z): + """Quantized the input tensor using a fixed codebook and returns + the corresponding codebook vectors + + Parameters + ---------- + z : Tensor[B x D x T] + + Returns + ------- + Tensor[B x D x T] + Quantized continuous representation of input + Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + Tensor[1] + Codebook loss to update the codebook + Tensor[B x T] + Codebook indices (quantized discrete representation of input) + Tensor[B x D x T] + Projected latents (continuous representation of input before quantization) + """ + # transpose since we use linear + + z = rearrange(z, "b d t -> b t d") + + # Factorized codes project input into low-dimensional space + z_e = self.in_proj(z) # z_e : (B x T x D) + z_e = rearrange(z_e, "b t d -> b d t") + z_q, indices = self.decode_latents(z_e) + + if self.training: + commitment_loss = ( + F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2]) + * self.commitment + ) + codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2]) + commit_loss = commitment_loss + codebook_loss + else: + commit_loss = torch.zeros(z.shape[0], device=z.device) + + z_q = ( + z_e + (z_q - z_e).detach() + ) # noop in forward pass, straight-through gradient estimator in backward pass + + z_q = rearrange(z_q, "b d t -> b t d") + z_q = self.out_proj(z_q) + z_q = rearrange(z_q, "b t d -> b d t") + + return z_q, indices, commit_loss + + def vq2emb(self, vq, proj=True): + emb = self.embed_code(vq) + if proj: + emb = self.out_proj(emb) + return emb.transpose(1, 2) + + def get_emb(self): + return self.codebook.weight + + def embed_code(self, embed_id): + return F.embedding(embed_id, self.codebook.weight) + + def decode_code(self, embed_id): + return self.embed_code(embed_id).transpose(1, 2) + + def decode_latents(self, latents): + encodings = rearrange(latents, "b d t -> (b t) d") + codebook = self.codebook.weight # codebook: (N x D) + # L2 normalize encodings and codebook + encodings = F.normalize(encodings) + codebook = F.normalize(codebook) + + # Compute euclidean distance with codebook + dist = ( + encodings.pow(2).sum(1, keepdim=True) + - 2 * encodings @ codebook.t() + + codebook.pow(2).sum(1, keepdim=True).t() + ) + indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0)) + z_q = self.decode_code(indices) + return z_q, indices diff --git a/quantize/rvq.py b/quantize/rvq.py new file mode 100644 index 0000000000000000000000000000000000000000..55d2e087a37a84ad367ac4caf937857783dac3c0 --- /dev/null +++ b/quantize/rvq.py @@ -0,0 +1,87 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +import torch +from torch import nn +from .fvq import FactorizedVectorQuantize + + +class ResidualVQ(nn.Module): + """Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf""" + + def __init__(self, *, num_quantizers, codebook_size, **kwargs): + super().__init__() + VQ = FactorizedVectorQuantize + if type(codebook_size) == int: + codebook_size = [codebook_size] * num_quantizers + self.layers = nn.ModuleList( + [VQ(codebook_size=2**size, **kwargs) for size in codebook_size] + ) + self.num_quantizers = num_quantizers + self.quantizer_dropout = kwargs.get("quantizer_dropout", 0.0) + self.dropout_type = kwargs.get("dropout_type", None) + + def forward(self, x, n_quantizers=None): + quantized_out = 0.0 + residual = x + + all_losses = [] + all_indices = [] + all_quantized = [] + + if n_quantizers is None: + n_quantizers = self.num_quantizers + if self.training: + n_quantizers = torch.ones((x.shape[0],)) * self.num_quantizers + 1 + if self.dropout_type == "linear": + dropout = torch.randint(1, self.num_quantizers + 1, (x.shape[0],)) + elif self.dropout_type == "exp": + dropout = torch.randint( + 1, int(math.log2(self.num_quantizers)), (x.shape[0],) + ) + dropout = torch.pow(2, dropout) + n_dropout = int(x.shape[0] * self.quantizer_dropout) + n_quantizers[:n_dropout] = dropout[:n_dropout] + n_quantizers = n_quantizers.to(x.device) + + for idx, layer in enumerate(self.layers): + if not self.training and idx >= n_quantizers: + break + quantized, indices, loss = layer(residual) + + mask = ( + torch.full((x.shape[0],), fill_value=idx, device=x.device) + < n_quantizers + ) + + residual = residual - quantized + + quantized_out = quantized_out + quantized * mask[:, None, None] + + # loss + loss = (loss * mask).mean() + + all_indices.append(indices) + all_losses.append(loss) + all_quantized.append(quantized) + all_losses, all_indices, all_quantized = map( + torch.stack, (all_losses, all_indices, all_quantized) + ) + return quantized_out, all_indices, all_losses, all_quantized + + def vq2emb(self, vq): + # vq: [n_quantizers, B, T] + quantized_out = 0.0 + for idx, layer in enumerate(self.layers): + quantized = layer.vq2emb(vq[idx]) + quantized_out += quantized + return quantized_out + + def get_emb(self): + embs = [] + for idx, layer in enumerate(self.layers): + embs.append(layer.get_emb()) + return embs diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..62076f8ba771f149cfdd3348b9bc18728f46e13c --- /dev/null +++ b/requirements.txt @@ -0,0 +1,14 @@ +torch +numpy +einops +tqdm +argbind +torchaudio +soundfile +click +PyYAML +accelerate +torchmetrics +munch +librosa +transformers \ No newline at end of file diff --git a/transformer_modules/__init__.py b/transformer_modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/transformer_modules/__pycache__/__init__.cpython-310.pyc b/transformer_modules/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2af7355849112f7ea9b435994c711f0edd5b5970 Binary files /dev/null and b/transformer_modules/__pycache__/__init__.cpython-310.pyc differ diff --git a/transformer_modules/__pycache__/optim.cpython-310.pyc b/transformer_modules/__pycache__/optim.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee40e295fb6301fd9cdb8ee46aef43e818a90b3c Binary files /dev/null and b/transformer_modules/__pycache__/optim.cpython-310.pyc differ diff --git a/transformer_modules/activation.py b/transformer_modules/activation.py new file mode 100644 index 0000000000000000000000000000000000000000..234b586905d31772d27c5e105ca1f5e749bdea8c --- /dev/null +++ b/transformer_modules/activation.py @@ -0,0 +1,599 @@ +from typing import Optional, Tuple + +import torch +from torch import Tensor +from torch.nn import Linear, Module +from torch.nn import functional as F +from torch.nn.init import constant_, xavier_normal_, xavier_uniform_ +from torch.nn.modules.linear import NonDynamicallyQuantizableLinear +from torch.nn.parameter import Parameter +import math + +from .rotary_embedding import RotaryEmbedding + +rotary_emb = None + +def multi_head_attention_forward( + x, + ipw, + ipb, + opw, + opb, + n_head, + attn_mask, + dropout=0.0, + past_kv=None, + use_cache=False, + use_rope=False, + rope=None, +): + rotary_emb = rope + B, T, C = x.size() + + q, k, v = torch._C._nn.linear(x, ipw, ipb).chunk(3, dim=-1) + k = k.view(B, T, n_head, C // n_head).transpose(1, 2) # (B, nh, T, hs) + q = q.view(B, T, n_head, C // n_head).transpose(1, 2) # (B, nh, T, hs) + v = v.view(B, T, n_head, C // n_head).transpose(1, 2) # (B, nh, T, hs) + + # implement RoPE here + if use_rope: + if rotary_emb is None: + rotary_emb = RotaryEmbedding(dim = C // n_head) + rotary_emb.to(x.device) + if past_kv is None: + try: + q = rotary_emb.rotate_queries_or_keys(q) + k = rotary_emb.rotate_queries_or_keys(k) + except: + print("?") + else: + q = rotary_emb.rotate_queries_or_keys(q, offset=past_kv[0].shape[-2]) + k = rotary_emb.rotate_queries_or_keys(k, offset=past_kv[0].shape[-2]) + if past_kv is not None: + past_key = past_kv[0] + past_value = past_kv[1] + k = torch.cat((past_key, k), dim=-2) + v = torch.cat((past_value, v), dim=-2) + + FULL_T = k.shape[-2] + + if use_cache is True: + present = [k, v] + else: + present = None + + if T == 1 or attn_mask is None: + with torch.backends.cuda.sdp_kernel(): + y = F.scaled_dot_product_attention(q, k, v) + else: + with torch.backends.cuda.sdp_kernel(): + if attn_mask.dtype == torch.bool: + y = F.scaled_dot_product_attention(q, k, v, attn_mask=~attn_mask[:, :, FULL_T - T:FULL_T, :FULL_T], dropout_p=dropout) + else: + y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask[:, :, FULL_T - T:FULL_T, :FULL_T], dropout_p=dropout) + + y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side + y = torch._C._nn.linear(y, opw, opb) + return (y, present) + +def multi_head_cross_attention_forward( + q, + k, + v, + ipw, + ipb, + opw, + opb, + n_head, + attn_mask, + dropout=0.0, + past_kv=None, + use_cache=False, + use_rope=False, + rope=None, +): + rotary_emb = rope + B, qT, C = q.size() + _, kT, _ = k.size() + _, vT, _ = v.size() + + q = torch._C._nn.linear(q, ipw[:C, :], ipb[:C]) + k = torch._C._nn.linear(k, ipw[C:2 * C, :], ipb[C:2 * C]) + v = torch._C._nn.linear(v, ipw[2 * C:, :], ipb[2 * C:]) + q = q.view(B, qT, n_head, C // n_head).transpose(1, 2) # (B, nh, T, hs) + k = k.view(B, kT, n_head, C // n_head).transpose(1, 2) # (B, nh, T, hs) + v = v.view(B, vT, n_head, C // n_head).transpose(1, 2) # (B, nh, T, hs) + + # implement RoPE here + if use_rope: + if rotary_emb is None: + rotary_emb = RotaryEmbedding(dim = C // n_head) + rotary_emb.to(q.device) + if past_kv is None: + q = rotary_emb.rotate_queries_or_keys(q) + k = rotary_emb.rotate_queries_or_keys(k) + else: + q = rotary_emb.rotate_queries_or_keys(q, offset=past_kv[0].shape[-2]) + k = rotary_emb.rotate_queries_or_keys(k, offset=past_kv[0].shape[-2]) + else: + q, k = q.contiguous(), k.contiguous() + if past_kv is not None: + past_key = past_kv[0] + past_value = past_kv[1] + k = torch.cat((past_key, k), dim=-2) + v = torch.cat((past_value, v), dim=-2) + + q_FULL_T = q.shape[-2] + k_FULL_T = k.shape[-2] + + if use_cache is True: + present = [k, v] + else: + present = None + + if qT == 1 or attn_mask is None: + with torch.backends.cuda.sdp_kernel(): + y = F.scaled_dot_product_attention(q, k, v) + else: + with torch.backends.cuda.sdp_kernel(): + if attn_mask.dtype == torch.bool: + y = F.scaled_dot_product_attention(q, k, v, attn_mask=~attn_mask[q_FULL_T - qT:q_FULL_T, :k_FULL_T], dropout_p=dropout) + else: + y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask[q_FULL_T - qT:q_FULL_T, :k_FULL_T], dropout_p=dropout) + + # calculate and return attention weights + attn_map = torch.bmm(q.view(-1, qT, C // n_head), k.view(-1, kT, C // n_head).transpose(1, 2)).view(B, n_head, qT, kT) + attn_map = attn_map / math.sqrt(C // n_head) + if attn_mask.dtype == torch.bool: + attn_map = attn_map.masked_fill(attn_mask[q_FULL_T - qT:q_FULL_T, :k_FULL_T], -1e5) + else: + attn_map += attn_mask[q_FULL_T - qT:q_FULL_T, :k_FULL_T].unsqueeze(1) + attn_map = F.softmax(attn_map.mean(dim=1), dim=-1) + y = y.transpose(1, 2).contiguous().view(B, qT, C) # re-assemble all head outputs side by side + y = torch._C._nn.linear(y, opw, opb) + return (y, present), attn_map + + +class MultiheadAttention(Module): + r"""Allows the model to jointly attend to information + from different representation subspaces as described in the paper: + `Attention Is All You Need `_. + + Multi-Head Attention is defined as: + + .. math:: + \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O + + where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`. + + ``forward()`` will use a special optimized implementation if all of the following + conditions are met: + + - self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor. This + restriction will be loosened in the future.) + - Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad`` + - training is disabled (using ``.eval()``) + - dropout is 0 + - ``add_bias_kv`` is ``False`` + - ``add_zero_attn`` is ``False`` + - ``batch_first`` is ``True`` and the input is batched + - ``kdim`` and ``vdim`` are equal to ``embed_dim`` + - at most one of ``key_padding_mask`` or ``attn_mask`` is passed + - if a `NestedTensor `_ is passed, neither ``key_padding_mask`` + nor ``attn_mask`` is passed + + If the optimized implementation is in use, a + `NestedTensor `_ can be passed for + ``query``/``key``/``value`` to represent padding more efficiently than using a + padding mask. In this case, a `NestedTensor `_ + will be returned, and an additional speedup proportional to the fraction of the input + that is padding can be expected. + + Args: + embed_dim: Total dimension of the model. + num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split + across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``). + dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout). + bias: If specified, adds bias to input / output projection layers. Default: ``True``. + add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``. + add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1. + Default: ``False``. + kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``). + vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``). + batch_first: If ``True``, then the input and output tensors are provided + as (batch, seq, feature). Default: ``False`` (seq, batch, feature). + + Examples:: + + >>> # xdoctest: +SKIP + >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) + >>> attn_output, attn_output_weights = multihead_attn(query, key, value) + + """ + __constants__ = ["batch_first"] + bias_k: Optional[torch.Tensor] + bias_v: Optional[torch.Tensor] + + def __init__( + self, + embed_dim, + num_heads, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + kdim=None, + vdim=None, + batch_first=False, + linear1_cls=Linear, + linear2_cls=Linear, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super(MultiheadAttention, self).__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self._qkv_same_embed_dim = ( + self.kdim == embed_dim and self.vdim == embed_dim + ) + + self.num_heads = num_heads + self.dropout = dropout + self.batch_first = batch_first + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + + if add_bias_kv: + self.bias_k = Parameter( + torch.empty((1, 1, embed_dim), **factory_kwargs) + ) + self.bias_v = Parameter( + torch.empty((1, 1, embed_dim), **factory_kwargs) + ) + else: + self.bias_k = self.bias_v = None + + if linear1_cls == Linear: + if not self._qkv_same_embed_dim: + self.q_proj_weight = Parameter( + torch.empty((embed_dim, embed_dim), **factory_kwargs) + ) + self.k_proj_weight = Parameter( + torch.empty((embed_dim, self.kdim), **factory_kwargs) + ) + self.v_proj_weight = Parameter( + torch.empty((embed_dim, self.vdim), **factory_kwargs) + ) + self.register_parameter("in_proj_weight", None) + else: + self.in_proj_weight = Parameter( + torch.empty((3 * embed_dim, embed_dim), **factory_kwargs) + ) + self.register_parameter("q_proj_weight", None) + self.register_parameter("k_proj_weight", None) + self.register_parameter("v_proj_weight", None) + + if bias: + self.in_proj_bias = Parameter( + torch.empty(3 * embed_dim, **factory_kwargs) + ) + else: + self.register_parameter("in_proj_bias", None) + self.out_proj = NonDynamicallyQuantizableLinear( + embed_dim, embed_dim, bias=bias, **factory_kwargs + ) + + self._reset_parameters() + else: + if not self._qkv_same_embed_dim: + raise NotImplementedError + else: + self.in_proj_linear = linear1_cls( + embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs + ) + self.in_proj_weight = self.in_proj_linear.weight + + self.register_parameter("q_proj_weight", None) + self.register_parameter("k_proj_weight", None) + self.register_parameter("v_proj_weight", None) + + if bias: + self.in_proj_bias = self.in_proj_linear.bias + else: + self.register_parameter("in_proj_bias", None) + + self.out_proj = linear2_cls( + embed_dim, embed_dim, bias=bias, **factory_kwargs + ) + + if self.bias_k is not None: + xavier_normal_(self.bias_k) + if self.bias_v is not None: + xavier_normal_(self.bias_v) + + self.add_zero_attn = add_zero_attn + + def _reset_parameters(self): + if self._qkv_same_embed_dim: + xavier_uniform_(self.in_proj_weight) + else: + xavier_uniform_(self.q_proj_weight) + xavier_uniform_(self.k_proj_weight) + xavier_uniform_(self.v_proj_weight) + + if self.in_proj_bias is not None: + constant_(self.in_proj_bias, 0.0) + constant_(self.out_proj.bias, 0.0) + + if self.bias_k is not None: + xavier_normal_(self.bias_k) + if self.bias_v is not None: + xavier_normal_(self.bias_v) + + def __setstate__(self, state): + # Support loading old MultiheadAttention checkpoints generated by v1.1.0 + if "_qkv_same_embed_dim" not in state: + state["_qkv_same_embed_dim"] = True + + super(MultiheadAttention, self).__setstate__(state) + + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + average_attn_weights: bool = True, + use_rope: bool = False, + rope = None, + ) -> Tuple[Tensor, Optional[Tensor]]: + r""" + Args: + query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False`` + or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length, + :math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``. + Queries are compared against key-value pairs to produce the output. + See "Attention Is All You Need" for more details. + key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False`` + or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length, + :math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``. + See "Attention Is All You Need" for more details. + value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when + ``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source + sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``. + See "Attention Is All You Need" for more details. + key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key`` + to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`. + Binary and byte masks are supported. + For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for + the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value. + need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``. + Default: ``True``. + attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape + :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size, + :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be + broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch. + Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the + corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the + corresponding position is not allowed to attend. For a float mask, the mask values will be added to + the attention weight. + average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across + heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an + effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads) + + Outputs: + - **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched, + :math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``, + where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the + embedding dimension ``embed_dim``. + - **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``, + returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or + :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and + :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per + head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`. + + .. note:: + `batch_first` argument is ignored for unbatched inputs. + """ + is_batched = query.dim() == 3 + if key_padding_mask is not None: + _kpm_dtype = key_padding_mask.dtype + if _kpm_dtype != torch.bool and not torch.is_floating_point( + key_padding_mask + ): + raise AssertionError( + "only bool and floating types of key_padding_mask are supported" + ) + why_not_fast_path = "" + if not is_batched: + why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}" + elif query is not key or key is not value: + # When lifting this restriction, don't forget to either + # enforce that the dtypes all match or test cases where + # they don't! + why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)" + elif ( + self.in_proj_bias is not None + and query.dtype != self.in_proj_bias.dtype + ): + why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match" + elif ( + self.in_proj_weight is not None + and query.dtype != self.in_proj_weight.dtype + ): + # this case will fail anyway, but at least they'll get a useful error message. + why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match" + elif self.training: + why_not_fast_path = "training is enabled" + elif not self.batch_first: + why_not_fast_path = "batch_first was not True" + elif self.bias_k is not None: + why_not_fast_path = "self.bias_k was not None" + elif self.bias_v is not None: + why_not_fast_path = "self.bias_v was not None" + elif self.dropout: + why_not_fast_path = f"dropout was {self.dropout}, required zero" + elif self.add_zero_attn: + why_not_fast_path = "add_zero_attn was enabled" + elif not self._qkv_same_embed_dim: + why_not_fast_path = "_qkv_same_embed_dim was not True" + elif attn_mask is not None: + why_not_fast_path = "attn_mask was not None" + elif query.is_nested and key_padding_mask is not None: + why_not_fast_path = ( + "key_padding_mask is not supported with NestedTensor input" + ) + elif self.num_heads % 2 == 1: + why_not_fast_path = "num_heads is odd" + elif torch.is_autocast_enabled(): + why_not_fast_path = "autocast is enabled" + + if not why_not_fast_path: + tensor_args = ( + query, + key, + value, + self.in_proj_weight, + self.in_proj_bias, + self.out_proj.weight, + self.out_proj.bias, + ) + # We have to use list comprehensions below because TorchScript does not support + # generator expressions. + if torch.overrides.has_torch_function(tensor_args): + why_not_fast_path = "some Tensor argument has_torch_function" + elif not all( + [ + (x is None or x.is_cuda or "cpu" in str(x.device)) + for x in tensor_args + ] + ): + why_not_fast_path = ( + "some Tensor argument is neither CUDA nor CPU" + ) + elif torch.is_grad_enabled() and any( + [x is not None and x.requires_grad for x in tensor_args] + ): + why_not_fast_path = ( + "grad is enabled and at least one of query or the " + "input/output projection weights or biases requires_grad" + ) + if not why_not_fast_path: + return torch._native_multi_head_attention( + query, + key, + value, + self.embed_dim, + self.num_heads, + self.in_proj_weight, + self.in_proj_bias, + self.out_proj.weight, + self.out_proj.bias, + key_padding_mask + if key_padding_mask is not None + else attn_mask, + need_weights, + average_attn_weights, + 1 + if key_padding_mask is not None + else 0 + if attn_mask is not None + else None, + ) + + any_nested = query.is_nested or key.is_nested or value.is_nested + assert not any_nested, ( + "MultiheadAttention does not support NestedTensor outside of its fast path. " + + f"The fast path was not hit because {why_not_fast_path}" + ) + x = query + if query.shape[1] == key.shape[1]: + attn_output, _ = multi_head_attention_forward( + x=x, + ipw=self.in_proj_weight, + ipb=self.in_proj_bias, + opw=self.out_proj.weight, + opb=self.out_proj.bias, + n_head=self.num_heads, + attn_mask=attn_mask, + dropout=self.dropout, + past_kv=None, + use_cache=False, + use_rope=use_rope, + rope=rope, + ) + else: + attn_output = multi_head_cross_attention_forward( + q=query, + k=key, + v=value, + ipw=self.in_proj_weight, + ipb=self.in_proj_bias, + opw=self.out_proj.weight, + opb=self.out_proj.bias, + n_head=self.num_heads, + attn_mask=attn_mask, + dropout=self.dropout, + past_kv=None, + use_cache=False, + use_rope=use_rope, + rope=rope, + ) + return attn_output, None + + def infer(self, + x: Tensor, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + average_attn_weights: bool = True, + past_kv = None, + use_cache = False, + use_rope = False, + rope = None + ): + # x = x.transpose(1, 0) + y, kv = multi_head_attention_forward( + x=x, + ipw=self.in_proj_weight, + ipb=self.in_proj_bias, + opw=self.out_proj.weight, + opb=self.out_proj.bias, + n_head=self.num_heads, + attn_mask=attn_mask, + past_kv=past_kv, + use_cache=use_cache, + use_rope=use_rope, + rope=rope, + ) + return (y, kv) + + def __repr__(self): + s = ( + f"MultiheadAttention(" + f"embed_dim={self.embed_dim}, " + f"num_heads={self.num_heads}, " + f"dropout={self.dropout}, " + f"bias={self.bias_k is not None}, " + f"add_bias_kv={self.bias_k is not None}, " + f"add_zero_attn={self.add_zero_attn}, " + f"kdim={self.kdim}, " + f"vdim={self.vdim}, " + f"batch_first={self.batch_first}" + ) + if self._qkv_same_embed_dim: + s += ", linear1_cls=Linear" + else: + s += ", linear1_cls=NotImplemented" + if self.bias_k is not None: + s += ", bias_k=Parameter containing:\n" + str(self.bias_k) + if self.bias_v is not None: + s += ", bias_v=Parameter containing:\n" + str(self.bias_v) + s += ")" + return s diff --git a/transformer_modules/conv.py b/transformer_modules/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..d565a6f4f23d2227c29cadd8473d93642d773995 --- /dev/null +++ b/transformer_modules/conv.py @@ -0,0 +1,157 @@ + +import torch +import torch.nn as nn + +class MultiLayeredConv1d(torch.nn.Module): + """Multi-layered conv1d for Transformer block. + + This is a module of multi-layered conv1d designed to replace position-wise feed-forward network + in Transformer block, which is introduced in `FastSpeech: Fast, Robust and Controllable Text to Speech`_. + + Args: + in_chans (int): Number of input channels. + hidden_chans (int): Number of hidden channels. + kernel_size (int): Kernel size of conv1d. + dropout_rate (float): Dropout rate. + + .. _`FastSpeech: Fast, Robust and Controllable Text to Speech`: + https://arxiv.org/pdf/1905.09263.pdf + + """ + + def __init__( + self, in_chans: int, hidden_chans: int, kernel_size: int, dropout_rate: float + ): + super(MultiLayeredConv1d, self).__init__() + self.w_1 = torch.nn.Conv1d( + in_chans, + hidden_chans, + kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + ) + self.w_2 = torch.nn.Conv1d( + hidden_chans, in_chans, 1, stride=1, padding=(1 - 1) // 2 + ) + self.dropout = torch.nn.Dropout(dropout_rate) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Calculate forward propagation. + + Args: + x (Tensor): Batch of input tensors (B, *, in_chans). + + Returns: + Tensor: Batch of output tensors (B, *, hidden_chans) + + """ + x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1) + return self.w_2(self.dropout(x).transpose(-1, 1)).transpose(-1, 1) + +class MultiLayeredConv1d(torch.nn.Module): + """Multi-layered conv1d for Transformer block. + + This is a module of multi-leyered conv1d designed to replace positionwise feed-forward network + in Transforner block, which is introduced in `FastSpeech: Fast, Robust and Controllable Text to Speech`_. + + Args: + in_chans (int): Number of input channels. + hidden_chans (int): Number of hidden channels. + kernel_size (int): Kernel size of conv1d. + dropout_rate (float): Dropout rate. + + .. _`FastSpeech: Fast, Robust and Controllable Text to Speech`: + https://arxiv.org/pdf/1905.09263.pdf + + """ + + def __init__( + self, in_chans: int, hidden_chans: int, kernel_size=5, dropout_rate=0.0, + ): + super(MultiLayeredConv1d, self).__init__() + self.w_1 = torch.nn.Conv1d( + in_chans, + hidden_chans, + kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + ) + self.w_2 = torch.nn.Conv1d( + hidden_chans, in_chans, 1, stride=1, padding=(1 - 1) // 2 + ) + self.dropout = torch.nn.Dropout(dropout_rate) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Calculate forward propagation. + + Args: + x (Tensor): Batch of input tensors (B, *, in_chans). + + Returns: + Tensor: Batch of output tensors (B, *, hidden_chans) + + """ + x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1) + return self.w_2(self.dropout(x).transpose(-1, 1)).transpose(-1, 1) + +class Swish(torch.nn.Module): + """ + Construct an Swish activation function for Conformer. + """ + + def forward(self, x): + """ + Return Swish activation function. + """ + return x * torch.sigmoid(x) +class ConvolutionModule(nn.Module): + """ + ConvolutionModule in Conformer model. + + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernel size of conv layers. + + """ + + def __init__(self, channels, kernel_size, activation=Swish(), ignore_prefix_len=0, bias=True): + super(ConvolutionModule, self).__init__() + # kernel_size should be an odd number for 'SAME' padding + assert (kernel_size - 1) % 2 == 0 + + self.pointwise_conv1 = nn.Conv1d(channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=bias, ) + self.depthwise_conv = nn.Conv1d(channels, channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2, groups=channels, bias=bias, ) + self.norm = nn.GroupNorm(num_groups=32, num_channels=channels) + self.pointwise_conv2 = nn.Conv1d(channels, channels, kernel_size=1, stride=1, padding=0, bias=bias, ) + self.activation = activation + self.ignore_prefix_len = ignore_prefix_len + + def forward(self, x): + """ + Compute convolution module. + + Args: + x (torch.Tensor): Input tensor (#batch, time, channels). + + Returns: + torch.Tensor: Output tensor (#batch, time, channels). + + """ + # exchange the temporal dimension and the feature dimension + x = x.transpose(1, 2) + + # GLU mechanism + x = self.pointwise_conv1(x) # (batch, 2*channel, dim) + x = nn.functional.glu(x, dim=1) # (batch, channel, dim) + + # 1D Depthwise Conv + x_sub = self.depthwise_conv(x[..., self.ignore_prefix_len:]) + x_sub = self.activation(self.norm(x_sub)) + x_pre = x[..., :self.ignore_prefix_len] + # x = self.depthwise_conv(x) + # x = self.activation(self.norm(x)) + x = torch.cat([x_pre, x_sub], dim=-1) + + x = self.pointwise_conv2(x) + + return x.transpose(1, 2) \ No newline at end of file diff --git a/transformer_modules/embedding.py b/transformer_modules/embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..15bf6d47882a9bd1c1b90fe5a77f08fbf4b75a37 --- /dev/null +++ b/transformer_modules/embedding.py @@ -0,0 +1,105 @@ +# Copyright 2023 (authors: Feiteng Li) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import torch +import torch.nn as nn + + +class TokenEmbedding(nn.Module): + def __init__( + self, + dim_model: int, + vocab_size: int, + dropout: float = 0.0, + ): + super().__init__() + + self.vocab_size = vocab_size + self.dim_model = dim_model + + self.dropout = torch.nn.Dropout(p=dropout) + self.word_embeddings = nn.Embedding(self.vocab_size, self.dim_model) + + @property + def weight(self) -> torch.Tensor: + return self.word_embeddings.weight + + def embedding(self, index: int) -> torch.Tensor: + return self.word_embeddings.weight[index : index + 1] + + def forward(self, x: torch.Tensor): + X = self.word_embeddings(x) + X = self.dropout(X) + + return X + + +class SinePositionalEmbedding(nn.Module): + def __init__( + self, + dim_model: int, + dropout: float = 0.0, + scale: bool = False, + alpha: bool = False, + ): + super().__init__() + self.dim_model = dim_model + self.x_scale = math.sqrt(dim_model) if scale else 1.0 + self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha) + self.dropout = torch.nn.Dropout(p=dropout) + + self.reverse = False + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, 4000)) + + def extend_pe(self, x): + """Reset the positional encodings.""" + if self.pe is not None: + if self.pe.size(1) >= x.size(1): + if self.pe.dtype != x.dtype or self.pe.device != x.device: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + pe = torch.zeros(x.size(1), self.dim_model) + if self.reverse: + position = torch.arange( + x.size(1) - 1, -1, -1.0, dtype=torch.float32 + ).unsqueeze(1) + else: + position = torch.arange( + 0, x.size(1), dtype=torch.float32 + ).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, self.dim_model, 2, dtype=torch.float32) + * -(math.log(10000.0) / self.dim_model) + ) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + self.pe = pe.to(device=x.device, dtype=x.dtype).detach() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + self.extend_pe(x) + output = x.unsqueeze(-1) if x.ndim == 2 else x + output = output * self.x_scale + self.alpha * self.pe[:, : x.size(1)] + return self.dropout(output) + + def infer(self, x, position_ids): + """ + infer only a single or a few tokens to save time + """ + output = x.unsqueeze(-1) if x.ndim == 2 else x + output = output * self.x_scale + self.alpha * self.pe[:, position_ids] + return self.dropout(output) \ No newline at end of file diff --git a/transformer_modules/optim.py b/transformer_modules/optim.py new file mode 100644 index 0000000000000000000000000000000000000000..0ab330f1eece785cf759be3fe393a36b056018ef --- /dev/null +++ b/transformer_modules/optim.py @@ -0,0 +1,1104 @@ +# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey) +# +# See ../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import logging +import random +from collections import defaultdict +from typing import List, Optional, Tuple, Union + +import torch +from torch import Tensor +from torch.optim import Optimizer + + +class BatchedOptimizer(Optimizer): + """ + This class adds to class Optimizer the capability to optimize parameters in batches: + it will stack the parameters and their grads for you so the optimizer can work + on tensors with an extra leading dimension. This is intended for speed with GPUs, + as it reduces the number of kernels launched in the optimizer. + + Args: + params: + """ + + def __init__(self, params, defaults): + super(BatchedOptimizer, self).__init__(params, defaults) + + @contextlib.contextmanager + def batched_params(self, param_group, group_params_names): + """ + This function returns (technically, yields) a list of + of tuples (p, state), where + p is a `fake` parameter that is stacked (over axis 0) from real parameters + that share the same shape, and its gradient is also stacked; + `state` is the state corresponding to this batch of parameters + (it will be physically located in the "state" for one of the real + parameters, the last one that has any particular shape and dtype). + + This function is decorated as a context manager so that it can + write parameters back to their "real" locations. + + The idea is, instead of doing: + + for p in group["params"]: + state = self.state[p] + ... + + you can do: + + with self.batched_params(group["params"]) as batches: + for p, state, p_names in batches: + ... + + + Args: + group: a parameter group, which is a list of parameters; should be + one of self.param_groups. + group_params_names: name for each parameter in group, + which is List[str]. + """ + batches = defaultdict( + list + ) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter + batches_names = defaultdict( + list + ) # `batches` maps from tuple (dtype_as_str,*shape) to list of str + + assert len(param_group) == len(group_params_names) + for p, named_p in zip(param_group, group_params_names): + key = (str(p.dtype), *p.shape) + batches[key].append(p) + batches_names[key].append(named_p) + + batches_names_keys = list(batches_names.keys()) + sorted_idx = sorted( + range(len(batches_names)), key=lambda i: batches_names_keys[i] + ) + batches_names = [ + batches_names[batches_names_keys[idx]] for idx in sorted_idx + ] + batches = [batches[batches_names_keys[idx]] for idx in sorted_idx] + + stacked_params_dict = dict() + + # turn batches into a list, in deterministic order. + # tuples will contain tuples of (stacked_param, state, stacked_params_names), + # one for each batch in `batches`. + tuples = [] + + for batch, batch_names in zip(batches, batches_names): + p = batch[0] + # we arbitrarily store the state in the + # state corresponding to the 1st parameter in the + # group. class Optimizer will take care of saving/loading state. + state = self.state[p] + p_stacked = torch.stack(batch) + grad = torch.stack( + [ + torch.zeros_like(p) if p.grad is None else p.grad + for p in batch + ] + ) + p_stacked.grad = grad + stacked_params_dict[key] = p_stacked + tuples.append((p_stacked, state, batch_names)) + + yield tuples # <-- calling code will do the actual optimization here! + + for ((stacked_params, _state, _names), batch) in zip(tuples, batches): + for i, p in enumerate(batch): # batch is list of Parameter + p.copy_(stacked_params[i]) + + +class ScaledAdam(BatchedOptimizer): + """ + Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update + proportional to the norm of that parameter; and also learn the scale of the parameter, + in log space, subject to upper and lower limits (as if we had factored each parameter as + param = underlying_param * log_scale.exp()) + + + Args: + params: The parameters or param_groups to optimize (like other Optimizer subclasses) + lr: The learning rate. We will typically use a learning rate schedule that starts + at 0.03 and decreases over time, i.e. much higher than other common + optimizers. + clipping_scale: (e.g. 2.0) + A scale for gradient-clipping: if specified, the normalized gradients + over the whole model will be clipped to have 2-norm equal to + `clipping_scale` times the median 2-norm over the most recent period + of `clipping_update_period` minibatches. By "normalized gradients", + we mean after multiplying by the rms parameter value for this tensor + [for non-scalars]; this is appropriate because our update is scaled + by this quantity. + betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad. + Must satisfy 0 < beta <= beta2 < 1. + scalar_lr_scale: A scaling factor on the learning rate, that we use to update the + scale of each parameter tensor and scalar parameters of the mode.. + If each parameter were decomposed + as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale + would be a the scaling factor on the learning rate of p_scale. + eps: A general-purpose epsilon to prevent division by zero + param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of + learning the scale on the parameters (we'll constrain the rms of each non-scalar + parameter tensor to be >= this value) + param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of + learning the scale on the parameters (we'll constrain the rms of each non-scalar + parameter tensor to be <= this value) + scalar_max: Maximum absolute value for scalar parameters (applicable if your + model has any parameters with numel() == 1). + size_update_period: The periodicity, in steps, with which we update the size (scale) + of the parameter tensor. This is provided to save a little time + in the update. + clipping_update_period: if clipping_scale is specified, this is the period + """ + + def __init__( + self, + params, + lr=3e-02, + clipping_scale=None, + betas=(0.9, 0.98), + scalar_lr_scale=0.1, + eps=1.0e-08, + param_min_rms=1.0e-05, + param_max_rms=3.0, + scalar_max=10.0, + size_update_period=4, + clipping_update_period=100, + parameters_names=None, + show_dominant_parameters=True, + ): + + assert parameters_names is not None, ( + "Please data parameters_names," + "which is a List[List[str]]. Each List[str] is for a group" + "and each str is for a parameter" + ) + defaults = dict( + lr=lr, + clipping_scale=clipping_scale, + betas=betas, + scalar_lr_scale=scalar_lr_scale, + eps=eps, + param_min_rms=param_min_rms, + param_max_rms=param_max_rms, + scalar_max=scalar_max, + size_update_period=size_update_period, + clipping_update_period=clipping_update_period, + ) + + super(ScaledAdam, self).__init__(params, defaults) + assert len(self.param_groups) == len(parameters_names) + self.parameters_names = parameters_names + self.show_dominant_parameters = show_dominant_parameters + + def __setstate__(self, state): + super(ScaledAdam, self).__setstate__(state) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + batch = True + + for group, group_params_names in zip( + self.param_groups, self.parameters_names + ): + + with self.batched_params( + group["params"], group_params_names + ) as batches: + + # batches is list of pairs (stacked_param, state). stacked_param is like + # a regular parameter, and will have a .grad, but the 1st dim corresponds to + # a stacking dim, it is not a real dim. + + if ( + len(batches[0][1]) == 0 + ): # if len(first state) == 0: not yet initialized + clipping_scale = 1 + else: + clipping_scale = self._get_clipping_scale(group, batches) + + for p, state, _ in batches: + # Perform optimization step. + # grad is not going to be None, we handled that when creating the batches. + grad = p.grad + if grad.is_sparse: + raise RuntimeError( + "ScaledAdam optimizer does not support sparse gradients" + ) + # State initialization + if len(state) == 0: + self._init_state(group, p, state) + + self._step_one_batch(group, p, state, clipping_scale) + + return loss + + def _init_state(self, group: dict, p: Tensor, state: dict): + """ + Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p + is actually the batch dimension, corresponding to batched-together + parameters of a given shape. + + + Args: + group: Dict to look up configuration values. + p: The parameter that we are initializing the state for + state: Dict from string to whatever state we are initializing + """ + size_update_period = group["size_update_period"] + + state["step"] = 0 + + kwargs = {"device": p.device, "dtype": p.dtype} + + # 'delta' implements conventional momentum. There are + # several different kinds of update going on, so rather than + # compute "exp_avg" like in Adam, we store and decay a + # parameter-change "delta", which combines all forms of + # update. this is equivalent to how it's done in Adam, + # except for the first few steps. + state["delta"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + batch_size = p.shape[0] + numel = p.numel() // batch_size + numel = p.numel() + + if numel > 1: + # "param_rms" just periodically records the scalar root-mean-square value of + # the parameter tensor. + # it has a shape like (batch_size, 1, 1, 1, 1) + param_rms = ( + (p ** 2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() + ) + state["param_rms"] = param_rms + + state["scale_exp_avg_sq"] = torch.zeros_like(param_rms) + state["scale_grads"] = torch.zeros( + size_update_period, *param_rms.shape, **kwargs + ) + + # exp_avg_sq is the weighted sum of scaled gradients. as in Adam. + state["exp_avg_sq"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + def _get_clipping_scale( + self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]] + ) -> float: + """ + Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients + by this amount before applying the rest of the update. + + Args: + group: the parameter group, an item in self.param_groups + tuples: a list of tuples of (param, state, param_names) + where param is a batched set of parameters, + with a .grad (1st dim is batch dim) + and state is the state-dict where optimization parameters are kept. + param_names is a List[str] while each str is name for a parameter + in batched set of parameters "param". + """ + assert len(tuples) >= 1 + clipping_scale = group["clipping_scale"] + (first_p, first_state, _) = tuples[0] + step = first_state["step"] + if clipping_scale is None or step == 0: + # no clipping. return early on step == 0 because the other + # parameters' state won't have been initialized yet. + return 1.0 + clipping_update_period = group["clipping_update_period"] + + tot_sumsq = torch.tensor(0.0, device=first_p.device) + for (p, state, param_names) in tuples: + grad = p.grad + if grad.is_sparse: + raise RuntimeError( + "ScaledAdam optimizer does not support sparse gradients" + ) + if p.numel() == p.shape[0]: # a batch of scalars + tot_sumsq += ( + grad ** 2 + ).sum() # sum() to change shape [1] to [] + else: + tot_sumsq += ((grad * state["param_rms"]) ** 2).sum() + + tot_norm = tot_sumsq.sqrt() + if "model_norms" not in first_state: + first_state["model_norms"] = torch.zeros( + clipping_update_period, device=p.device + ) + first_state["model_norms"][step % clipping_update_period] = tot_norm + + if step % clipping_update_period == 0: + # Print some stats. + # We don't reach here if step == 0 because we would have returned + # above. + sorted_norms = first_state["model_norms"].sort()[0].to("cpu") + quartiles = [] + for n in range(0, 5): + index = min( + clipping_update_period - 1, + (clipping_update_period // 4) * n, + ) + quartiles.append(sorted_norms[index].item()) + + median = quartiles[2] + threshold = clipping_scale * median + first_state["model_norm_threshold"] = threshold + percent_clipped = ( + first_state["num_clipped"] * 100.0 / clipping_update_period + if "num_clipped" in first_state + else 0.0 + ) + first_state["num_clipped"] = 0 + quartiles = " ".join(["%.3e" % x for x in quartiles]) + logging.info( + f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, " + f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}" + ) + + if step < clipping_update_period: + return 1.0 # We have not yet estimated a norm to clip to. + else: + try: + model_norm_threshold = first_state["model_norm_threshold"] + except KeyError: + logging.info( + "Warning: model_norm_threshold not in state: possibly " + "you changed config when restarting, adding clipping_scale option?" + ) + return 1.0 + ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item()) + if ans < 1.0: + first_state["num_clipped"] += 1 + if ans < 0.1: + logging.warn( + f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}" + ) + if self.show_dominant_parameters: + assert p.shape[0] == len(param_names) + self._show_gradient_dominating_parameter(tuples, tot_sumsq) + return ans + + def _show_gradient_dominating_parameter( + self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor + ): + """ + Show information of parameter wihch dominanting tot_sumsq. + + Args: + tuples: a list of tuples of (param, state, param_names) + where param is a batched set of parameters, + with a .grad (1st dim is batch dim) + and state is the state-dict where optimization parameters are kept. + param_names is a List[str] while each str is name for a parameter + in batched set of parameters "param". + tot_sumsq: sumsq of all parameters. Though it's could be calculated + from tuples, we still pass it to save some time. + """ + all_sumsq_orig = {} + for (p, state, batch_param_names) in tuples: + # p is a stacked batch parameters. + batch_grad = p.grad + if p.numel() == p.shape[0]: # a batch of scalars + batch_sumsq_orig = batch_grad ** 2 + # Dummpy values used by following `zip` statement. + batch_rms_orig = torch.ones(p.shape[0]) + else: + batch_rms_orig = state["param_rms"] + batch_sumsq_orig = ((batch_grad * batch_rms_orig) ** 2).sum( + dim=list(range(1, batch_grad.ndim)) + ) + + for name, sumsq_orig, rms, grad in zip( + batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad + ): + + proportion_orig = sumsq_orig / tot_sumsq + all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad) + + assert torch.isclose( + sum([value[0] for value in all_sumsq_orig.values()]).cpu(), + torch.tensor(1.0), + ) + sorted_by_proportion = { + k: v + for k, v in sorted( + all_sumsq_orig.items(), + key=lambda item: item[1][0], + reverse=True, + ) + } + dominant_param_name = next(iter(sorted_by_proportion)) + ( + dominant_proportion, + dominant_sumsq, + dominant_rms, + dominant_grad, + ) = sorted_by_proportion[dominant_param_name] + logging.info( + f"Parameter Dominanting tot_sumsq {dominant_param_name}" + f" with proportion {dominant_proportion:.2f}," + f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)" + f"={dominant_sumsq:.3e}," + f" grad_sumsq = {(dominant_grad**2).sum():.3e}," + f" orig_rms_sq={(dominant_rms**2).item():.3e}" + ) + + def _step_one_batch( + self, group: dict, p: Tensor, state: dict, clipping_scale: float + ): + """ + Do the step for one parameter, which is actually going to be a batch of + `real` parameters, with dim 0 as the batch dim. + Args: + group: dict to look up configuration values + p: parameter to update (actually multiple parameters stacked together + as a batch) + state: state-dict for p, to look up the optimizer state + """ + lr = group["lr"] + size_update_period = group["size_update_period"] + beta1 = group["betas"][0] + + grad = p.grad + if clipping_scale != 1.0: + grad = grad * clipping_scale + step = state["step"] + delta = state["delta"] + + delta.mul_(beta1) + batch_size = p.shape[0] + numel = p.numel() // batch_size + if numel > 1: + # Update the size/scale of p, and set param_rms + scale_grads = state["scale_grads"] + scale_grads[step % size_update_period] = (p * grad).sum( + dim=list(range(1, p.ndim)), keepdim=True + ) + if step % size_update_period == size_update_period - 1: + param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..) + param_rms.copy_( + (p ** 2) + .mean(dim=list(range(1, p.ndim)), keepdim=True) + .sqrt() + ) + if step > 0: + # self._size_update() learns the overall scale on the + # parameter, by shrinking or expanding it. + self._size_update(group, scale_grads, p, state) + + if numel == 1: + # For parameters with 1 element we just use regular Adam. + # Updates delta. + self._step_scalar(group, p, state) + else: + self._step(group, p, state) + + state["step"] = step + 1 + + def _size_update( + self, group: dict, scale_grads: Tensor, p: Tensor, state: dict + ) -> None: + """ + Called only where p.numel() > 1, this updates the scale of the parameter. + If we imagine: p = underlying_param * scale.exp(), and we are doing + gradient descent on underlying param and on scale, this function does the update + on `scale`. + + Args: + group: dict to look up configuration values + scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing + grads w.r.t. the scales. + p: The parameter to update + state: The state-dict of p + """ + + param_rms = state["param_rms"] + beta1, beta2 = group["betas"] + size_lr = group["lr"] * group["scalar_lr_scale"] + param_min_rms = group["param_min_rms"] + param_max_rms = group["param_max_rms"] + eps = group["eps"] + step = state["step"] + batch_size = p.shape[0] + + size_update_period = scale_grads.shape[0] + # correct beta2 for the size update period: we will have + # faster decay at this level. + beta2_corr = beta2 ** size_update_period + + scale_exp_avg_sq = state[ + "scale_exp_avg_sq" + ] # shape: (batch_size, 1, 1, ..) + scale_exp_avg_sq.mul_(beta2_corr).add_( + (scale_grads ** 2).mean( + dim=0 + ), # mean over dim `size_update_period` + alpha=1 - beta2_corr, + ) # shape is (batch_size, 1, 1, ...) + + # The 1st time we reach here is when size_step == 1. + size_step = (step + 1) // size_update_period + bias_correction2 = 1 - beta2_corr ** size_step + # we don't bother with bias_correction1; this will help prevent divergence + # at the start of training. + + denom = scale_exp_avg_sq.sqrt() + eps + + scale_step = ( + -size_lr + * (bias_correction2 ** 0.5) + * scale_grads.sum(dim=0) + / denom + ) + + is_too_small = param_rms < param_min_rms + is_too_large = param_rms > param_max_rms + + # when the param gets too small, just don't shrink it any further. + scale_step.masked_fill_(is_too_small, 0.0) + # when it gets too large, stop it from getting any larger. + scale_step.masked_fill_(is_too_large, -size_lr * size_update_period) + delta = state["delta"] + # the factor of (1-beta1) relates to momentum. + delta.add_(p * scale_step, alpha=(1 - beta1)) + + def _step(self, group: dict, p: Tensor, state: dict): + """ + This function does the core update of self.step(), in the case where the members of + the batch have more than 1 element. + + Args: + group: A dict which will be used to look up configuration values + p: The parameter to be updated + grad: The grad of p + state: The state-dict corresponding to parameter p + + This function modifies p. + """ + grad = p.grad + lr = group["lr"] + beta1, beta2 = group["betas"] + eps = group["eps"] + param_min_rms = group["param_min_rms"] + step = state["step"] + + exp_avg_sq = state["exp_avg_sq"] + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2)) + + this_step = state["step"] - ( + state["zero_step"] if "zero_step" in state else 0 + ) + bias_correction2 = 1 - beta2 ** (this_step + 1) + if bias_correction2 < 0.99: + # note: not in-place. + exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2) + + denom = exp_avg_sq.sqrt() + denom += eps + grad = grad / denom + + alpha = -lr * (1 - beta1) * state["param_rms"].clamp(min=param_min_rms) + + delta = state["delta"] + delta.add_(grad * alpha) + p.add_(delta) + + def _step_scalar(self, group: dict, p: Tensor, state: dict): + """ + A simplified form of the core update for scalar tensors, where we cannot get a good + estimate of the parameter rms. + """ + beta1, beta2 = group["betas"] + scalar_max = group["scalar_max"] + eps = group["eps"] + lr = group["lr"] * group["scalar_lr_scale"] + grad = p.grad + + exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + # bias_correction2 is like in Adam. Don't bother with bias_correction1; + # slower update at the start will help stability anyway. + bias_correction2 = 1 - beta2 ** (state["step"] + 1) + denom = (exp_avg_sq / bias_correction2).sqrt() + eps + + delta = state["delta"] + delta.add_(grad / denom, alpha=-lr * (1 - beta1)) + p.clamp_(min=-scalar_max, max=scalar_max) + p.add_(delta) + + +class LRScheduler(object): + """ + Base-class for learning rate schedulers where the learning-rate depends on both the + batch and the epoch. + """ + + def __init__(self, optimizer: Optimizer, verbose: bool = False): + # Attach optimizer + if not isinstance(optimizer, Optimizer): + raise TypeError( + "{} is not an Optimizer".format(type(optimizer).__name__) + ) + self.optimizer = optimizer + self.verbose = verbose + + for group in optimizer.param_groups: + group.setdefault("base_lr", group["lr"]) + + self.base_lrs = [group["base_lr"] for group in optimizer.param_groups] + + self.epoch = 0 + self.batch = 0 + + def state_dict(self): + """Returns the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the optimizer. + """ + return { + "base_lrs": self.base_lrs, + "epoch": self.epoch, + "batch": self.batch, + } + + def load_state_dict(self, state_dict): + """Loads the schedulers state. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + self.__dict__.update(state_dict) + + def get_last_lr(self) -> List[float]: + """Return last computed learning rate by current scheduler. Will be a list of float.""" + return self._last_lr + + def get_lr(self): + # Compute list of learning rates from self.epoch and self.batch and + # self.base_lrs; this must be overloaded by the user. + # e.g. return [some_formula(self.batch, self.epoch, base_lr) for base_lr in self.base_lrs ] + raise NotImplementedError + + def step_batch(self, batch: Optional[int] = None) -> None: + # Step the batch index, or just set it. If `batch` is specified, it + # must be the batch index from the start of training, i.e. summed over + # all epochs. + # You can call this in any order; if you don't provide 'batch', it should + # of course be called once per batch. + if batch is not None: + self.batch = batch + else: + self.batch = self.batch + 1 + self._set_lrs() + + def step_epoch(self, epoch: Optional[int] = None): + # Step the epoch index, or just set it. If you provide the 'epoch' arg, + # you should call this at the start of the epoch; if you don't provide the 'epoch' + # arg, you should call it at the end of the epoch. + if epoch is not None: + self.epoch = epoch + else: + self.epoch = self.epoch + 1 + self._set_lrs() + + def _set_lrs(self): + values = self.get_lr() + assert len(values) == len(self.optimizer.param_groups) + + for i, data in enumerate(zip(self.optimizer.param_groups, values)): + param_group, lr = data + param_group["lr"] = lr + self.print_lr(self.verbose, i, lr) + self._last_lr = [group["lr"] for group in self.optimizer.param_groups] + + def print_lr(self, is_verbose, group, lr): + """Display the current learning rate.""" + if is_verbose: + logging.info( + f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate" + f" of group {group} to {lr:.4e}." + ) + + +class Eden(LRScheduler): + """ + Eden scheduler. + The basic formula (before warmup) is: + lr = base_lr * (((batch**2 + lr_batches**2) / lr_batches**2) ** -0.25 * + (((epoch**2 + lr_epochs**2) / lr_epochs**2) ** -0.25)) * warmup + where `warmup` increases from linearly 0.5 to 1 over `warmup_batches` batches + and then stays constant at 1. + + + E.g. suggest base_lr = 0.04 (passed to optimizer) if used with ScaledAdam + + Args: + optimizer: the optimizer to change the learning rates on + lr_batches: the number of batches after which we start significantly + decreasing the learning rate, suggest 5000. + lr_epochs: the number of epochs after which we start significantly + decreasing the learning rate, suggest 6 if you plan to do e.g. + 20 to 40 epochs, but may need smaller number if dataset is huge + and you will do few epochs. + """ + + def __init__( + self, + optimizer: Optimizer, + lr_batches: Union[int, float], + lr_epochs: Union[int, float], + warmup_batches: Union[int, float] = 500.0, + verbose: bool = False, + ): + super(Eden, self).__init__(optimizer, verbose) + self.lr_batches = lr_batches + self.lr_epochs = lr_epochs + self.warmup_batches = warmup_batches + + def get_lr(self): + factor = ( + (self.batch ** 2 + self.lr_batches ** 2) / self.lr_batches ** 2 + ) ** -0.25 * ( + ((self.epoch ** 2 + self.lr_epochs ** 2) / self.lr_epochs ** 2) + ** -0.25 + ) + warmup_factor = ( + 1.0 + if self.batch >= self.warmup_batches + else 0.5 + 0.5 * (self.batch / self.warmup_batches) + ) + + return [x * factor * warmup_factor for x in self.base_lrs] + + +def _test_eden(): + m = torch.nn.Linear(100, 100) + optim = ScaledAdam(m.parameters(), lr=0.03) + + scheduler = Eden(optim, lr_batches=100, lr_epochs=2, verbose=True) + + for epoch in range(10): + scheduler.step_epoch(epoch) # sets epoch to `epoch` + + for step in range(20): + x = torch.randn(200, 100).detach() + x.requires_grad = True + y = m(x) + dy = torch.randn(200, 100).detach() + f = (y * dy).sum() + f.backward() + + optim.step() + scheduler.step_batch() + optim.zero_grad() + + logging.info(f"last lr = {scheduler.get_last_lr()}") + logging.info(f"state dict = {scheduler.state_dict()}") + + +# This is included mostly as a baseline for ScaledAdam. +class Eve(Optimizer): + """ + Implements Eve algorithm. This is a modified version of AdamW with a special + way of setting the weight-decay / shrinkage-factor, which is designed to make the + rms of the parameters approach a particular target_rms (default: 0.1). This is + for use with networks with 'scaled' versions of transformer_modules (see scaling.py), which + will be close to invariant to the absolute scale on the parameter matrix. + + The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. + The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. + Eve is unpublished so far. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 3e-4; + this value means that the weight would decay significantly after + about 3k minibatches. Is not multiplied by learning rate, but + is conditional on RMS-value of parameter being > target_rms. + target_rms (float, optional): target root-mean-square value of + parameters, if they fall below this we will stop applying weight decay. + + + .. _Adam: A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.98), + eps=1e-8, + weight_decay=1e-3, + target_rms=0.1, + ): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError( + "Invalid beta parameter at index 0: {}".format(betas[0]) + ) + if not 0.0 <= betas[1] < 1.0: + raise ValueError( + "Invalid beta parameter at index 1: {}".format(betas[1]) + ) + if not 0 <= weight_decay <= 0.1: + raise ValueError( + "Invalid weight_decay value: {}".format(weight_decay) + ) + if not 0 < target_rms <= 10.0: + raise ValueError("Invalid target_rms value: {}".format(target_rms)) + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + target_rms=target_rms, + ) + super(Eve, self).__init__(params, defaults) + + def __setstate__(self, state): + super(Eve, self).__setstate__(state) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + + # Perform optimization step + grad = p.grad + if grad.is_sparse: + raise RuntimeError( + "AdamW does not support sparse gradients" + ) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state["step"] = 0 + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + + beta1, beta2 = group["betas"] + + state["step"] += 1 + bias_correction1 = 1 - beta1 ** state["step"] + bias_correction2 = 1 - beta2 ** state["step"] + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_( + group["eps"] + ) + + step_size = group["lr"] / bias_correction1 + target_rms = group["target_rms"] + weight_decay = group["weight_decay"] + + if p.numel() > 1: + # avoid applying this weight-decay on "scaling factors" + # (which are scalar). + is_above_target_rms = p.norm() > ( + target_rms * (p.numel() ** 0.5) + ) + p.mul_(1 - (weight_decay * is_above_target_rms)) + + p.addcdiv_(exp_avg, denom, value=-step_size) + + # if random.random() < 0.0005: + # step = (exp_avg / denom) * step_size + # logging.info( + # f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}" + # ) + + return loss + + +def _test_scaled_adam(hidden_dim: int): + import timeit + + from scaling import ScaledLinear + + E = 100 + B = 4 + T = 2 + logging.info("in test_eve_cain") + # device = torch.device('cuda') + device = torch.device("cpu") + dtype = torch.float32 + + fix_random_seed(42) + # these input_magnitudes and output_magnitudes are to test that + # Abel is working as we expect and is able to adjust scales of + # different dims differently. + input_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp() + output_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp() + + for iter in [1, 0]: + fix_random_seed(42) + Linear = torch.nn.Linear if iter == 0 else ScaledLinear + + m = torch.nn.Sequential( + Linear(E, hidden_dim), + torch.nn.PReLU(), + Linear(hidden_dim, hidden_dim), + torch.nn.PReLU(), + Linear(hidden_dim, E), + ).to(device) + + train_pairs = [ + ( + 100.0 + * torch.randn(B, T, E, device=device, dtype=dtype) + * input_magnitudes, + torch.randn(B, T, E, device=device, dtype=dtype) + * output_magnitudes, + ) + for _ in range(20) + ] + + if iter == 0: + optim = Eve(m.parameters(), lr=0.003) + elif iter == 1: + optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0) + scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False) + + start = timeit.default_timer() + avg_loss = 0.0 + for epoch in range(180): + scheduler.step_epoch() + # if epoch == 100 and iter in [2,3]: + # optim.reset_speedup() # check it doesn't crash. + + # if epoch == 130: + # opts = diagnostics.TensorDiagnosticOptions( + # 2 ** 22 + # ) # allow 4 megabytes per sub-module + # diagnostic = diagnostics.attach_diagnostics(m, opts) + + for n, (x, y) in enumerate(train_pairs): + y_out = m(x) + loss = ((y_out - y) ** 2).mean() * 100.0 + if epoch == 0 and n == 0: + avg_loss = loss.item() + else: + avg_loss = 0.98 * avg_loss + 0.02 * loss.item() + if n == 0 and epoch % 5 == 0: + # norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item() + # norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item() + # norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item() + # norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item() + # scale1 = '%.2e' % (m[0].weight_scale.exp().item()) + # scale1b = '%.2e' % (m[0].bias_scale.exp().item()) + # scale2 = '%.2e' % (m[2].weight_scale.exp().item()) + # scale2b = '%.2e' % (m[2].bias_scale.exp().item()) + lr = scheduler.get_last_lr()[0] + logging.info( + f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}" + ) # , norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b} + loss.log().backward() + optim.step() + optim.zero_grad() + scheduler.step_batch() + + # diagnostic.print_diagnostics() + + stop = timeit.default_timer() + logging.info(f"Iter={iter}, Time taken: {stop - start}") + + logging.info(f"last lr = {scheduler.get_last_lr()}") + # logging.info("state dict = ", scheduler.state_dict()) + # logging.info("optim state_dict = ", optim.state_dict()) + logging.info(f"input_magnitudes = {input_magnitudes}") + logging.info(f"output_magnitudes = {output_magnitudes}") + + +if __name__ == "__main__": + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + logging.getLogger().setLevel(logging.INFO) + import subprocess + + s = subprocess.check_output( + "git status -uno .; git log -1; git diff HEAD .", shell=True + ) + logging.info(s) + import sys + + if len(sys.argv) > 1: + hidden_dim = int(sys.argv[1]) + else: + hidden_dim = 200 + + _test_scaled_adam(hidden_dim) + _test_eden() diff --git a/transformer_modules/rotary_embedding.py b/transformer_modules/rotary_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..4abad0a897d73efb1a06363956c6be837592a48a --- /dev/null +++ b/transformer_modules/rotary_embedding.py @@ -0,0 +1,220 @@ +from math import pi, log + +import torch +from torch import nn, einsum + +from einops import rearrange, repeat + +# helper functions + +def exists(val): + return val is not None + +def default(val, d): + return val if exists(val) else d + +def broadcat(tensors, dim = -1): + num_tensors = len(tensors) + shape_lens = set(list(map(lambda t: len(t.shape), tensors))) + assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions' + shape_len = list(shape_lens)[0] + + dim = (dim + shape_len) if dim < 0 else dim + dims = list(zip(*map(lambda t: list(t.shape), tensors))) + + expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] + assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation' + max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims)) + expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims)) + expanded_dims.insert(dim, (dim, dims[dim])) + expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims))) + tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes))) + return torch.cat(tensors, dim = dim) + +# rotary embedding helper functions + +def rotate_half(x): + x = rearrange(x, '... (d r) -> ... d r', r = 2) + x1, x2 = x.unbind(dim = -1) + x = torch.stack((-x2, x1), dim = -1) + return rearrange(x, '... d r -> ... (d r)') + +def apply_rotary_emb(freqs, t, start_index = 0, scale = 1., seq_dim = -2): + rot_dim, seq_len = freqs.shape[-1], t.shape[seq_dim] + freqs = freqs[-seq_len:].to(t) + + end_index = start_index + rot_dim + assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}' + t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:] + t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale) + return torch.cat((t_left, t, t_right), dim = -1) + +# learned rotation helpers + +def apply_learned_rotations(rotations, t, start_index = 0, freq_ranges = None): + if exists(freq_ranges): + rotations = einsum('..., f -> ... f', rotations, freq_ranges) + rotations = rearrange(rotations, '... r f -> ... (r f)') + + rotations = repeat(rotations, '... n -> ... (n r)', r = 2) + return apply_rotary_emb(rotations, t, start_index = start_index) + +# classes + +class RotaryEmbedding(nn.Module): + def __init__( + self, + dim, + custom_freqs = None, + freqs_for = 'lang', + theta = 10000, + max_freq = 10, + num_freqs = 1, + learned_freq = False, + use_xpos = False, + xpos_scale_base = 512, + interpolate_factor = 1., + theta_rescale_factor = 1., + seq_before_head_dim = False + ): + super().__init__() + # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning + # has some connection to NTK literature + # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ + + theta *= theta_rescale_factor ** (dim / (dim - 2)) + + if exists(custom_freqs): + freqs = custom_freqs + elif freqs_for == 'lang': + freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) + elif freqs_for == 'pixel': + freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi + elif freqs_for == 'constant': + freqs = torch.ones(num_freqs).float() + else: + raise ValueError(f'unknown modality {freqs_for}') + + self.cache = dict() + self.cache_scale = dict() + self.register_buffer('freqs', freqs) + + # default sequence dimension + + self.seq_before_head_dim = seq_before_head_dim + self.default_seq_dim = -3 if seq_before_head_dim else -2 + + # interpolation factors + + assert interpolate_factor >= 1. + self.interpolate_factor = interpolate_factor + + # xpos + + self.use_xpos = use_xpos + if not use_xpos: + self.register_buffer('scale', None) + return + + scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim) + self.scale_base = xpos_scale_base + self.register_buffer('scale', scale) + + def get_seq_pos(self, seq_len, device, dtype, offset = 0): + return (torch.arange(seq_len, device = device, dtype = dtype) + offset) / self.interpolate_factor + + def rotate_queries_or_keys(self, t, seq_dim = None, offset = 0, freq_seq_len = None): + seq_dim = default(seq_dim, self.default_seq_dim) + + assert not self.use_xpos, 'you must use `.rotate_queries_and_keys` method instead and pass in both queries and keys, for length extrapolatable rotary embeddings' + + device, dtype, seq_len = t.device, t.dtype, t.shape[seq_dim] + + if exists(freq_seq_len): + assert freq_seq_len >= seq_len + seq_len = freq_seq_len + + freqs = self.forward(lambda: self.get_seq_pos(seq_len, device = device, dtype = dtype, offset = offset), cache_key = f'freqs:{seq_len}|offset:{offset}') + + if seq_dim == -3: + freqs = rearrange(freqs, 'n d -> n 1 d') + + return apply_rotary_emb(freqs, t, seq_dim = seq_dim) + + def rotate_queries_with_cached_keys(self, q, k, seq_dim = None, offset = 0): + seq_dim = default(seq_dim, self.default_seq_dim) + + q_len, k_len = q.shape[seq_dim], k.shape[seq_dim] + assert q_len <= k_len + q = self.rotate_queries_or_keys(q, seq_dim = seq_dim, freq_seq_len = k_len) + k = self.rotate_queries_or_keys(k, seq_dim = seq_dim) + return q, k + + def rotate_queries_and_keys(self, q, k, seq_dim = None, pid = None): + seq_dim = default(seq_dim, self.default_seq_dim) + + assert self.use_xpos + device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim] + + if pid is None: + seq = self.get_seq_pos(seq_len, dtype = dtype, device = device) + freqs = self.forward(lambda: seq, cache_key = f'freqs:{seq_len}') + scale = self.get_scale(lambda: seq, cache_key = f'scale:{seq_len}').to(dtype) + + if seq_dim == -3: + freqs = rearrange(freqs, 'n d -> n 1 d') + scale = rearrange(scale, 'n d -> n 1 d') + + rotated_q = apply_rotary_emb(freqs, q, scale = scale, seq_dim = seq_dim) + rotated_k = apply_rotary_emb(freqs, k, scale = scale ** -1, seq_dim = seq_dim) + else: + # if pid (position id) is given, it indicates kv cache is used + # pid equals to len(k_cache) in this case + seq = self.get_seq_pos(1, dtype = dtype, device = device) + pid + freqs = self.forward(lambda: seq, cache_key = None) + scale = self.get_scale(lambda: seq, cache_key = None).to(dtype) + + if seq_dim == -3: + freqs = rearrange(freqs, 'n d -> n 1 d') + scale = rearrange(scale, 'n d -> n 1 d') + + rotated_q = apply_rotary_emb(freqs, q, scale = scale, seq_dim = seq_dim) + rotated_k = apply_rotary_emb(freqs, k, scale = scale ** -1, seq_dim = seq_dim) + return rotated_q, rotated_k + + def get_scale(self, t, cache_key = None): + assert self.use_xpos + + if exists(cache_key) and cache_key in self.cache: + return self.cache[cache_key] + + if callable(t): + t = t() + + scale = 1. + if self.use_xpos: + power = (t - len(t) // 2) / self.scale_base + scale = self.scale ** rearrange(power, 'n -> n 1') + scale = torch.cat((scale, scale), dim = -1) + + if exists(cache_key): + self.cache[cache_key] = scale + + return scale + + def forward(self, t, cache_key = None): + if exists(cache_key) and cache_key in self.cache: + return self.cache[cache_key] + + if callable(t): + t = t() + + freqs = self.freqs + + freqs = einsum('..., f -> ... f', t.type(freqs.dtype), freqs) + freqs = repeat(freqs, '... n -> ... (n r)', r = 2) + + if exists(cache_key): + self.cache[cache_key] = freqs + + return freqs diff --git a/transformer_modules/scaling.py b/transformer_modules/scaling.py new file mode 100644 index 0000000000000000000000000000000000000000..e008236132b0cef6c849740716d61d85f4a6625c --- /dev/null +++ b/transformer_modules/scaling.py @@ -0,0 +1,1399 @@ +# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import collections +import logging +import random +import math +from functools import reduce +from itertools import repeat +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.nn import Embedding as ScaledEmbedding + + +class ActivationBalancerFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: Tensor, + scale_factor: Tensor, + sign_factor: Optional[Tensor], + channel_dim: int, + ) -> Tensor: + if channel_dim < 0: + channel_dim += x.ndim + ctx.channel_dim = channel_dim + xgt0 = x > 0 + if sign_factor is None: + ctx.save_for_backward(xgt0, scale_factor) + else: + ctx.save_for_backward(xgt0, scale_factor, sign_factor) + return x + + @staticmethod + def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]: + if len(ctx.saved_tensors) == 3: + xgt0, scale_factor, sign_factor = ctx.saved_tensors + for _ in range(ctx.channel_dim, x_grad.ndim - 1): + scale_factor = scale_factor.unsqueeze(-1) + sign_factor = sign_factor.unsqueeze(-1) + factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5) + else: + xgt0, scale_factor = ctx.saved_tensors + for _ in range(ctx.channel_dim, x_grad.ndim - 1): + scale_factor = scale_factor.unsqueeze(-1) + factor = scale_factor * (xgt0.to(x_grad.dtype) - 0.5) + neg_delta_grad = x_grad.abs() * factor + return ( + x_grad - neg_delta_grad, + None, + None, + None, + ) + + +def _compute_scale_factor( + x: Tensor, + channel_dim: int, + min_abs: float, + max_abs: float, + gain_factor: float, + max_factor: float, +) -> Tensor: + if channel_dim < 0: + channel_dim += x.ndim + sum_dims = [d for d in range(x.ndim) if d != channel_dim] + x_abs_mean = torch.mean(x.abs(), dim=sum_dims).to(torch.float32) + + if min_abs == 0.0: + below_threshold = 0.0 + else: + # below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if + # x_abs)_mean , min_abs. + below_threshold = ( + (min_abs - x_abs_mean) * (gain_factor / min_abs) + ).clamp(min=0, max=max_factor) + + above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp( + min=0, max=max_factor + ) + + return below_threshold - above_threshold + + +def _compute_sign_factor( + x: Tensor, + channel_dim: int, + min_positive: float, + max_positive: float, + gain_factor: float, + max_factor: float, +) -> Tensor: + if channel_dim < 0: + channel_dim += x.ndim + sum_dims = [d for d in range(x.ndim) if d != channel_dim] + proportion_positive = torch.mean((x > 0).to(torch.float32), dim=sum_dims) + if min_positive == 0.0: + factor1 = 0.0 + else: + # 0 if proportion_positive >= min_positive, else can be + # as large as max_factor. + factor1 = ( + (min_positive - proportion_positive) * (gain_factor / min_positive) + ).clamp_(min=0, max=max_factor) + + if max_positive == 1.0: + factor2 = 0.0 + else: + # 0 if self.proportion_positive <= max_positive, else can be + # as large as -max_factor. + factor2 = ( + (proportion_positive - max_positive) + * (gain_factor / (1.0 - max_positive)) + ).clamp_(min=0, max=max_factor) + sign_factor = factor1 - factor2 + # require min_positive != 0 or max_positive != 1: + assert not isinstance(sign_factor, float) + return sign_factor + + +class ActivationScaleBalancerFunction(torch.autograd.Function): + """ + This object is used in class ActivationBalancer when the user specified + min_positive=0, max_positive=1, so there are no constraints on the signs + of the activations and only the absolute value has a constraint. + """ + + @staticmethod + def forward( + ctx, + x: Tensor, + sign_factor: Tensor, + scale_factor: Tensor, + channel_dim: int, + ) -> Tensor: + if channel_dim < 0: + channel_dim += x.ndim + ctx.channel_dim = channel_dim + xgt0 = x > 0 + ctx.save_for_backward(xgt0, sign_factor, scale_factor) + return x + + @staticmethod + def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]: + xgt0, sign_factor, scale_factor = ctx.saved_tensors + for _ in range(ctx.channel_dim, x_grad.ndim - 1): + sign_factor = sign_factor.unsqueeze(-1) + scale_factor = scale_factor.unsqueeze(-1) + + factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5) + neg_delta_grad = x_grad.abs() * factor + return ( + x_grad - neg_delta_grad, + None, + None, + None, + ) + + +class RandomClampFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: Tensor, + min: Optional[float], + max: Optional[float], + prob: float, + reflect: float, + ) -> Tensor: + x_clamped = torch.clamp(x, min=min, max=max) + mask = torch.rand_like(x) < prob + ans = torch.where(mask, x_clamped, x) + if x.requires_grad: + ctx.save_for_backward(ans == x) + ctx.reflect = reflect + if reflect != 0.0: + ans = ans * (1.0 + reflect) - (x * reflect) + return ans + + @staticmethod + def backward( + ctx, ans_grad: Tensor + ) -> Tuple[Tensor, None, None, None, None]: + (is_same,) = ctx.saved_tensors + x_grad = ans_grad * is_same.to(ans_grad.dtype) + reflect = ctx.reflect + if reflect != 0.0: + x_grad = x_grad * (1.0 + reflect) - (ans_grad * reflect) + return x_grad, None, None, None, None + + +def random_clamp( + x: Tensor, + min: Optional[float] = None, + max: Optional[float] = None, + prob: float = 0.5, + reflect: float = 0.0, +): + return RandomClampFunction.apply(x, min, max, prob, reflect) + + +def random_cast_to_half(x: Tensor, min_abs: float = 5.0e-06) -> Tensor: + """ + A randomized way of casting a floating point value to half precision. + """ + if x.dtype == torch.float16: + return x + x_abs = x.abs() + is_too_small = x_abs < min_abs + # for elements where is_too_small is true, random_val will contain +-min_abs with + # probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations, + # for those elements]. + random_val = min_abs * x.sign() * (torch.rand_like(x) * min_abs < x_abs) + return torch.where(is_too_small, random_val, x).to(torch.float16) + + +class RandomGradFunction(torch.autograd.Function): + """ + Does nothing in forward pass; in backward pass, gets rid of very small grads using + randomized approach that preserves expectations (intended to reduce roundoff). + """ + + @staticmethod + def forward(ctx, x: Tensor, min_abs: float) -> Tensor: + ctx.min_abs = min_abs + return x + + @staticmethod + def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None]: + if ans_grad.dtype == torch.float16: + return ( + random_cast_to_half( + ans_grad.to(torch.float32), min_abs=ctx.min_abs + ), + None, + ) + else: + return ans_grad, None + + +class RandomGrad(torch.nn.Module): + """ + Gets rid of very small gradients using an expectation-preserving method, intended to increase + accuracy of training when using amp (automatic mixed precision) + """ + + def __init__(self, min_abs: float = 5.0e-06): + super(RandomGrad, self).__init__() + self.min_abs = min_abs + + def forward(self, x: Tensor): + if ( + torch.jit.is_scripting() + or not self.training + or torch.jit.is_tracing() + ): + return x + else: + return RandomGradFunction.apply(x, self.min_abs) + + +class SoftmaxFunction(torch.autograd.Function): + """ + Tries to handle half-precision derivatives in a randomized way that should + be more accurate for training than the default behavior. + """ + + @staticmethod + def forward(ctx, x: Tensor, dim: int): + ans = x.softmax(dim=dim) + # if x dtype is float16, x.softmax() returns a float32 because + # (presumably) that op does not support float16, and autocast + # is enabled. + if torch.is_autocast_enabled(): + ans = ans.to(torch.float16) + ctx.save_for_backward(ans) + ctx.x_dtype = x.dtype + ctx.dim = dim + return ans + + @staticmethod + def backward(ctx, ans_grad: Tensor): + (ans,) = ctx.saved_tensors + with torch.cuda.amp.autocast(enabled=False): + ans_grad = ans_grad.to(torch.float32) + ans = ans.to(torch.float32) + x_grad = ans_grad * ans + x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True) + return x_grad, None + + +def softmax(x: Tensor, dim: int): + if torch.jit.is_scripting() or torch.jit.is_tracing(): + return x.softmax(dim) + + return SoftmaxFunction.apply(x, dim) + + +class MaxEigLimiterFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: Tensor, + coeffs: Tensor, + direction: Tensor, + channel_dim: int, + grad_scale: float, + ) -> Tensor: + ctx.channel_dim = channel_dim + ctx.grad_scale = grad_scale + ctx.save_for_backward(x.detach(), coeffs.detach(), direction.detach()) + return x + + @staticmethod + def backward(ctx, x_grad, *args): + with torch.enable_grad(): + (x_orig, coeffs, new_direction) = ctx.saved_tensors + x_orig.requires_grad = True + num_channels = x_orig.shape[ctx.channel_dim] + x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels) + new_direction.requires_grad = False + x = x - x.mean(dim=0) + x_var = (x ** 2).mean() + x_residual = x - coeffs * new_direction + x_residual_var = (x_residual ** 2).mean() + # `variance_proportion` is the proportion of the variance accounted for + # by the top eigen-direction. This is to be minimized. + variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20) + variance_proportion.backward() + x_orig_grad = x_orig.grad + x_extra_grad = ( + x_orig.grad + * ctx.grad_scale + * x_grad.norm() + / (x_orig_grad.norm() + 1.0e-20) + ) + return x_grad + x_extra_grad.detach(), None, None, None, None + + +class BasicNorm(torch.nn.Module): + """ + This is intended to be a simpler, and hopefully cheaper, replacement for + LayerNorm. The observation this is based on, is that Transformer-type + networks, especially with pre-norm, sometimes seem to set one of the + feature dimensions to a large constant value (e.g. 50), which "defeats" + the LayerNorm because the output magnitude is then not strongly dependent + on the other (useful) features. Presumably the weight and bias of the + LayerNorm are required to allow it to do this. + + So the idea is to introduce this large constant value as an explicit + parameter, that takes the role of the "eps" in LayerNorm, so the network + doesn't have to do this trick. We make the "eps" learnable. + + Args: + num_channels: the number of channels, e.g. 512. + channel_dim: the axis/dimension corresponding to the channel, + interprted as an offset from the input's ndim if negative. + shis is NOT the num_channels; it should typically be one of + {-2, -1, 0, 1, 2, 3}. + eps: the initial "epsilon" that we add as ballast in: + scale = ((input_vec**2).mean() + epsilon)**-0.5 + Note: our epsilon is actually large, but we keep the name + to indicate the connection with conventional LayerNorm. + learn_eps: if true, we learn epsilon; if false, we keep it + at the initial value. + eps_min: float + eps_max: float + """ + + def __init__( + self, + num_channels: int, + channel_dim: int = -1, # CAUTION: see documentation. + eps: float = 0.25, + learn_eps: bool = True, + eps_min: float = -3.0, + eps_max: float = 3.0, + ) -> None: + super(BasicNorm, self).__init__() + self.num_channels = num_channels + self.channel_dim = channel_dim + if learn_eps: + self.eps = nn.Parameter(torch.tensor(eps).log().detach()) + else: + self.register_buffer("eps", torch.tensor(eps).log().detach()) + self.eps_min = eps_min + self.eps_max = eps_max + + def forward(self, x: Tensor) -> Tensor: + assert x.shape[self.channel_dim] == self.num_channels + eps = self.eps + if self.training and random.random() < 0.25: + # with probability 0.25, in training mode, clamp eps between the min + # and max; this will encourage it to learn parameters within the + # allowed range by making parameters that are outside the allowed + # range noisy. + + # gradients to allow the parameter to get back into the allowed + # region if it happens to exit it. + eps = eps.clamp(min=self.eps_min, max=self.eps_max) + scales = ( + torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + eps.exp() + ) ** -0.5 + return x * scales + + +def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear: + """ + Behaves like a constructor of a modified version of nn.Linear + that gives an easy way to set the default initial parameter scale. + + Args: + Accepts the standard args and kwargs that nn.Linear accepts + e.g. in_features, out_features, bias=False. + + initial_scale: you can override this if you want to increase + or decrease the initial magnitude of the module's output + (affects the initialization of weight_scale and bias_scale). + Another option, if you want to do something like this, is + to re-initialize the parameters. + """ + ans = nn.Linear(*args, **kwargs) + with torch.no_grad(): + ans.weight[:] *= initial_scale + if ans.bias is not None: + torch.nn.init.uniform_( + ans.bias, -0.1 * initial_scale, 0.1 * initial_scale + ) + return ans + + +def ScaledConv1d( + *args, + initial_scale: float = 1.0, + kernel_size: int = 3, + padding: str = "same", + **kwargs, +) -> nn.Conv1d: + """ + Behaves like a constructor of a modified version of nn.Conv1d + that gives an easy way to set the default initial parameter scale. + + Args: + Accepts the standard args and kwargs that nn.Linear accepts + e.g. in_features, out_features, bias=False. + + initial_scale: you can override this if you want to increase + or decrease the initial magnitude of the module's output + (affects the initialization of weight_scale and bias_scale). + Another option, if you want to do something like this, is + to re-initialize the parameters. + """ + ans = nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs) + with torch.no_grad(): + ans.weight[:] *= initial_scale + if ans.bias is not None: + torch.nn.init.uniform_( + ans.bias, -0.1 * initial_scale, 0.1 * initial_scale + ) + return ans + + +def TransposeScaledConv1d( + *args, + initial_scale: float = 1.0, + kernel_size: int = 3, + padding: str = "same", + **kwargs, +) -> nn.Sequential: + """ + Transpose -> ScaledConv1d + """ + return nn.Sequential( + Transpose(), + ScaledConv1d( + *args, + initial_scale=initial_scale, + kernel_size=kernel_size, + padding=padding, + **kwargs, + ), + ) + + +def ScaledConv1dTranspose( + *args, + initial_scale: float = 1.0, + kernel_size: int = 3, + padding: str = "same", + **kwargs, +) -> nn.Sequential: + """ + Transpose -> ScaledConv1d + """ + return nn.Sequential( + ScaledConv1d( + *args, + initial_scale=initial_scale, + kernel_size=kernel_size, + padding=padding, + **kwargs, + ), + Transpose(), + ) + + +def TransposeConv1d( + *args, kernel_size: int = 3, padding: str = "same", **kwargs +) -> nn.Sequential: + """ + Transpose -> Conv1d + """ + return nn.Sequential( + Transpose(), + nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs), + ) + + +def Conv1dTranspose( + *args, kernel_size: int = 3, padding: str = "same", **kwargs +) -> nn.Sequential: + """ + ScaledConv1d -> Transpose + """ + return nn.Sequential( + nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs), + Transpose(), + ) + + +class SRLinear(nn.Linear): + """https://arxiv.org/abs/2303.06296 + Stabilizing Transformer Training by Preventing Attention Entropy Collapse + """ + + def __init__(self, in_features, out_features, bias=True, **kwargs): + super().__init__(in_features, out_features, bias=bias, **kwargs) + self.register_buffer( + "u", nn.functional.normalize(torch.randn(in_features), dim=0) + ) + with torch.no_grad(): + sigma = self.get_sigma() + self.register_buffer("spectral_norm", sigma) + self.sigma = nn.Parameter(torch.ones(1)) + + def get_sigma(self): + with torch.no_grad(): + u = self.u + v = self.weight.mv(u) + v = nn.functional.normalize(v, dim=0) + u = self.weight.T.mv(v) + u = nn.functional.normalize(u, dim=0) + self.u.data.copy_(u) + return torch.einsum("c,cd,d->", v, self.weight, u) + + def get_weight(self): + sigma = self.get_sigma() + if self.training: + self.spectral_norm.data.copy_(sigma) + weight = (self.sigma / sigma) * self.weight + return weight + + def forward(self, x): + return nn.functional.linear(x, self.get_weight(), self.bias) + + +class SRConv1d(SRLinear): + def __init__( + self, + in_features, + out_features, + kernel_size, + stride: int = 1, + padding: str = "same", + bias: bool = True, + **kwargs, + ): + in_features = in_features * kernel_size + super().__init__(in_features, out_features, bias=bias, **kwargs) + nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + + def forward(self, x): + in_features = self.in_features // self.kernel_size + weight = self.get_weight().view( + self.out_features, in_features, self.kernel_size + ) + return nn.functional.conv1d( + x, weight, bias=self.bias, stride=self.stride, padding=self.padding + ) + + +def TransposeSRConv1d( + *args, kernel_size: int = 3, padding: str = "same", **kwargs +) -> nn.Sequential: + """ + Transpose -> SRConv1d + """ + return nn.Sequential( + Transpose(), + SRConv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs), + ) + + +def SRConv1dTranspose( + *args, kernel_size: int = 3, padding: str = "same", **kwargs +) -> nn.Sequential: + """ + SRConv1d -> Transpose + """ + return nn.Sequential( + SRConv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs), + Transpose(), + ) + + +class ActivationBalancer(torch.nn.Module): + """ + Modifies the backpropped derivatives of a function to try to encourage, for + each channel, that it is positive at least a proportion `threshold` of the + time. It does this by multiplying negative derivative values by up to + (1+max_factor), and positive derivative values by up to (1-max_factor), + interpolated from 1 at the threshold to those extremal values when none + of the inputs are positive. + + Args: + num_channels: the number of channels + channel_dim: the dimension/axis corresponding to the channel, e.g. + -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative. + min_positive: the minimum, per channel, of the proportion of the time + that (x > 0), below which we start to modify the derivatives. + max_positive: the maximum, per channel, of the proportion of the time + that (x > 0), above which we start to modify the derivatives. + max_factor: the maximum factor by which we modify the derivatives for + either the sign constraint or the magnitude constraint; + e.g. with max_factor=0.02, the the derivatives would be multiplied by + values in the range [0.98..1.02]. + sign_gain_factor: determines the 'gain' with which we increase the + change in gradient once the constraints on min_positive and max_positive + are violated. + scale_gain_factor: determines the 'gain' with which we increase the + change in gradient once the constraints on min_abs and max_abs + are violated. + min_abs: the minimum average-absolute-value difference from the mean + value per channel, which we allow, before we start to modify + the derivatives to prevent this. + max_abs: the maximum average-absolute-value difference from the mean + value per channel, which we allow, before we start to modify + the derivatives to prevent this. + min_prob: determines the minimum probability with which we modify the + gradients for the {min,max}_positive and {min,max}_abs constraints, + on each forward(). This is done randomly to prevent all layers + from doing it at the same time. Early in training we may use + higher probabilities than this; it will decay to this value. + """ + + def __init__( + self, + num_channels: int, + channel_dim: int, + min_positive: float = 0.05, + max_positive: float = 0.95, + max_factor: float = 0.04, + sign_gain_factor: float = 0.01, + scale_gain_factor: float = 0.02, + min_abs: float = 0.2, + max_abs: float = 100.0, + min_prob: float = 0.1, + ): + super(ActivationBalancer, self).__init__() + self.num_channels = num_channels + self.channel_dim = channel_dim + self.min_positive = min_positive + self.max_positive = max_positive + self.max_factor = max_factor + self.min_abs = min_abs + self.max_abs = max_abs + self.min_prob = min_prob + self.sign_gain_factor = sign_gain_factor + self.scale_gain_factor = scale_gain_factor + + # count measures how many times the forward() function has been called. + # We occasionally sync this to a tensor called `count`, that exists to + # make sure it is synced to disk when we load and save the model. + self.cpu_count = 0 + self.register_buffer("count", torch.tensor(0, dtype=torch.int64)) + + def forward(self, x: Tensor) -> Tensor: + if ( + torch.jit.is_scripting() + or not x.requires_grad + or torch.jit.is_tracing() + ): + return _no_op(x) + + count = self.cpu_count + self.cpu_count += 1 + + if random.random() < 0.01: + # Occasionally sync self.cpu_count with self.count. + # count affects the decay of 'prob'. don't do this on every iter, + # because syncing with the GPU is slow. + self.cpu_count = max(self.cpu_count, self.count.item()) + self.count.fill_(self.cpu_count) + + # the prob of doing some work exponentially decreases from 0.5 till it hits + # a floor at min_prob (==0.1, by default) + prob = max(self.min_prob, 0.5 ** (1 + (count / 4000.0))) + + if random.random() < prob: + sign_gain_factor = 0.5 + if self.min_positive != 0.0 or self.max_positive != 1.0: + sign_factor = _compute_sign_factor( + x, + self.channel_dim, + self.min_positive, + self.max_positive, + gain_factor=self.sign_gain_factor / prob, + max_factor=self.max_factor, + ) + else: + sign_factor = None + + scale_factor = _compute_scale_factor( + x.detach(), + self.channel_dim, + min_abs=self.min_abs, + max_abs=self.max_abs, + gain_factor=self.scale_gain_factor / prob, + max_factor=self.max_factor, + ) + return ActivationBalancerFunction.apply( + x, + scale_factor, + sign_factor, + self.channel_dim, + ) + else: + return _no_op(x) + + +def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float) -> Tensor: + """ + Returns x unmodified, but in backprop will put a penalty for the excess of + the absolute values of elements of x over the limit "limit". E.g. if + limit == 10.0, then if x has any values over 10 it will get a penalty. + + Caution: the value of this penalty will be affected by grad scaling used + in automatic mixed precision training. For this reasons we use this, + it shouldn't really matter, or may even be helpful; we just use this + to disallow really implausible values of scores to be given to softmax. + """ + x_sign = x.sign() + over_limit = (x.abs() - limit) > 0 + # The following is a memory efficient way to penalize the absolute values of + # x that's over the limit. (The memory efficiency comes when you think + # about which items torch needs to cache for the autograd, and which ones it + # can throw away). The numerical value of aux_loss as computed here will + # actually be larger than it should be, by limit * over_limit.sum(), but it + # has the same derivative as the real aux_loss which is penalty * (x.abs() - + # limit).relu(). + aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x) + # note: we don't do sum() here on aux)_loss, but it's as if we had done + # sum() due to how with_loss() works. + x = with_loss(x, aux_loss) + # you must use x for something, or this will be ineffective. + return x + + +def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims. + if x.ndim == 2: + return x.diag() + else: + (batch, dim, dim) = x.shape + x = x.reshape(batch, dim * dim) + x = x[:, :: dim + 1] + assert x.shape == (batch, dim) + return x + + +def _whitening_metric(x: Tensor, num_groups: int): + """ + Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of + of the centered feature covariance are the same within each group's covariance matrix + and also between groups. + Args: + x: a Tensor of shape (*, num_channels) + num_groups: the number of groups of channels, a number >=1 that divides num_channels + Returns: + Returns a scalar Tensor that will be 1.0 if the data is "perfectly white" and + greater than 1.0 otherwise. + """ + assert x.dtype != torch.float16 + x = x.reshape(-1, x.shape[-1]) + (num_frames, num_channels) = x.shape + assert num_channels % num_groups == 0 + channels_per_group = num_channels // num_groups + x = x.reshape(num_frames, num_groups, channels_per_group).transpose(0, 1) + # x now has shape (num_groups, num_frames, channels_per_group) + # subtract the mean so we use the centered, not uncentered, covariance. + # My experience has been that when we "mess with the gradients" like this, + # it's better not do anything that tries to move the mean around, because + # that can easily cause instability. + x = x - x.mean(dim=1, keepdim=True) + # x_covar: (num_groups, channels_per_group, channels_per_group) + x_covar = torch.matmul(x.transpose(1, 2), x) + x_covar_mean_diag = _diag(x_covar).mean() + # the following expression is what we'd get if we took the matrix product + # of each covariance and measured the mean of its trace, i.e. + # the same as _diag(torch.matmul(x_covar, x_covar)).mean(). + x_covarsq_mean_diag = (x_covar ** 2).sum() / ( + num_groups * channels_per_group + ) + # this metric will be >= 1.0; the larger it is, the less 'white' the data was. + metric = x_covarsq_mean_diag / (x_covar_mean_diag ** 2 + 1.0e-20) + return metric + + +class WhiteningPenaltyFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: Tensor, + num_groups: int, + whitening_limit: float, + grad_scale: float, + ) -> Tensor: + ctx.save_for_backward(x) + ctx.num_groups = num_groups + ctx.whitening_limit = whitening_limit + ctx.grad_scale = grad_scale + return x + + @staticmethod + def backward(ctx, x_grad: Tensor): + (x_orig,) = ctx.saved_tensors + with torch.enable_grad(): + with torch.cuda.amp.autocast(enabled=False): + x_detached = x_orig.to(torch.float32).detach() + x_detached.requires_grad = True + + metric = _whitening_metric(x_detached, ctx.num_groups) + + if random.random() < 0.005 or __name__ == "__main__": + logging.info( + f"Whitening: num_groups={ctx.num_groups}, num_channels={x_orig.shape[-1]}, " + f"metric={metric.item():.2f} vs. limit={ctx.whitening_limit}" + ) + + (metric - ctx.whitening_limit).relu().backward() + penalty_grad = x_detached.grad + scale = ctx.grad_scale * ( + x_grad.to(torch.float32).norm() + / (penalty_grad.norm() + 1.0e-20) + ) + penalty_grad = penalty_grad * scale + return x_grad + penalty_grad.to(x_grad.dtype), None, None, None + + +class Whiten(nn.Module): + def __init__( + self, + num_groups: int, + whitening_limit: float, + prob: Union[float, Tuple[float, float]], + grad_scale: float, + ): + """ + Args: + num_groups: the number of groups to divide the channel dim into before + whitening. We will attempt to make the feature covariance + within each group, after mean subtraction, as "white" as possible, + while having the same trace across all groups. + whitening_limit: a value greater than 1.0, that dictates how much + freedom we have to violate the constraints. 1.0 would mean perfectly + white, with exactly the same trace across groups; larger values + give more freedom. E.g. 2.0. + prob: the probability with which we apply the gradient modification + (also affects the grad scale). May be supplied as a float, + or as a pair (min_prob, max_prob) + + grad_scale: determines the scale on the gradient term from this object, + relative to the rest of the gradient on the attention weights. + E.g. 0.02 (you may want to use smaller values than this if prob is large) + """ + super(Whiten, self).__init__() + assert num_groups >= 1 + assert whitening_limit >= 1 + assert grad_scale >= 0 + self.num_groups = num_groups + self.whitening_limit = whitening_limit + if isinstance(prob, float): + assert 0 < prob <= 1 + self.prob = prob + else: + (self.min_prob, self.max_prob) = prob + assert 0 < self.min_prob < self.max_prob <= 1 + self.prob = self.max_prob + + self.grad_scale = grad_scale + + def forward(self, x: Tensor) -> Tensor: + """ + In the forward pass, this function just returns the input unmodified. + In the backward pass, it will modify the gradients to ensure that the + distribution in each group has close to (lambda times I) as the covariance + after mean subtraction, with the same lambda across groups. + For whitening_limit > 1, there will be more freedom to violate this + constraint. + + Args: + x: the input of shape (*, num_channels) + + Returns: + x, unmodified. You should make sure + you use the returned value, or the graph will be freed + and nothing will happen in backprop. + """ + if ( + not x.requires_grad + or random.random() > self.prob + or self.grad_scale == 0 + ): + return _no_op(x) + else: + if hasattr(self, "min_prob") and random.random() < 0.25: + # occasionally switch between min_prob and max_prob, based on whether + # we are above or below the threshold. + if ( + _whitening_metric(x.to(torch.float32), self.num_groups) + > self.whitening_limit + ): + # there would be a change to the grad. + self.prob = self.max_prob + else: + self.prob = self.min_prob + + return WhiteningPenaltyFunction.apply( + x, self.num_groups, self.whitening_limit, self.grad_scale + ) + + +class WithLoss(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, y: Tensor): + ctx.y_shape = y.shape + return x + + @staticmethod + def backward(ctx, ans_grad: Tensor): + return ans_grad, torch.ones( + ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device + ) + + +def with_loss(x, y): + if torch.jit.is_scripting() or torch.jit.is_tracing(): + return x + # returns x but adds y.sum() to the loss function. + return WithLoss.apply(x, y) + + +def _no_op(x: Tensor) -> Tensor: + if torch.jit.is_scripting() or torch.jit.is_tracing(): + return x + else: + # a no-op function that will have a node in the autograd graph, + # to avoid certain bugs relating to backward hooks + return x.chunk(1, dim=-1)[0] + + +class Identity(torch.nn.Module): + def __init__(self): + super(Identity, self).__init__() + + def forward(self, x): + return _no_op(x) + + +class MaxEig(torch.nn.Module): + """ + Modifies the backpropped derivatives of a function to try to discourage + that any given direction in activation space accounts for more than + a specified proportion of the covariance (e.g. 0.2). + + + Args: + num_channels: the number of channels + channel_dim: the dimension/axis corresponding to the channel, e.g. + -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative. + max_var_per_eig: the maximum proportion of the variance of the + features/channels, after mean subtraction, that can come from + any given eigenvalue. + min_prob: the minimum probability with which we apply this during any invocation + of forward(), assuming last time we applied the constraint it was + not active; supplied for speed. + scale: determines the scale with which we modify the gradients, relative + to the existing / unmodified gradients + """ + + def __init__( + self, + num_channels: int, + channel_dim: int, + max_var_per_eig: float = 0.2, + min_prob: float = 0.01, + scale: float = 0.01, + ): + super(MaxEig, self).__init__() + self.num_channels = num_channels + self.channel_dim = channel_dim + self.scale = scale + assert max_var_per_eig == 0.0 or max_var_per_eig > 1.0 / num_channels + self.max_var_per_eig = max_var_per_eig + + # we figure out the dominant direction using the power method: starting with + # a random vector, keep multiplying by the covariance and renormalizing. + with torch.no_grad(): + # arbitrary.. would use randn() but want to leave the rest of the model's + # random parameters unchanged for comparison + direction = torch.arange(num_channels).to(torch.float) + direction = direction / direction.norm() + self.register_buffer("max_eig_direction", direction) + + self.min_prob = min_prob + # cur_prob is the current probability we'll use to apply the ActivationBalancer. + # We'll regress this towards prob, each tiem we try to apply it and it is not + # active. + self.cur_prob = 1.0 + + def forward(self, x: Tensor) -> Tensor: + if ( + torch.jit.is_scripting() + or self.max_var_per_eig <= 0 + or random.random() > self.cur_prob + or torch.jit.is_tracing() + ): + return _no_op(x) + + with torch.cuda.amp.autocast(enabled=False): + eps = 1.0e-20 + orig_x = x + x = x.to(torch.float32) + with torch.no_grad(): + x = x.transpose(self.channel_dim, -1).reshape( + -1, self.num_channels + ) + x = x - x.mean(dim=0) + new_direction, coeffs = self._find_direction_coeffs( + x, self.max_eig_direction + ) + x_var = (x ** 2).mean() + x_residual = x - coeffs * new_direction + x_residual_var = (x_residual ** 2).mean() + + # `variance_proportion` is the proportion of the variance accounted for + # by the top eigen-direction. + variance_proportion = (x_var - x_residual_var) / ( + x_var + 1.0e-20 + ) + + # ensure new direction is nonzero even if x == 0, by including `direction`. + self._set_direction( + 0.1 * self.max_eig_direction + new_direction + ) + + if random.random() < 0.01 or __name__ == "__main__": + logging.info( + f"variance_proportion = {variance_proportion.item()}, shape={tuple(orig_x.shape)}, cur_prob={self.cur_prob}" + ) + + if variance_proportion >= self.max_var_per_eig: + # The constraint is active. Note, we should quite rarely + # reach here, only near the beginning of training if we are + # starting to diverge, should this constraint be active. + cur_prob = self.cur_prob + self.cur_prob = ( + 1.0 # next time, do the update with probability 1.0. + ) + return MaxEigLimiterFunction.apply( + orig_x, coeffs, new_direction, self.channel_dim, self.scale + ) + else: + # let self.cur_prob exponentially approach self.min_prob, as + # long as the constraint is inactive. + self.cur_prob = 0.75 * self.cur_prob + 0.25 * self.min_prob + return orig_x + + def _set_direction(self, direction: Tensor): + """ + Sets self.max_eig_direction to a normalized version of `direction` + """ + direction = direction.detach() + direction = direction / direction.norm() + direction_sum = direction.sum().item() + if direction_sum - direction_sum == 0: # no inf/nan + self.max_eig_direction[:] = direction + else: + logging.info( + f"Warning: sum of direction in MaxEig is {direction_sum}, " + "num_channels={self.num_channels}, channel_dim={self.channel_dim}" + ) + + def _find_direction_coeffs( + self, x: Tensor, prev_direction: Tensor + ) -> Tuple[Tensor, Tensor, Tensor]: + """ + Figure out (an approximation to) the proportion of the variance of a set of + feature vectors that can be attributed to the top eigen-direction. + Args: + x: a Tensor of shape (num_frames, num_channels), with num_frames > 1. + prev_direction: a Tensor of shape (num_channels,), that is our previous estimate + of the top eigen-direction, or a random direction if this is the first + iteration. Does not have to be normalized, but should be nonzero. + + Returns: (cur_direction, coeffs), where: + cur_direction: a Tensor of shape (num_channels,) that is the current + estimate of the top eigen-direction. + coeffs: a Tensor of shape (num_frames, 1) that minimizes, or + approximately minimizes, (x - coeffs * cur_direction).norm() + """ + (num_frames, num_channels) = x.shape + assert num_channels > 1 and num_frames > 1 + assert prev_direction.shape == (num_channels,) + # `coeffs` are the coefficients of `prev_direction` in x. + # actually represent the coeffs up to a constant positive factor. + coeffs = (x * prev_direction).sum(dim=1, keepdim=True) + 1.0e-10 + cur_direction = (x * coeffs).sum(dim=0) / ( + (coeffs ** 2).sum() + 1.0e-20 + ) + return cur_direction, coeffs + + +class DoubleSwishFunction(torch.autograd.Function): + """ + double_swish(x) = x * torch.sigmoid(x-1) + This is a definition, originally motivated by its close numerical + similarity to swish(swish(x)), where swish(x) = x * sigmoid(x). + + Memory-efficient derivative computation: + double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1) + double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x). + Now, s'(x) = s(x) * (1-s(x)). + double_swish'(x) = x * s'(x) + s(x). + = x * s(x) * (1-s(x)) + s(x). + = double_swish(x) * (1-s(x)) + s(x) + ... so we just need to remember s(x) but not x itself. + """ + + @staticmethod + def forward(ctx, x: Tensor) -> Tensor: + requires_grad = x.requires_grad + x_dtype = x.dtype + if x.dtype == torch.float16: + x = x.to(torch.float32) + + s = torch.sigmoid(x - 1.0) + y = x * s + + if requires_grad: + deriv = y * (1 - s) + s + # notes on derivative of x * sigmoid(x - 1): + # https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29 + # min \simeq -0.043638. Take floor as -0.043637 so it's a lower bund + # max \simeq 1.1990. Take ceil to be 1.2 so it's an upper bound. + # the combination of "+ torch.rand_like(deriv)" and casting to torch.uint8 (which + # floors), should be expectation-preserving. + floor = -0.043637 + ceil = 1.2 + d_scaled = (deriv - floor) * ( + 255.0 / (ceil - floor) + ) + torch.rand_like(deriv) + if __name__ == "__main__": + # for self-testing only. + assert d_scaled.min() >= 0.0 + assert d_scaled.max() < 256.0 + d_int = d_scaled.to(torch.uint8) + ctx.save_for_backward(d_int) + if x.dtype == torch.float16 or torch.is_autocast_enabled(): + y = y.to(torch.float16) + return y + + @staticmethod + def backward(ctx, y_grad: Tensor) -> Tensor: + (d,) = ctx.saved_tensors + # the same constants as used in forward pass. + floor = -0.043637 + ceil = 1.2 + d = d * ((ceil - floor) / 255.0) + floor + return y_grad * d + + +class DoubleSwish(torch.nn.Module): + def forward(self, x: Tensor) -> Tensor: + """Return double-swish activation function which is an approximation to Swish(Swish(x)), + that we approximate closely with x * sigmoid(x-1). + """ + if torch.jit.is_scripting() or torch.jit.is_tracing(): + return x * torch.sigmoid(x - 1.0) + return DoubleSwishFunction.apply(x) + + +def BalancedDoubleSwish( + d_model, channel_dim=-1, max_abs=10.0, min_prob=0.25 +) -> nn.Sequential: + """ + ActivationBalancer -> DoubleSwish + """ + balancer = ActivationBalancer( + d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob + ) + return nn.Sequential( + balancer, + DoubleSwish(), + ) + + +def _test_max_eig(): + for proportion in [0.1, 0.5, 10.0]: + logging.info(f"proportion = {proportion}") + x = torch.randn(100, 128) + direction = torch.randn(128) + coeffs = torch.randn(100, 1) + x += proportion * direction * coeffs + + x.requires_grad = True + + num_channels = 128 + m = MaxEig( + num_channels, 1, 0.5, scale=0.1 # channel_dim # max_var_per_eig + ) # grad_scale + + for _ in range(4): + y = m(x) + + y_grad = torch.randn_like(x) + y.backward(gradient=y_grad) + + if proportion < 0.2: + assert torch.allclose(x.grad, y_grad, atol=1.0e-02) + elif proportion > 1.0: + assert not torch.allclose(x.grad, y_grad) + + +def _test_whiten(): + for proportion in [0.1, 0.5, 10.0]: + logging.info(f"_test_whiten(): proportion = {proportion}") + x = torch.randn(100, 128) + direction = torch.randn(128) + coeffs = torch.randn(100, 1) + x += proportion * direction * coeffs + + x.requires_grad = True + + num_channels = 128 + m = Whiten( + 1, 5.0, prob=1.0, grad_scale=0.1 # num_groups # whitening_limit, + ) # grad_scale + + for _ in range(4): + y = m(x) + + y_grad = torch.randn_like(x) + y.backward(gradient=y_grad) + + if proportion < 0.2: + assert torch.allclose(x.grad, y_grad) + elif proportion > 1.0: + assert not torch.allclose(x.grad, y_grad) + + +def _test_activation_balancer_sign(): + probs = torch.arange(0, 1, 0.01) + N = 1000 + x = 1.0 * ( + (2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0 + ) + x = x.detach() + x.requires_grad = True + m = ActivationBalancer( + probs.numel(), + channel_dim=0, + min_positive=0.05, + max_positive=0.95, + max_factor=0.2, + min_abs=0.0, + ) + + y_grad = torch.sign(torch.randn(probs.numel(), N)) + + y = m(x) + y.backward(gradient=y_grad) + print("_test_activation_balancer_sign: x = ", x) + print("_test_activation_balancer_sign: y grad = ", y_grad) + print("_test_activation_balancer_sign: x grad = ", x.grad) + + +def _test_activation_balancer_magnitude(): + magnitudes = torch.arange(0, 1, 0.01) + N = 1000 + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze( + -1 + ) + x = x.detach() + x.requires_grad = True + m = ActivationBalancer( + magnitudes.numel(), + channel_dim=0, + min_positive=0.0, + max_positive=1.0, + max_factor=0.2, + min_abs=0.2, + max_abs=0.8, + min_prob=1.0, + ) + + y_grad = torch.sign(torch.randn(magnitudes.numel(), N)) + + y = m(x) + y.backward(gradient=y_grad) + print("_test_activation_balancer_magnitude: x = ", x) + print("_test_activation_balancer_magnitude: y grad = ", y_grad) + print("_test_activation_balancer_magnitude: x grad = ", x.grad) + + +def _test_basic_norm(): + num_channels = 128 + m = BasicNorm(num_channels=num_channels, channel_dim=1) + + x = torch.randn(500, num_channels) + + y = m(x) + + assert y.shape == x.shape + x_rms = (x ** 2).mean().sqrt() + y_rms = (y ** 2).mean().sqrt() + print("x rms = ", x_rms) + print("y rms = ", y_rms) + assert y_rms < x_rms + assert y_rms > 0.5 * x_rms + + +def _test_double_swish_deriv(): + x = torch.randn(10, 12, dtype=torch.double) * 3.0 + x.requires_grad = True + m = DoubleSwish() + + tol = (1.2 - (-0.043637)) / 255.0 + torch.autograd.gradcheck(m, x, atol=tol) + + # for self-test. + x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 + x.requires_grad = True + y = m(x) + + +def _test_softmax(): + a = torch.randn(2, 10, dtype=torch.float64) + b = a.clone() + a.requires_grad = True + b.requires_grad = True + a.softmax(dim=1)[:, 0].sum().backward() + print("a grad = ", a.grad) + softmax(b, dim=1)[:, 0].sum().backward() + print("b grad = ", b.grad) + assert torch.allclose(a.grad, b.grad) + + +if __name__ == "__main__": + logging.getLogger().setLevel(logging.INFO) + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + _test_softmax() + _test_whiten() + _test_max_eig() + _test_activation_balancer_sign() + _test_activation_balancer_magnitude() + _test_basic_norm() + _test_double_swish_deriv() diff --git a/transformer_modules/scheduler.py b/transformer_modules/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..da6b25dd5b88ab3737cc00baff0d88c2cb288978 --- /dev/null +++ b/transformer_modules/scheduler.py @@ -0,0 +1,76 @@ +#!/usr/bin/env python3 +# Copyright 2023 (authors: Feiteng Li) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch + + +def calc_lr(step, dim_embed, warmup_steps): + return dim_embed ** (-0.5) * min( + step ** (-0.5), step * warmup_steps ** (-1.5) + ) + + +class NoamScheduler(torch.optim.lr_scheduler._LRScheduler): + def __init__( + self, + base_lr: float, + optimizer: torch.optim.Optimizer, + dim_embed: int, + warmup_steps: int, + last_epoch: int = -1, + verbose: bool = False, + ) -> None: + + self.dim_embed = dim_embed + self.base_lr = base_lr + self.warmup_steps = warmup_steps + self.num_param_groups = len(optimizer.param_groups) + + super().__init__(optimizer, last_epoch, verbose) + + def get_lr(self) -> float: + lr = self.base_lr * calc_lr( + self._step_count, self.dim_embed, self.warmup_steps + ) + return [lr] * self.num_param_groups + + def set_step(self, step: int): + self._step_count = step + + +def get_scheduler(params, optimizer): + if params.scheduler_name.lower() == "eden": + scheduler = Eden(optimizer, 5000, 4, warmup_batches=params.warmup_steps) + elif params.scheduler_name.lower() == "noam": + scheduler = NoamScheduler( + params.base_lr, + optimizer, + params.decoder_dim, + warmup_steps=params.warmup_steps, + ) + # scheduler.set_step(params.start_batch or params.batch_idx_train) + elif params.scheduler_name.lower() == "cosine": + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + params.warmup_steps, + optimizer, + eta_min=params.base_lr, + ) + else: + raise NotImplementedError(f"{params.scheduler_name}") + + return scheduler diff --git a/transformer_modules/transformer.py b/transformer_modules/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..960b9a6d1e9debd34b9601d864e1dd5406954be4 --- /dev/null +++ b/transformer_modules/transformer.py @@ -0,0 +1,869 @@ +import copy +import numbers +from functools import partial +from typing import Any, Callable, List, Optional, Tuple, Union + +import torch +from torch import Tensor, nn +from torch.nn import functional as F + +from .activation import MultiheadAttention +from .scaling import ActivationBalancer, BalancedDoubleSwish +from .scaling import BasicNorm as _BasicNorm +from .rotary_embedding import RotaryEmbedding +from .conv import ConvolutionModule, MultiLayeredConv1d +_shape_t = Union[int, List[int], torch.Size] + + +class LayerNorm(nn.Module): + __constants__ = ["normalized_shape", "eps", "elementwise_affine"] + normalized_shape: Tuple[int, ...] + eps: float + elementwise_affine: bool + + def __init__( + self, + normalized_shape: _shape_t, + eps: float = 1e-5, + elementwise_affine: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super(LayerNorm, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + # mypy error: incompatible types in assignment + normalized_shape = (normalized_shape,) # type: ignore[assignment] + self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] + self.eps = eps + self.elementwise_affine = elementwise_affine + if self.elementwise_affine: + self.weight = nn.Parameter( + torch.empty(self.normalized_shape, **factory_kwargs) + ) + self.bias = nn.Parameter( + torch.empty(self.normalized_shape, **factory_kwargs) + ) + else: + self.register_parameter("weight", None) + self.register_parameter("bias", None) + + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.elementwise_affine: + nn.init.ones_(self.weight) + nn.init.zeros_(self.bias) + + def forward(self, input: Tensor, embedding: Any = None) -> Tensor: + if isinstance(input, tuple): + input, embedding = input + return ( + F.layer_norm( + input, + self.normalized_shape, + self.weight, + self.bias, + self.eps, + ), + embedding, + ) + + assert embedding is None + return F.layer_norm( + input, self.normalized_shape, self.weight, self.bias, self.eps + ) + + def extra_repr(self) -> str: + return ( + "{normalized_shape}, eps={eps}, " + "elementwise_affine={elementwise_affine}".format(**self.__dict__) + ) + + +class AdaptiveLayerNorm(nn.Module): + r"""Adaptive Layer Normalization""" + + def __init__(self, d_model, norm) -> None: + super(AdaptiveLayerNorm, self).__init__() + self.project_layer = nn.Linear(d_model, 2 * d_model) + self.norm = norm + self.d_model = d_model + self.eps = self.norm.eps + + def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor: + if isinstance(input, tuple): + input, embedding = input + weight, bias = torch.split( + self.project_layer(embedding), + split_size_or_sections=self.d_model, + dim=-1, + ) + return (weight * self.norm(input) + bias, embedding) + + weight, bias = torch.split( + self.project_layer(embedding), + split_size_or_sections=self.d_model, + dim=-1, + ) + return weight * self.norm(input) + bias + + +class BasicNorm(_BasicNorm): + def __init__( + self, + d_model: int, + eps: float = 1e-5, + device=None, + dtype=None, + ): + super(BasicNorm, self).__init__(d_model, eps=eps) + + def forward(self, input: Tensor, embedding: Any = None) -> Tensor: + if isinstance(input, tuple): + input, embedding = input + return ( + super(BasicNorm, self).forward(input), + embedding, + ) + + assert embedding is None + return super(BasicNorm, self).forward(input) + + +class BalancedBasicNorm(nn.Module): + def __init__( + self, + d_model: int, + eps: float = 1e-5, + device=None, + dtype=None, + ): + super(BalancedBasicNorm, self).__init__() + self.balancer = ActivationBalancer( + d_model, + channel_dim=-1, + min_positive=0.45, + max_positive=0.55, + max_abs=6.0, + ) + self.norm = BasicNorm(d_model, eps, device=device, dtype=dtype) + + def forward(self, input: Tensor, embedding: Any = None) -> Tensor: + if isinstance(input, tuple): + input, embedding = input + return self.norm((self.balancer(input), embedding)) + + assert embedding is None + return self.norm(self.balancer(input)) + + +class IdentityNorm(nn.Module): + def __init__( + self, + d_model: int, + eps: float = 1e-5, + device=None, + dtype=None, + ) -> None: + super(IdentityNorm, self).__init__() + + def forward(self, input: Tensor, embedding: Any = None) -> Tensor: + if isinstance(input, tuple): + return input + + assert embedding is None + return input + +class RMSNorm(nn.Module): + def __init__(self, d, p=-1., eps=1e-8, bias=False): + """ + Root Mean Square Layer Normalization + :param d: model size + :param p: partial RMSNorm, valid value [0, 1], default -1.0 (disabled) + :param eps: epsilon value, default 1e-8 + :param bias: whether use bias term for RMSNorm, disabled by + default because RMSNorm doesn't enforce re-centering invariance. + """ + super(RMSNorm, self).__init__() + + self.eps = eps + self.d = d + self.p = p + self.bias = bias + + self.scale = nn.Parameter(torch.ones(d)) + self.register_parameter("scale", self.scale) + + if self.bias: + self.offset = nn.Parameter(torch.zeros(d)) + self.register_parameter("offset", self.offset) + + def forward(self, x): + if self.p < 0. or self.p > 1.: + norm_x = x.norm(2, dim=-1, keepdim=True) + d_x = self.d + else: + partial_size = int(self.d * self.p) + partial_x, _ = torch.split(x, [partial_size, self.d - partial_size], dim=-1) + + norm_x = partial_x.norm(2, dim=-1, keepdim=True) + d_x = partial_size + + rms_x = norm_x * d_x ** (-1. / 2) + x_normed = x / (rms_x + self.eps) + + if self.bias: + return self.scale * x_normed + self.offset + + return self.scale * x_normed + + +class TransformerEncoderLayer(nn.Module): + __constants__ = ["batch_first", "norm_first"] + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, + batch_first: bool = False, + norm_first: bool = False, + device=None, + dtype=None, + linear1_self_attention_cls: nn.Module = nn.Linear, + linear2_self_attention_cls: nn.Module = nn.Linear, + linear1_feedforward_cls: nn.Module = nn.Linear, + linear2_feedforward_cls: nn.Module = nn.Linear, + layer_norm_cls: nn.Module = LayerNorm, + layer_norm_eps: float = 1e-5, + adaptive_layer_norm=False, + use_conv_module: bool = False, + use_depth_wise_conv: bool = False, + conv_ignore_prefix_len: int = 0, + cross_attention: bool = False, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super(TransformerEncoderLayer, self).__init__() + self.self_attn = MultiheadAttention( + d_model, + nhead, + dropout=dropout, + batch_first=batch_first, + linear1_cls=linear1_self_attention_cls, + linear2_cls=linear2_self_attention_cls, + **factory_kwargs, + ) + if cross_attention: + self.has_cross_attention = True + self.cross_attn = nn.MultiheadAttention( + d_model, nhead, 0.1, batch_first=True + ) + self.norm3 = layer_norm_cls( + d_model, eps=layer_norm_eps, **factory_kwargs + ) + + # Implementation of Feedforward model + self.use_depth_wise_conv = use_depth_wise_conv + self.use_conv_module = use_conv_module + if not use_depth_wise_conv: + self.linear1 = linear1_feedforward_cls( + d_model, dim_feedforward, **factory_kwargs + ) + self.dropout = nn.Dropout(dropout) + self.linear2 = linear2_feedforward_cls( + dim_feedforward, d_model, **factory_kwargs + ) + else: + self.dw_ffn = MultiLayeredConv1d( + in_chans=d_model, + hidden_chans=dim_feedforward, + kernel_size=5, + dropout_rate=dropout, + ) + + self.norm_first = norm_first + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + # Legacy string support for activation function. + if isinstance(activation, str): + activation = _get_activation_fn(activation) + elif isinstance(activation, partial): + activation = activation(d_model) + elif activation == BalancedDoubleSwish: + activation = BalancedDoubleSwish(d_model) + + self.activation = activation + + norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs) + if layer_norm_cls == IdentityNorm: + norm2 = BalancedBasicNorm( + d_model, eps=layer_norm_eps, **factory_kwargs + ) + else: + norm2 = layer_norm_cls( + d_model, eps=layer_norm_eps, **factory_kwargs + ) + + if adaptive_layer_norm: + self.norm1 = AdaptiveLayerNorm(d_model, norm1) + self.norm2 = AdaptiveLayerNorm(d_model, norm2) + else: + self.norm1 = norm1 + self.norm2 = norm2 + + self.rotary_emb = RotaryEmbedding(dim=d_model // nhead) + + if use_conv_module: + self.conv_module = ConvolutionModule( + d_model, + kernel_size=31, + activation=activation, + ignore_prefix_len=conv_ignore_prefix_len, + ) + self.norm_conv = LayerNorm(d_model) # for the CNN module + if adaptive_layer_norm: + self.norm_conv = AdaptiveLayerNorm(d_model, self.norm_conv) + else: + self.conv_module = None + + def __setstate__(self, state): + super(TransformerEncoderLayer, self).__setstate__(state) + if not hasattr(self, "activation"): + self.activation = F.relu + + def forward( + self, + src: Tensor, + context: Optional[Tensor] = None, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + use_rope: bool = False, + ) -> Tensor: + r"""Pass the input through the encoder layer. + + Args: + src: the sequence to the encoder layer (required). + src_mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + see the docs in Transformer class. + """ + is_src_tuple = False + if isinstance(src, tuple): + x, stage_embedding = src + is_src_tuple = True + else: + x, stage_embedding = src, None + + if src_key_padding_mask is not None: + _skpm_dtype = src_key_padding_mask.dtype + if _skpm_dtype != torch.bool and not torch.is_floating_point( + src_key_padding_mask + ): + raise AssertionError( + "only bool and floating types of key_padding_mask are supported" + ) + + if self.norm_first: + x = x + self._sa_block( + self.norm1(x, stage_embedding), + src_mask, + src_key_padding_mask, + use_rope=use_rope, + ) + if self.conv_module is not None: + residual = x + x = self.norm_conv(x, stage_embedding) + x = residual + self.dropout1(self.conv_module(x)) + # if self.has_cross_attention: + # x = x + self.cross_attn( + # self.norm3(x, stage_embedding), + # context, + # context, + # attn_mask=src_mask, + # )[0] + x = x + self._ff_block(self.norm2(x, stage_embedding)) + else: + x = self.norm1( + x + self._sa_block(x, src_mask, src_key_padding_mask, use_rope=use_rope), + stage_embedding, + ) + if self.conv_module is not None: + residual = x + x = residual + self.dropout(self.conv_module(x)) + x = self.norm_conv(x, stage_embedding) + x = self.norm2(x + self._ff_block(x), stage_embedding) + + if is_src_tuple: + return (x, stage_embedding) + return x + + def infer( + self, + src: Tensor, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + past_kv: Optional[Tensor] = None, + use_cache: bool = False, + use_rope: bool = False, + ): + x, stage_embedding = src, None + is_src_tuple = False + if isinstance(src, tuple): + x, stage_embedding = src + is_src_tuple = True + + if src_key_padding_mask is not None: + _skpm_dtype = src_key_padding_mask.dtype + if _skpm_dtype != torch.bool and not torch.is_floating_point( + src_key_padding_mask + ): + raise AssertionError( + "only bool and floating types of key_padding_mask are supported" + ) + + if self.norm_first: + x_attn_out, kv = self.self_attn.infer( + self.norm1(x, stage_embedding), + attn_mask=src_mask, + key_padding_mask=src_key_padding_mask, + need_weights=False, + past_kv=past_kv, + use_cache=use_cache, + use_rope=use_rope, + rope=self.rotary_emb + ) + x = x + x_attn_out + x = x + self._ff_block(self.norm2(x, stage_embedding)) + + if is_src_tuple: + return (x, stage_embedding) + return (x, kv) + + # self-attention block + def _sa_block( + self, + x: Tensor, + attn_mask: Optional[Tensor], + key_padding_mask: Optional[Tensor], + use_rope: bool = False, + ) -> Tensor: + x = self.self_attn( + x, + x, + x, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + need_weights=False, + use_rope=use_rope, + rope=self.rotary_emb + )[0] + return self.dropout1(x) + + # feed forward block + def _ff_block(self, x: Tensor) -> Tensor: + if self.use_depth_wise_conv: + x = self.dw_ffn(x) + else: + x = self.linear2(self.dropout(self.activation(self.linear1(x)))) + return self.dropout2(x) + + +class TransformerEncoder(nn.Module): + r"""TransformerEncoder is a stack of N encoder layers. Users can build the + BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters. + + Args: + encoder_layer: an instance of the TransformerEncoderLayer() class (required). + num_layers: the number of sub-encoder-layers in the encoder (required). + norm: the layer normalization component (optional). + enable_nested_tensor: if True, input will automatically convert to nested tensor + (and convert back on output). This will improve the overall performance of + TransformerEncoder when padding rate is high. Default: ``True`` (enabled). + + Examples:: + >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8) + >>> transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6) + >>> src = torch.rand(10, 32, 512) + >>> out = transformer_encoder(src) + """ + __constants__ = ["norm"] + + def __init__(self, encoder_layer, num_layers, norm=None): + super(TransformerEncoder, self).__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward( + self, + src: Tensor, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + return_layer_states: bool = False, + use_rope: bool = False, + ) -> Tensor: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required). + mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + return_layer_states: return layers' state (optional). + + Shape: + see the docs in Transformer class. + """ + if return_layer_states: + layer_states = [] # layers' output + output = src + for mod in self.layers: + output = mod( + output, + src_mask=mask, + src_key_padding_mask=src_key_padding_mask, + use_rope=use_rope, + ) + layer_states.append(output[0]) + + if self.norm is not None: + output = self.norm(output) + + return layer_states, output + + output = src + for mod in self.layers: + output = mod( + output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, use_rope=use_rope + ) + + if self.norm is not None: + output = self.norm(output) + + return output + + def infer( + self, + src: Tensor, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + return_layer_states: bool = False, + past_kv: Optional[Tensor] = None, + use_cache: bool = False, + use_rope: bool = False, + ): + if past_kv is None: + past_length = 0 + past_kv = tuple([None] * self.num_layers) + else: + past_length = past_kv[0][0].size(-2) + new_kv = () if use_cache else None + output = src + for mod, past_layer_kv in zip(self.layers, past_kv): + output, kv = mod.infer( + output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, past_kv=past_layer_kv, use_cache=use_cache, use_rope=use_rope + ) + if use_cache: + new_kv = new_kv + (kv,) + + if self.norm is not None: + output = self.norm(output) + + return output, new_kv + +class TransformerDecoderLayer(nn.Module): + __constants__ = ["batch_first", "norm_first"] + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, + linear1_self_attention_cls: nn.Module = nn.Linear, + linear2_self_attention_cls: nn.Module = nn.Linear, + linear1_feedforward_cls: nn.Module = nn.Linear, + linear2_feedforward_cls: nn.Module = nn.Linear, + batch_first: bool = False, + norm_first: bool = False, + device=None, + dtype=None, + layer_norm_cls: nn.Module = LayerNorm, + layer_norm_eps: float = 1e-5, + adaptive_layer_norm=False, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super(TransformerDecoderLayer, self).__init__() + self.self_attn = MultiheadAttention( + d_model, + nhead, + dropout=dropout, + batch_first=batch_first, + linear1_cls=linear1_self_attention_cls, + linear2_cls=linear2_self_attention_cls, + **factory_kwargs, + ) + self.multihead_attn = MultiheadAttention( + d_model, + nhead, + dropout=dropout, + batch_first=batch_first, + linear1_cls=linear1_self_attention_cls, + linear2_cls=linear2_self_attention_cls, + **factory_kwargs, + ) + # Implementation of Feedforward model + self.linear1 = linear1_feedforward_cls( + d_model, dim_feedforward, **factory_kwargs + ) + self.dropout = nn.Dropout(dropout) + self.linear2 = linear2_feedforward_cls( + dim_feedforward, d_model, **factory_kwargs + ) + + self.norm_first = norm_first + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + # Legacy string support for activation function. + if isinstance(activation, str): + self.activation = _get_activation_fn(activation) + elif isinstance(activation, partial): + self.activation = activation(d_model) + elif activation == BalancedDoubleSwish: + self.activation = BalancedDoubleSwish(d_model) + else: + self.activation = activation + + if adaptive_layer_norm: + norm1 = layer_norm_cls( + d_model, eps=layer_norm_eps, **factory_kwargs + ) + norm2 = layer_norm_cls( + d_model, eps=layer_norm_eps, **factory_kwargs + ) + norm3 = layer_norm_cls( + d_model, eps=layer_norm_eps, **factory_kwargs + ) + + self.norm1 = AdaptiveLayerNorm(d_model, norm1) + self.norm2 = AdaptiveLayerNorm(d_model, norm2) + self.norm3 = AdaptiveLayerNorm(d_model, norm3) + else: + self.norm1 = layer_norm_cls( + d_model, eps=layer_norm_eps, **factory_kwargs + ) + self.norm2 = layer_norm_cls( + d_model, eps=layer_norm_eps, **factory_kwargs + ) + if layer_norm_cls == IdentityNorm: + self.norm3 = BalancedBasicNorm( + d_model, eps=layer_norm_eps, **factory_kwargs + ) + else: + self.norm3 = layer_norm_cls( + d_model, eps=layer_norm_eps, **factory_kwargs + ) + + self.rotary_emb = RotaryEmbedding(dim=d_model // nhead) + + def forward( + self, + tgt: Tensor, + memory: Tensor, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + use_rope: bool = False, + ) -> Tensor: + r"""Pass the inputs (and mask) through the decoder layer. + + Args: + tgt: the sequence to the decoder layer (required). + memory: the sequence from the last layer of the encoder (required). + tgt_mask: the mask for the tgt sequence (optional). + memory_mask: the mask for the memory sequence (optional). + tgt_key_padding_mask: the mask for the tgt keys per batch (optional). + memory_key_padding_mask: the mask for the memory keys per batch (optional). + + Shape: + see the docs in Transformer class. + """ + tgt_is_tuple = False + if isinstance(tgt, tuple): + x, stage_embedding = tgt + tgt_is_tuple = True + else: + x, stage_embedding = tgt, None + + if self.norm_first: + x = x + self._sa_block( + self.norm1(x, stage_embedding), tgt_mask, tgt_key_padding_mask, use_rope=use_rope, + ) + x_mha_out, attn_map = self._mha_block( + self.norm2(x, stage_embedding), + memory, + memory_mask, + memory_key_padding_mask, + use_rope=use_rope, + ) + x = x + x_mha_out + x = x + self._ff_block(self.norm3(x, stage_embedding)) + else: + x = self.norm1( + x + self._sa_block(x, tgt_mask, tgt_key_padding_mask), + stage_embedding, + ) + x = self.norm2( + x + + self._mha_block( + x, memory, memory_mask, memory_key_padding_mask + ), + stage_embedding, + ) + x = self.norm3(x + self._ff_block(x), stage_embedding) + + if tgt_is_tuple: + return (x, stage_embedding) + return x, attn_map + + # self-attention block + def _sa_block( + self, + x: Tensor, + attn_mask: Optional[Tensor], + key_padding_mask: Optional[Tensor], + use_rope: bool = False, + ) -> Tensor: + x = self.self_attn( + x, + x, + x, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + need_weights=False, + use_rope=use_rope, + rope=self.rotary_emb + )[0] + return self.dropout1(x) + + # multihead attention block + def _mha_block( + self, + x: Tensor, + mem: Tensor, + attn_mask: Optional[Tensor], + key_padding_mask: Optional[Tensor], + use_rope: bool = False, + ) -> Tensor: + x = self.multihead_attn( + x, + mem, + mem, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + need_weights=False, + use_rope=use_rope, + rope=self.rotary_emb + )[0] + x, attn_map = x + return self.dropout2(x[0]), attn_map + + # feed forward block + def _ff_block(self, x: Tensor) -> Tensor: + x = self.linear2(self.dropout(self.activation(self.linear1(x)))) + return self.dropout3(x) + +class TransformerDecoder(nn.Module): + r"""TransformerDecoder is a stack of N decoder layers. Users can build the + BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters. + + Args: + decoder_layer: an instance of the TransformerDecoderLayer() class (required). + num_layers: the number of sub-decoder-layers in the decoder (required). + norm: the layer normalization component (optional). + enable_nested_tensor: if True, input will automatically convert to nested tensor + (and convert back on output). This will improve the overall performance of + TransformerDecoder when padding rate is high. Default: ``True`` (enabled). + + Examples:: + >>> decoder_layer = TransformerDecoderLayer(d_model=512, nhead=8) + >>> transformer_decoder = TransformerDecoder(decoder_layer, num_layers=6) + >>> tgt = torch.rand(10, 32, 512) + >>> memory = torch.rand(20, 32, 512) + >>> out = transformer_decoder(tgt, memory) + """ + __constants__ = ["norm"] + + def __init__(self, decoder_layer, num_layers, norm=None): + super(TransformerDecoder, self).__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward( + self, + tgt: Tensor, + memory: Tensor, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + return_attn: bool = False, + use_rope: bool = False, + ) -> Tensor: + r"""Pass the inputs (and mask) through the decoder layers in turn. + + Args: + tgt: the sequence to the decoder (required). + memory: the sequence from the last layer of the encoder (required). + tgt_mask: the mask for the tgt sequence (optional). + memory_mask: the mask for the memory sequence (optional). + tgt_key_padding_mask: the mask for the tgt keys per batch (optional). + memory_key_padding_mask: the mask for the memory keys per batch (optional). + return_attn: return cross attention maps of each layer (optional). + + Shape: + see the docs in Transformer class. + """ + attn_maps = [] + output = tgt + for mod in self.layers: + output, attn_map = mod( + output, + memory, + tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + use_rope=use_rope, + ) + if return_attn: + attn_maps.append(attn_map) + + if self.norm is not None: + output = self.norm(output) + + return output, attn_maps + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]: + if activation == "relu": + return F.relu + elif activation == "gelu": + return F.gelu + + raise RuntimeError( + "activation should be relu/gelu, not {}".format(activation) + ) diff --git a/webui.py b/webui.py new file mode 100644 index 0000000000000000000000000000000000000000..79b6c4a16c350c57c3d989493c51e09590d5a7ec --- /dev/null +++ b/webui.py @@ -0,0 +1,117 @@ +import gradio as gr +import torch +import torchaudio +import librosa +import numpy as np +import os +from huggingface_hub import hf_hub_download +import yaml +from modules.commons import recursive_munch, build_model + +# setup device +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +# load model +def load_model(repo_id): + ckpt_path = hf_hub_download(repo_id, "pytorch_model.bin", cache_dir="./checkpoints") + config_path = hf_hub_download(repo_id, "config.yml", cache_dir="./checkpoints") + + config = yaml.safe_load(open(config_path)) + model_params = recursive_munch(config['model_params']) + + if "redecoder" in repo_id: + model = build_model(model_params, stage="redecoder") + else: + model = build_model(model_params, stage="codec") + + ckpt_params = torch.load(ckpt_path, map_location="cpu") + + for key in model: + model[key].load_state_dict(ckpt_params[key]) + model[key].eval() + model[key].to(device) + + return model + + +# load models +codec_model = load_model("Plachta/FAcodec") +redecoder_model = load_model("Plachta/FAcodec-redecoder") + + +# preprocess audio +def preprocess_audio(audio_path, sr=24000): + audio = librosa.load(audio_path, sr=sr)[0] + # if audio has two channels, take the first one + if len(audio.shape) > 1: + audio = audio[0] + audio = audio[:sr * 30] # crop only the first 30 seconds + return torch.tensor(audio).unsqueeze(0).float().to(device) + + +# audio reconstruction function +@torch.no_grad() +def reconstruct_audio(audio): + source_audio = preprocess_audio(audio) + + z = codec_model.encoder(source_audio[None, ...]) + z, _, _, _, _ = codec_model.quantizer(z, source_audio[None, ...], n_c=2) + + reconstructed_wave = codec_model.decoder(z) + + return (24000, reconstructed_wave[0, 0].cpu().numpy()) + + +# voice conversion function +@torch.no_grad() +def voice_conversion(source_audio, target_audio): + source_audio = preprocess_audio(source_audio) + target_audio = preprocess_audio(target_audio) + + z = codec_model.encoder(source_audio[None, ...]) + z, _, _, _, timbre, codes = codec_model.quantizer(z, source_audio[None, ...], n_c=2, return_codes=True) + + z_target = codec_model.encoder(target_audio[None, ...]) + _, _, _, _, timbre_target, _ = codec_model.quantizer(z_target, target_audio[None, ...], n_c=2, return_codes=True) + + z_converted = redecoder_model.encoder(codes[0], codes[1], timbre_target, use_p_code=False, n_c=1) + converted_wave = redecoder_model.decoder(z_converted) + + return (24000, converted_wave[0, 0].cpu().numpy()) + + +# gradio interface +def gradio_interface(): + with gr.Blocks() as demo: + gr.Markdown( + "# FAcodec reconstruction and voice conversion" + "[![GitHub stars](https://img.shields.io/github/stars/username/repo-name.svg?style=social&label=Star&maxAge=2592000)](https://github.com/Plachtaa/FAcodec)" + "FAcodec from [Natural Speech 3](https://arxiv.org/pdf/2403.03100). The checkpoint used in this demo is trained on an improved pipeline of " + "where all kinds of annotations are not required, enabling the scale up of training data.
This model is " + "trained on 50k hours of data with over 1 million speakers, largely improved timbre diversity compared to " + "the [original FAcodec](https://huggingface.co/spaces/amphion/naturalspeech3_facodec)." + "

This project is supported by [Amphion](https://github.com/open-mmlab/Amphion)" + ) + + with gr.Tab("reconstruction"): + with gr.Row(): + input_audio = gr.Audio(type="filepath", label="Input audio") + output_audio = gr.Audio(label="Reconstructed audio") + reconstruct_btn = gr.Button("Reconstruct") + reconstruct_btn.click(reconstruct_audio, inputs=[input_audio], outputs=[output_audio]) + + with gr.Tab("voice conversion"): + with gr.Row(): + source_audio = gr.Audio(type="filepath", label="Source audio") + target_audio = gr.Audio(type="filepath", label="Reference audio") + converted_audio = gr.Audio(label="Converted audio") + convert_btn = gr.Button("Convert") + convert_btn.click(voice_conversion, inputs=[source_audio, target_audio], outputs=[converted_audio]) + + return demo + + +if __name__ == "__main__": + iface = gradio_interface() + iface.launch() \ No newline at end of file