diff --git a/modules/Enhancer/ResembleEnhance.py b/modules/Enhancer/ResembleEnhance.py index 1bf974b2ab1ba5b399d545b35b8f2ef5d3e23e6e..adbf23d68cdf85479d69c7c50ed846badd803f94 100644 --- a/modules/Enhancer/ResembleEnhance.py +++ b/modules/Enhancer/ResembleEnhance.py @@ -1,13 +1,8 @@ import os from typing import List - -try: - from resemble_enhance.enhancer.enhancer import Enhancer - from resemble_enhance.enhancer.hparams import HParams - from resemble_enhance.inference import inference -except: - HParams = dict - Enhancer = dict +from modules.repos_static.resemble_enhance.enhancer.enhancer import Enhancer +from modules.repos_static.resemble_enhance.enhancer.hparams import HParams +from modules.repos_static.resemble_enhance.inference import inference import torch diff --git a/modules/repos_static/__init__.py b/modules/repos_static/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/modules/repos_static/readme.md b/modules/repos_static/readme.md new file mode 100644 index 0000000000000000000000000000000000000000..9b64e0e376a50b3cf2c05071cca74b8807306811 --- /dev/null +++ b/modules/repos_static/readme.md @@ -0,0 +1,5 @@ +# repos static + +## resemble_enhance + +https://github.com/resemble-ai/resemble-enhance/tree/main diff --git a/modules/repos_static/resemble_enhance/__init__.py b/modules/repos_static/resemble_enhance/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/modules/repos_static/resemble_enhance/common.py b/modules/repos_static/resemble_enhance/common.py new file mode 100644 index 0000000000000000000000000000000000000000..dfe3980103294a7b57fce918ffa8592f7b935c50 --- /dev/null +++ b/modules/repos_static/resemble_enhance/common.py @@ -0,0 +1,55 @@ +import logging + +import torch +from torch import Tensor, nn + +logger = logging.getLogger(__name__) + + +class Normalizer(nn.Module): + def __init__(self, momentum=0.01, eps=1e-9): + super().__init__() + self.momentum = momentum + self.eps = eps + self.running_mean_unsafe: Tensor + self.running_var_unsafe: Tensor + self.register_buffer("running_mean_unsafe", torch.full([], torch.nan)) + self.register_buffer("running_var_unsafe", torch.full([], torch.nan)) + + @property + def started(self): + return not torch.isnan(self.running_mean_unsafe) + + @property + def running_mean(self): + if not self.started: + return torch.zeros_like(self.running_mean_unsafe) + return self.running_mean_unsafe + + @property + def running_std(self): + if not self.started: + return torch.ones_like(self.running_var_unsafe) + return (self.running_var_unsafe + self.eps).sqrt() + + @torch.no_grad() + def _ema(self, a: Tensor, x: Tensor): + return (1 - self.momentum) * a + self.momentum * x + + def update_(self, x): + if not self.started: + self.running_mean_unsafe = x.mean() + self.running_var_unsafe = x.var() + else: + self.running_mean_unsafe = self._ema(self.running_mean_unsafe, x.mean()) + self.running_var_unsafe = self._ema(self.running_var_unsafe, (x - self.running_mean).pow(2).mean()) + + def forward(self, x: Tensor, update=True): + if self.training and update: + self.update_(x) + self.stats = dict(mean=self.running_mean.item(), std=self.running_std.item()) + x = (x - self.running_mean) / self.running_std + return x + + def inverse(self, x: Tensor): + return x * self.running_std + self.running_mean diff --git a/modules/repos_static/resemble_enhance/data/__init__.py b/modules/repos_static/resemble_enhance/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6ebc6373ce4e90804e2f12828b7d9467a85656e5 --- /dev/null +++ b/modules/repos_static/resemble_enhance/data/__init__.py @@ -0,0 +1,48 @@ +import logging +import random + +from torch.utils.data import DataLoader + +from ..hparams import HParams +from .dataset import Dataset +from .utils import mix_fg_bg, rglob_audio_files + +logger = logging.getLogger(__name__) + + +def _create_datasets(hp: HParams, mode, val_size=10, seed=123): + paths = rglob_audio_files(hp.fg_dir) + logger.info(f"Found {len(paths)} audio files in {hp.fg_dir}") + + random.Random(seed).shuffle(paths) + train_paths = paths[:-val_size] + val_paths = paths[-val_size:] + + train_ds = Dataset(train_paths, hp, training=True, mode=mode) + val_ds = Dataset(val_paths, hp, training=False, mode=mode) + + logger.info(f"Train set: {len(train_ds)} samples - Val set: {len(val_ds)} samples") + + return train_ds, val_ds + + +def create_dataloaders(hp: HParams, mode): + train_ds, val_ds = _create_datasets(hp=hp, mode=mode) + + train_dl = DataLoader( + train_ds, + batch_size=hp.batch_size_per_gpu, + shuffle=True, + num_workers=hp.nj, + drop_last=True, + collate_fn=train_ds.collate_fn, + ) + val_dl = DataLoader( + val_ds, + batch_size=1, + shuffle=False, + num_workers=hp.nj, + drop_last=False, + collate_fn=val_ds.collate_fn, + ) + return train_dl, val_dl diff --git a/modules/repos_static/resemble_enhance/data/dataset.py b/modules/repos_static/resemble_enhance/data/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..0ba57c1736367345d171c2fc4feceefbfc25362a --- /dev/null +++ b/modules/repos_static/resemble_enhance/data/dataset.py @@ -0,0 +1,171 @@ +import logging +import random +from pathlib import Path + +import numpy as np +import torch +import torchaudio +import torchaudio.functional as AF +from torch.nn.utils.rnn import pad_sequence +from torch.utils.data import Dataset as DatasetBase + +from ..hparams import HParams +from .distorter import Distorter +from .utils import rglob_audio_files + +logger = logging.getLogger(__name__) + + +def _normalize(x): + return x / (np.abs(x).max() + 1e-7) + + +def _collate(batch, key, tensor=True, pad=True): + l = [d[key] for d in batch] + if l[0] is None: + return None + if tensor: + l = [torch.from_numpy(x) for x in l] + if pad: + assert tensor, "Can't pad non-tensor" + l = pad_sequence(l, batch_first=True) + return l + + +def praat_augment(wav, sr): + try: + import parselmouth + except ImportError: + raise ImportError("Please install parselmouth>=0.5.0 to use Praat augmentation") + # "praat-parselmouth @ git+https://github.com/YannickJadoul/Parselmouth@0bbcca69705ed73322f3712b19d71bb3694b2540", + # https://github.com/YannickJadoul/Parselmouth/issues/68 + # note that this function may hang if the praat version is 0.4.3 + assert wav.ndim == 1, f"wav.ndim must be 1 but got {wav.ndim}" + sound = parselmouth.Sound(wav, sr) + formant_shift_ratio = random.uniform(1.1, 1.5) + pitch_range_factor = random.uniform(0.5, 2.0) + sound = parselmouth.praat.call(sound, "Change gender", 75, 600, formant_shift_ratio, 0, pitch_range_factor, 1.0) + wav = np.array(sound.values)[0].astype(np.float32) + return wav + + +class Dataset(DatasetBase): + def __init__( + self, + fg_paths: list[Path], + hp: HParams, + training=True, + max_retries=100, + silent_fg_prob=0.01, + mode=False, + ): + super().__init__() + + assert mode in ("enhancer", "denoiser"), f"Invalid mode: {mode}" + + self.hp = hp + self.fg_paths = fg_paths + self.bg_paths = rglob_audio_files(hp.bg_dir) + + if len(self.fg_paths) == 0: + raise ValueError(f"No foreground audio files found in {hp.fg_dir}") + + if len(self.bg_paths) == 0: + raise ValueError(f"No background audio files found in {hp.bg_dir}") + + logger.info(f"Found {len(self.fg_paths)} foreground files and {len(self.bg_paths)} background files") + + self.training = training + self.max_retries = max_retries + self.silent_fg_prob = silent_fg_prob + + self.mode = mode + self.distorter = Distorter(hp, training=training, mode=mode) + + def _load_wav(self, path, length=None, random_crop=True): + wav, sr = torchaudio.load(path) + + wav = AF.resample( + waveform=wav, + orig_freq=sr, + new_freq=self.hp.wav_rate, + lowpass_filter_width=64, + rolloff=0.9475937167399596, + resampling_method="sinc_interp_kaiser", + beta=14.769656459379492, + ) + + wav = wav.float().numpy() + + if wav.ndim == 2: + wav = np.mean(wav, axis=0) + + if length is None and self.training: + length = int(self.hp.training_seconds * self.hp.wav_rate) + + if length is not None: + if random_crop: + start = random.randint(0, max(0, len(wav) - length)) + wav = wav[start : start + length] + else: + wav = wav[:length] + + if length is not None and len(wav) < length: + wav = np.pad(wav, (0, length - len(wav))) + + wav = _normalize(wav) + + return wav + + def _getitem_unsafe(self, index: int): + fg_path = self.fg_paths[index] + + if self.training and random.random() < self.silent_fg_prob: + fg_wav = np.zeros(int(self.hp.training_seconds * self.hp.wav_rate), dtype=np.float32) + else: + fg_wav = self._load_wav(fg_path) + if random.random() < self.hp.praat_augment_prob and self.training: + fg_wav = praat_augment(fg_wav, self.hp.wav_rate) + + if self.hp.load_fg_only: + bg_wav = None + fg_dwav = None + bg_dwav = None + else: + fg_dwav = _normalize(self.distorter(fg_wav, self.hp.wav_rate)).astype(np.float32) + if self.training: + bg_path = random.choice(self.bg_paths) + else: + # Deterministic for validation + bg_path = self.bg_paths[index % len(self.bg_paths)] + bg_wav = self._load_wav(bg_path, length=len(fg_wav), random_crop=self.training) + bg_dwav = _normalize(self.distorter(bg_wav, self.hp.wav_rate)).astype(np.float32) + + return dict( + fg_wav=fg_wav, + bg_wav=bg_wav, + fg_dwav=fg_dwav, + bg_dwav=bg_dwav, + ) + + def __getitem__(self, index: int): + for i in range(self.max_retries): + try: + return self._getitem_unsafe(index) + except Exception as e: + if i == self.max_retries - 1: + raise RuntimeError(f"Failed to load {self.fg_paths[index]} after {self.max_retries} retries") from e + logger.debug(f"Error loading {self.fg_paths[index]}: {e}, skipping") + index = np.random.randint(0, len(self)) + + def __len__(self): + return len(self.fg_paths) + + @staticmethod + def collate_fn(batch): + return dict( + fg_wavs=_collate(batch, "fg_wav"), + bg_wavs=_collate(batch, "bg_wav"), + fg_dwavs=_collate(batch, "fg_dwav"), + bg_dwavs=_collate(batch, "bg_dwav"), + ) diff --git a/modules/repos_static/resemble_enhance/data/distorter/__init__.py b/modules/repos_static/resemble_enhance/data/distorter/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ad059fd9af40fbfac1aceebf39fac6a09562c7de --- /dev/null +++ b/modules/repos_static/resemble_enhance/data/distorter/__init__.py @@ -0,0 +1 @@ +from .distorter import Distorter diff --git a/modules/repos_static/resemble_enhance/data/distorter/base.py b/modules/repos_static/resemble_enhance/data/distorter/base.py new file mode 100644 index 0000000000000000000000000000000000000000..d43d84fa840dd25804d9c5e5dc9673f65fdc5b94 --- /dev/null +++ b/modules/repos_static/resemble_enhance/data/distorter/base.py @@ -0,0 +1,104 @@ +import itertools +import os +import random +import time +import warnings + +import numpy as np + +_DEBUG = bool(os.environ.get("DEBUG", False)) + + +class Effect: + def apply(self, wav: np.ndarray, sr: int): + """ + Args: + wav: (T) + sr: sample rate + Returns: + wav: (T) with the same sample rate of `sr` + """ + raise NotImplementedError + + def __call__(self, wav: np.ndarray, sr: int): + """ + Args: + wav: (T) + sr: sample rate + Returns: + wav: (T) with the same sample rate of `sr` + """ + assert len(wav.shape) == 1, wav.shape + + if _DEBUG: + start = time.time() + else: + start = None + + shape = wav.shape + assert wav.ndim == 1, f"{self}: Expected wav.ndim == 1, got {wav.ndim}." + wav = self.apply(wav, sr) + assert shape == wav.shape, f"{self}: {shape} != {wav.shape}." + + if start is not None: + end = time.time() + print(f"{self.__class__.__name__}: {end - start:.3f} sec") + + return wav + + +class Chain(Effect): + def __init__(self, *effects): + super().__init__() + + self.effects = effects + + def apply(self, wav, sr): + for effect in self.effects: + wav = effect(wav, sr) + return wav + + +class Maybe(Effect): + def __init__(self, prob, effect): + super().__init__() + + self.prob = prob + self.effect = effect + + if _DEBUG: + warnings.warn("DEBUG mode is on. Maybe -> Must.") + self.prob = 1 + + def apply(self, wav, sr): + if random.random() > self.prob: + return wav + return self.effect(wav, sr) + + +class Choice(Effect): + def __init__(self, *effects, **kwargs): + super().__init__() + self.effects = effects + self.kwargs = kwargs + + def apply(self, wav, sr): + return np.random.choice(self.effects, **self.kwargs)(wav, sr) + + +class Permutation(Effect): + def __init__(self, *effects, n: int | None = None): + super().__init__() + self.effects = effects + self.n = n + + def apply(self, wav, sr): + if self.n is None: + n = np.random.binomial(len(self.effects), 0.5) + else: + n = self.n + if n == 0: + return wav + perms = itertools.permutations(self.effects, n) + effects = random.choice(list(perms)) + return Chain(*effects)(wav, sr) diff --git a/modules/repos_static/resemble_enhance/data/distorter/custom.py b/modules/repos_static/resemble_enhance/data/distorter/custom.py new file mode 100644 index 0000000000000000000000000000000000000000..28428f7789cebb2d174c581111711f4d73f6565b --- /dev/null +++ b/modules/repos_static/resemble_enhance/data/distorter/custom.py @@ -0,0 +1,85 @@ +import logging +import random +from dataclasses import dataclass +from functools import cached_property +from pathlib import Path + +import librosa +import numpy as np +from scipy import signal + +from ..utils import walk_paths +from .base import Effect + +_logger = logging.getLogger(__name__) + + +@dataclass +class RandomRIR(Effect): + rir_dir: Path | None + rir_rate: int = 44_000 + rir_suffix: str = ".npy" + deterministic: bool = False + + @cached_property + def rir_paths(self): + if self.rir_dir is None: + return [] + return list(walk_paths(self.rir_dir, self.rir_suffix)) + + def _sample_rir(self): + if len(self.rir_paths) == 0: + return None + + if self.deterministic: + rir_path = self.rir_paths[0] + else: + rir_path = random.choice(self.rir_paths) + + rir = np.squeeze(np.load(rir_path)) + assert isinstance(rir, np.ndarray) + + return rir + + def apply(self, wav, sr): + # ref: https://github.com/haoheliu/voicefixer_main/blob/b06e07c945ac1d309b8a57ddcd599ca376b98cd9/dataloaders/augmentation/magical_effects.py#L158 + + if len(self.rir_paths) == 0: + return wav + + length = len(wav) + + wav = librosa.resample(wav, orig_sr=sr, target_sr=self.rir_rate, res_type="kaiser_fast") + rir = self._sample_rir() + + wav = signal.convolve(wav, rir, mode="same") + + actlev = np.max(np.abs(wav)) + if actlev > 0.99: + wav = (wav / actlev) * 0.98 + + wav = librosa.resample(wav, orig_sr=self.rir_rate, target_sr=sr, res_type="kaiser_fast") + + if abs(length - len(wav)) > 10: + _logger.warning(f"length mismatch: {length} vs {len(wav)}") + + if length > len(wav): + wav = np.pad(wav, (0, length - len(wav))) + elif length < len(wav): + wav = wav[:length] + + return wav + + +class RandomGaussianNoise(Effect): + def __init__(self, alpha_range=(0.8, 1)): + super().__init__() + self.alpha_range = alpha_range + + def apply(self, wav, sr): + noise = np.random.randn(*wav.shape) + noise_energy = np.sum(noise**2) + wav_energy = np.sum(wav**2) + noise = noise * np.sqrt(wav_energy / noise_energy) + alpha = random.uniform(*self.alpha_range) + return wav * alpha + noise * (1 - alpha) diff --git a/modules/repos_static/resemble_enhance/data/distorter/distorter.py b/modules/repos_static/resemble_enhance/data/distorter/distorter.py new file mode 100644 index 0000000000000000000000000000000000000000..7f787a8cdbf941ae7c8e3ac925d1aa66dad5e978 --- /dev/null +++ b/modules/repos_static/resemble_enhance/data/distorter/distorter.py @@ -0,0 +1,32 @@ +from ...hparams import HParams +from .base import Chain, Choice, Permutation +from .custom import RandomGaussianNoise, RandomRIR + + +class Distorter(Chain): + def __init__(self, hp: HParams, training: bool = False, mode: str = "enhancer"): + # Lazy import + from .sox import RandomBandpassDistorter, RandomEqualizer, RandomLowpassDistorter, RandomOverdrive, RandomReverb + + if training: + permutation = Permutation( + RandomRIR(hp.rir_dir), + RandomReverb(), + RandomGaussianNoise(), + RandomOverdrive(), + RandomEqualizer(), + Choice( + RandomLowpassDistorter(), + RandomBandpassDistorter(), + ), + ) + if mode == "denoiser": + super().__init__(permutation) + else: + # 80%: distortion, 20%: clean + super().__init__(Choice(permutation, Chain(), p=[0.8, 0.2])) + else: + super().__init__( + RandomRIR(hp.rir_dir, deterministic=True), + RandomReverb(deterministic=True), + ) diff --git a/modules/repos_static/resemble_enhance/data/distorter/sox.py b/modules/repos_static/resemble_enhance/data/distorter/sox.py new file mode 100644 index 0000000000000000000000000000000000000000..92a2d74033d33b975141c1ece7ac5619d1dfcc39 --- /dev/null +++ b/modules/repos_static/resemble_enhance/data/distorter/sox.py @@ -0,0 +1,176 @@ +import logging +import os +import random +import warnings +from functools import partial + +import numpy as np +import torch + +try: + import augment +except ImportError: + raise ImportError( + "augment is not installed, please install it first using:" + "\npip install git+https://github.com/facebookresearch/WavAugment@54afcdb00ccc852c2f030f239f8532c9562b550e" + ) + +from .base import Effect + +_logger = logging.getLogger(__name__) +_DEBUG = bool(os.environ.get("DEBUG", False)) + + +class AttachableEffect(Effect): + def attach(self, chain: augment.EffectChain) -> augment.EffectChain: + raise NotImplementedError + + def apply(self, wav: np.ndarray, sr: int): + chain = augment.EffectChain() + chain = self.attach(chain) + tensor = torch.from_numpy(wav)[None].float() # (1, T) + tensor = chain.apply(tensor, src_info={"rate": sr}, target_info={"channels": 1, "rate": sr}) + wav = tensor.numpy()[0] # (T,) + return wav + + +class SoxEffect(AttachableEffect): + def __init__(self, effect_name: str, *args, **kwargs): + self.effect_name = effect_name + self.args = args + self.kwargs = kwargs + + def attach(self, chain: augment.EffectChain) -> augment.EffectChain: + _logger.debug(f"Attaching {self.effect_name} with {self.args} and {self.kwargs}") + if not hasattr(chain, self.effect_name): + raise ValueError(f"EffectChain has no attribute {self.effect_name}") + return getattr(chain, self.effect_name)(*self.args, **self.kwargs) + + +class Maybe(AttachableEffect): + """ + Attach an effect with a probability. + """ + + def __init__(self, prob: float, effect: AttachableEffect): + self.prob = prob + self.effect = effect + if _DEBUG: + warnings.warn("DEBUG mode is on. Maybe -> Must.") + self.prob = 1 + + def attach(self, chain: augment.EffectChain) -> augment.EffectChain: + if random.random() > self.prob: + return chain + return self.effect.attach(chain) + + +class Chain(AttachableEffect): + """ + Attach a chain of effects. + """ + + def __init__(self, *effects: AttachableEffect): + self.effects = effects + + def attach(self, chain: augment.EffectChain) -> augment.EffectChain: + for effect in self.effects: + chain = effect.attach(chain) + return chain + + +class Choice(AttachableEffect): + """ + Attach one of the effects randomly. + """ + + def __init__(self, *effects: AttachableEffect): + self.effects = effects + + def attach(self, chain: augment.EffectChain) -> augment.EffectChain: + return random.choice(self.effects).attach(chain) + + +class Generator: + def __call__(self) -> str: + raise NotImplementedError + + +class Uniform(Generator): + def __init__(self, low, high): + self.low = low + self.high = high + + def __call__(self) -> str: + return str(random.uniform(self.low, self.high)) + + +class Randint(Generator): + def __init__(self, low, high): + self.low = low + self.high = high + + def __call__(self) -> str: + return str(random.randint(self.low, self.high)) + + +class Concat(Generator): + def __init__(self, *parts: Generator | str): + self.parts = parts + + def __call__(self): + return "".join([part if isinstance(part, str) else part() for part in self.parts]) + + +class RandomLowpassDistorter(SoxEffect): + def __init__(self, low=2000, high=16000): + super().__init__("sinc", "-n", Randint(50, 200), Concat("-", Uniform(low, high))) + + +class RandomBandpassDistorter(SoxEffect): + def __init__(self, low=100, high=1000, min_width=2000, max_width=4000): + super().__init__("sinc", "-n", Randint(50, 200), partial(self._fn, low, high, min_width, max_width)) + + @staticmethod + def _fn(low, high, min_width, max_width): + start = random.randint(low, high) + stop = start + random.randint(min_width, max_width) + return f"{start}-{stop}" + + +class RandomEqualizer(SoxEffect): + def __init__(self, low=100, high=4000, q_low=1, q_high=5, db_low: int = -30, db_high: int = 30): + super().__init__( + "equalizer", + Uniform(low, high), + lambda: f"{random.randint(q_low, q_high)}q", + lambda: random.randint(db_low, db_high), + ) + + +class RandomOverdrive(SoxEffect): + def __init__(self, gain_low=5, gain_high=40, colour_low=20, colour_high=80): + super().__init__("overdrive", Uniform(gain_low, gain_high), Uniform(colour_low, colour_high)) + + +class RandomReverb(Chain): + def __init__(self, deterministic=False): + super().__init__( + SoxEffect( + "reverb", + Uniform(50, 50) if deterministic else Uniform(0, 100), + Uniform(50, 50) if deterministic else Uniform(0, 100), + Uniform(50, 50) if deterministic else Uniform(0, 100), + ), + SoxEffect("channels", 1), + ) + + +class Flanger(SoxEffect): + def __init__(self): + super().__init__("flanger") + + +class Phaser(SoxEffect): + def __init__(self): + super().__init__("phaser") diff --git a/modules/repos_static/resemble_enhance/data/utils.py b/modules/repos_static/resemble_enhance/data/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..77f59d345b75cac76c6c423c734ae9c70a626590 --- /dev/null +++ b/modules/repos_static/resemble_enhance/data/utils.py @@ -0,0 +1,43 @@ +from pathlib import Path +from typing import Callable + +from torch import Tensor + + +def walk_paths(root, suffix): + for path in Path(root).iterdir(): + if path.is_dir(): + yield from walk_paths(path, suffix) + elif path.suffix == suffix: + yield path + + +def rglob_audio_files(path: Path): + return list(walk_paths(path, ".wav")) + list(walk_paths(path, ".flac")) + + +def mix_fg_bg(fg: Tensor, bg: Tensor, alpha: float | Callable[..., float] = 0.5, eps=1e-7): + """ + Args: + fg: (b, t) + bg: (b, t) + """ + assert bg.shape == fg.shape, f"bg.shape != fg.shape: {bg.shape} != {fg.shape}" + fg = fg / (fg.abs().max(dim=-1, keepdim=True).values + eps) + bg = bg / (bg.abs().max(dim=-1, keepdim=True).values + eps) + + fg_energy = fg.pow(2).sum(dim=-1, keepdim=True) + bg_energy = bg.pow(2).sum(dim=-1, keepdim=True) + + fg = fg / (fg_energy + eps).sqrt() + bg = bg / (bg_energy + eps).sqrt() + + if callable(alpha): + alpha = alpha() + + assert 0 <= alpha <= 1, f"alpha must be between 0 and 1: {alpha}" + + mx = alpha * fg + (1 - alpha) * bg + mx = mx / (mx.abs().max(dim=-1, keepdim=True).values + eps) + + return mx diff --git a/modules/repos_static/resemble_enhance/denoiser/__init__.py b/modules/repos_static/resemble_enhance/denoiser/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/modules/repos_static/resemble_enhance/denoiser/__main__.py b/modules/repos_static/resemble_enhance/denoiser/__main__.py new file mode 100644 index 0000000000000000000000000000000000000000..86188661c35d10721c94dc21f88f4babf45f6f7d --- /dev/null +++ b/modules/repos_static/resemble_enhance/denoiser/__main__.py @@ -0,0 +1,30 @@ +import argparse +from pathlib import Path + +import torch +import torchaudio + +from .inference import denoise + + +@torch.inference_mode() +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("in_dir", type=Path, help="Path to input audio folder") + parser.add_argument("out_dir", type=Path, help="Output folder") + parser.add_argument("--run_dir", type=Path, default="runs/denoiser", help="Path to run folder") + parser.add_argument("--suffix", type=str, default=".wav", help="File suffix") + parser.add_argument("--device", type=str, default="cuda", help="Device") + args = parser.parse_args() + + for path in args.in_dir.glob(f"**/*{args.suffix}"): + print(f"Processing {path} ..") + dwav, sr = torchaudio.load(path) + hwav, sr = denoise(dwav[0], sr, args.run_dir, args.device) + out_path = args.out_dir / path.relative_to(args.in_dir) + out_path.parent.mkdir(parents=True, exist_ok=True) + torchaudio.save(out_path, hwav[None], sr) + + +if __name__ == "__main__": + main() diff --git a/modules/repos_static/resemble_enhance/denoiser/denoiser.py b/modules/repos_static/resemble_enhance/denoiser/denoiser.py new file mode 100644 index 0000000000000000000000000000000000000000..8b1d49cdc257a84073fd43b205f5f497386ce80f --- /dev/null +++ b/modules/repos_static/resemble_enhance/denoiser/denoiser.py @@ -0,0 +1,181 @@ +import logging + +import torch +import torch.nn.functional as F +from torch import Tensor, nn + +from ..melspec import MelSpectrogram +from .hparams import HParams +from .unet import UNet + +logger = logging.getLogger(__name__) + + +def _normalize(x: Tensor) -> Tensor: + return x / (x.abs().max(dim=-1, keepdim=True).values + 1e-7) + + +class Denoiser(nn.Module): + @property + def stft_cfg(self) -> dict: + hop_size = self.hp.hop_size + return dict(hop_length=hop_size, n_fft=hop_size * 4, win_length=hop_size * 4) + + @property + def n_fft(self): + return self.stft_cfg["n_fft"] + + @property + def eps(self): + return 1e-7 + + def __init__(self, hp: HParams): + super().__init__() + self.hp = hp + self.net = UNet(input_dim=3, output_dim=3) + self.mel_fn = MelSpectrogram(hp) + + self.dummy: Tensor + self.register_buffer("dummy", torch.zeros(1), persistent=False) + + def to_mel(self, x: Tensor, drop_last=True): + """ + Args: + x: (b t), wavs + Returns: + o: (b c t), mels + """ + if drop_last: + return self.mel_fn(x)[..., :-1] # (b d t) + return self.mel_fn(x) + + def _stft(self, x): + """ + Args: + x: (b t) + Returns: + mag: (b f t) in [0, inf) + cos: (b f t) in [-1, 1] + sin: (b f t) in [-1, 1] + """ + dtype = x.dtype + device = x.device + + if x.is_mps: + x = x.cpu() + + window = torch.hann_window(self.stft_cfg["win_length"], device=x.device) + s = torch.stft(x.float(), **self.stft_cfg, window=window, return_complex=True) # (b f t+1) + + s = s[..., :-1] # (b f t) + + mag = s.abs() # (b f t) + + phi = s.angle() # (b f t) + cos = phi.cos() # (b f t) + sin = phi.sin() # (b f t) + + mag = mag.to(dtype=dtype, device=device) + cos = cos.to(dtype=dtype, device=device) + sin = sin.to(dtype=dtype, device=device) + + return mag, cos, sin + + def _istft(self, mag: Tensor, cos: Tensor, sin: Tensor): + """ + Args: + mag: (b f t) in [0, inf) + cos: (b f t) in [-1, 1] + sin: (b f t) in [-1, 1] + Returns: + x: (b t) + """ + device = mag.device + dtype = mag.dtype + + if mag.is_mps: + mag = mag.cpu() + cos = cos.cpu() + sin = sin.cpu() + + real = mag * cos # (b f t) + imag = mag * sin # (b f t) + + s = torch.complex(real, imag) # (b f t) + + if s.isnan().any(): + logger.warning("NaN detected in ISTFT input.") + + s = F.pad(s, (0, 1), "replicate") # (b f t+1) + + window = torch.hann_window(self.stft_cfg["win_length"], device=s.device) + x = torch.istft(s, **self.stft_cfg, window=window, return_complex=False) + + if x.isnan().any(): + logger.warning("NaN detected in ISTFT output, set to zero.") + x = torch.where(x.isnan(), torch.zeros_like(x), x) + + x = x.to(dtype=dtype, device=device) + + return x + + def _magphase(self, real, imag): + mag = (real.pow(2) + imag.pow(2) + self.eps).sqrt() + cos = real / mag + sin = imag / mag + return mag, cos, sin + + def _predict(self, mag: Tensor, cos: Tensor, sin: Tensor): + """ + Args: + mag: (b f t) + cos: (b f t) + sin: (b f t) + Returns: + mag_mask: (b f t) in [0, 1], magnitude mask + cos_res: (b f t) in [-1, 1], phase residual + sin_res: (b f t) in [-1, 1], phase residual + """ + x = torch.stack([mag, cos, sin], dim=1) # (b 3 f t) + mag_mask, real, imag = self.net(x).unbind(1) # (b 3 f t) + mag_mask = mag_mask.sigmoid() # (b f t) + real = real.tanh() # (b f t) + imag = imag.tanh() # (b f t) + _, cos_res, sin_res = self._magphase(real, imag) # (b f t) + return mag_mask, sin_res, cos_res + + def _separate(self, mag, cos, sin, mag_mask, cos_res, sin_res): + """Ref: https://audio-agi.github.io/Separate-Anything-You-Describe/AudioSep_arXiv.pdf""" + sep_mag = F.relu(mag * mag_mask) + sep_cos = cos * cos_res - sin * sin_res + sep_sin = sin * cos_res + cos * sin_res + return sep_mag, sep_cos, sep_sin + + def forward(self, x: Tensor, y: Tensor | None = None): + """ + Args: + x: (b t), a mixed audio + y: (b t), a fg audio + """ + assert x.dim() == 2, f"Expected (b t), got {x.size()}" + x = x.to(self.dummy) + x = _normalize(x) + + if y is not None: + assert y.dim() == 2, f"Expected (b t), got {y.size()}" + y = y.to(self.dummy) + y = _normalize(y) + + mag, cos, sin = self._stft(x) # (b 2f t) + mag_mask, sin_res, cos_res = self._predict(mag, cos, sin) + sep_mag, sep_cos, sep_sin = self._separate(mag, cos, sin, mag_mask, cos_res, sin_res) + + o = self._istft(sep_mag, sep_cos, sep_sin) + + npad = x.shape[-1] - o.shape[-1] + o = F.pad(o, (0, npad)) + + if y is not None: + self.losses = dict(l1=F.l1_loss(o, y)) + + return o diff --git a/modules/repos_static/resemble_enhance/denoiser/hparams.py b/modules/repos_static/resemble_enhance/denoiser/hparams.py new file mode 100644 index 0000000000000000000000000000000000000000..72ec1e5680e1f3323406f1206caf7945e0fb7b3b --- /dev/null +++ b/modules/repos_static/resemble_enhance/denoiser/hparams.py @@ -0,0 +1,9 @@ +from dataclasses import dataclass + +from ..hparams import HParams as HParamsBase + + +@dataclass(frozen=True) +class HParams(HParamsBase): + batch_size_per_gpu: int = 128 + distort_prob: float = 0.5 diff --git a/modules/repos_static/resemble_enhance/denoiser/inference.py b/modules/repos_static/resemble_enhance/denoiser/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..9111321baaa428d46fa2d5f789fc437654c50f8b --- /dev/null +++ b/modules/repos_static/resemble_enhance/denoiser/inference.py @@ -0,0 +1,31 @@ +import logging +from functools import cache + +import torch + +from ..denoiser.denoiser import Denoiser + +from ..inference import inference +from .hparams import HParams + +logger = logging.getLogger(__name__) + + +@cache +def load_denoiser(run_dir, device): + if run_dir is None: + return Denoiser(HParams()) + hp = HParams.load(run_dir) + denoiser = Denoiser(hp) + path = run_dir / "ds" / "G" / "default" / "mp_rank_00_model_states.pt" + state_dict = torch.load(path, map_location="cpu")["module"] + denoiser.load_state_dict(state_dict) + denoiser.eval() + denoiser.to(device) + return denoiser + + +@torch.inference_mode() +def denoise(dwav, sr, run_dir, device): + denoiser = load_denoiser(run_dir, device) + return inference(model=denoiser, dwav=dwav, sr=sr, device=device) diff --git a/modules/repos_static/resemble_enhance/denoiser/unet.py b/modules/repos_static/resemble_enhance/denoiser/unet.py new file mode 100644 index 0000000000000000000000000000000000000000..2b8f78309ce03f776c4a6d9f28f1f9763c94ea7a --- /dev/null +++ b/modules/repos_static/resemble_enhance/denoiser/unet.py @@ -0,0 +1,144 @@ +import torch.nn.functional as F +from torch import nn + + +class PreactResBlock(nn.Sequential): + def __init__(self, dim): + super().__init__( + nn.GroupNorm(dim // 16, dim), + nn.GELU(), + nn.Conv2d(dim, dim, 3, padding=1), + nn.GroupNorm(dim // 16, dim), + nn.GELU(), + nn.Conv2d(dim, dim, 3, padding=1), + ) + + def forward(self, x): + return x + super().forward(x) + + +class UNetBlock(nn.Module): + def __init__(self, input_dim, output_dim=None, scale_factor=1.0): + super().__init__() + if output_dim is None: + output_dim = input_dim + self.pre_conv = nn.Conv2d(input_dim, output_dim, 3, padding=1) + self.res_block1 = PreactResBlock(output_dim) + self.res_block2 = PreactResBlock(output_dim) + self.downsample = self.upsample = nn.Identity() + if scale_factor > 1: + self.upsample = nn.Upsample(scale_factor=scale_factor) + elif scale_factor < 1: + self.downsample = nn.Upsample(scale_factor=scale_factor) + + def forward(self, x, h=None): + """ + Args: + x: (b c h w), last output + h: (b c h w), skip output + Returns: + o: (b c h w), output + s: (b c h w), skip output + """ + x = self.upsample(x) + if h is not None: + assert x.shape == h.shape, f"{x.shape} != {h.shape}" + x = x + h + x = self.pre_conv(x) + x = self.res_block1(x) + x = self.res_block2(x) + return self.downsample(x), x + + +class UNet(nn.Module): + def __init__(self, input_dim, output_dim, hidden_dim=16, num_blocks=4, num_middle_blocks=2): + super().__init__() + self.input_dim = input_dim + self.output_dim = output_dim + self.input_proj = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) + self.encoder_blocks = nn.ModuleList( + [ + UNetBlock(input_dim=hidden_dim * 2**i, output_dim=hidden_dim * 2 ** (i + 1), scale_factor=0.5) + for i in range(num_blocks) + ] + ) + self.middle_blocks = nn.ModuleList( + [UNetBlock(input_dim=hidden_dim * 2**num_blocks) for _ in range(num_middle_blocks)] + ) + self.decoder_blocks = nn.ModuleList( + [ + UNetBlock(input_dim=hidden_dim * 2 ** (i + 1), output_dim=hidden_dim * 2**i, scale_factor=2) + for i in reversed(range(num_blocks)) + ] + ) + self.head = nn.Sequential( + nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), + nn.GELU(), + nn.Conv2d(hidden_dim, output_dim, 1), + ) + + @property + def scale_factor(self): + return 2 ** len(self.encoder_blocks) + + def pad_to_fit(self, x): + """ + Args: + x: (b c h w), input + Returns: + x: (b c h' w'), padded input + """ + hpad = (self.scale_factor - x.shape[2] % self.scale_factor) % self.scale_factor + wpad = (self.scale_factor - x.shape[3] % self.scale_factor) % self.scale_factor + return F.pad(x, (0, wpad, 0, hpad)) + + def forward(self, x): + """ + Args: + x: (b c h w), input + Returns: + o: (b c h w), output + """ + shape = x.shape + + x = self.pad_to_fit(x) + x = self.input_proj(x) + + s_list = [] + for block in self.encoder_blocks: + x, s = block(x) + s_list.append(s) + + for block in self.middle_blocks: + x, _ = block(x) + + for block, s in zip(self.decoder_blocks, reversed(s_list)): + x, _ = block(x, s) + + x = self.head(x) + x = x[..., : shape[2], : shape[3]] + + return x + + def test(self, shape=(3, 512, 256)): + import ptflops + + macs, params = ptflops.get_model_complexity_info( + self, + shape, + as_strings=True, + print_per_layer_stat=True, + verbose=True, + ) + + print(f"macs: {macs}") + print(f"params: {params}") + + +def main(): + model = UNet(3, 3) + model.test() + + +if __name__ == "__main__": + main() diff --git a/modules/repos_static/resemble_enhance/enhancer/__init__.py b/modules/repos_static/resemble_enhance/enhancer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/modules/repos_static/resemble_enhance/enhancer/__main__.py b/modules/repos_static/resemble_enhance/enhancer/__main__.py new file mode 100644 index 0000000000000000000000000000000000000000..0c1ad5ce68497c73756585009a59ea225c89ab94 --- /dev/null +++ b/modules/repos_static/resemble_enhance/enhancer/__main__.py @@ -0,0 +1,129 @@ +import argparse +import random +import time +from pathlib import Path + +import torch +import torchaudio +from tqdm import tqdm + +from .inference import denoise, enhance + + +@torch.inference_mode() +def main(): + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument("in_dir", type=Path, help="Path to input audio folder") + parser.add_argument("out_dir", type=Path, help="Output folder") + parser.add_argument( + "--run_dir", + type=Path, + default=None, + help="Path to the enhancer run folder, if None, use the default model", + ) + parser.add_argument( + "--suffix", + type=str, + default=".wav", + help="Audio file suffix", + ) + parser.add_argument( + "--device", + type=str, + default="cuda", + help="Device to use for computation, recommended to use CUDA", + ) + parser.add_argument( + "--denoise_only", + action="store_true", + help="Only apply denoising without enhancement", + ) + parser.add_argument( + "--lambd", + type=float, + default=1.0, + help="Denoise strength for enhancement (0.0 to 1.0)", + ) + parser.add_argument( + "--tau", + type=float, + default=0.5, + help="CFM prior temperature (0.0 to 1.0)", + ) + parser.add_argument( + "--solver", + type=str, + default="midpoint", + choices=["midpoint", "rk4", "euler"], + help="Numerical solver to use", + ) + parser.add_argument( + "--nfe", + type=int, + default=64, + help="Number of function evaluations", + ) + parser.add_argument( + "--parallel_mode", + action="store_true", + help="Shuffle the audio paths and skip the existing ones, enabling multiple jobs to run in parallel", + ) + + args = parser.parse_args() + + device = args.device + + if device == "cuda" and not torch.cuda.is_available(): + print("CUDA is not available but --device is set to cuda, using CPU instead") + device = "cpu" + + start_time = time.perf_counter() + + run_dir = args.run_dir + + paths = sorted(args.in_dir.glob(f"**/*{args.suffix}")) + + if args.parallel_mode: + random.shuffle(paths) + + if len(paths) == 0: + print(f"No {args.suffix} files found in the following path: {args.in_dir}") + return + + pbar = tqdm(paths) + + for path in pbar: + out_path = args.out_dir / path.relative_to(args.in_dir) + if args.parallel_mode and out_path.exists(): + continue + pbar.set_description(f"Processing {out_path}") + dwav, sr = torchaudio.load(path) + dwav = dwav.mean(0) + if args.denoise_only: + hwav, sr = denoise( + dwav=dwav, + sr=sr, + device=device, + run_dir=args.run_dir, + ) + else: + hwav, sr = enhance( + dwav=dwav, + sr=sr, + device=device, + nfe=args.nfe, + solver=args.solver, + lambd=args.lambd, + tau=args.tau, + run_dir=run_dir, + ) + out_path.parent.mkdir(parents=True, exist_ok=True) + torchaudio.save(out_path, hwav[None], sr) + + # Cool emoji effect saying the job is done + elapsed_time = time.perf_counter() - start_time + print(f"🌟 Enhancement done! {len(paths)} files processed in {elapsed_time:.2f}s") + + +if __name__ == "__main__": + main() diff --git a/modules/repos_static/resemble_enhance/enhancer/download.py b/modules/repos_static/resemble_enhance/enhancer/download.py new file mode 100644 index 0000000000000000000000000000000000000000..614b9a4b4f9a1a10b79f12ca1a25821247ea2a16 --- /dev/null +++ b/modules/repos_static/resemble_enhance/enhancer/download.py @@ -0,0 +1,30 @@ +import logging +from pathlib import Path + +import torch + +RUN_NAME = "enhancer_stage2" + +logger = logging.getLogger(__name__) + + +def get_source_url(relpath): + return f"https://huggingface.co/ResembleAI/resemble-enhance/resolve/main/{RUN_NAME}/{relpath}?download=true" + + +def get_target_path(relpath: str | Path, run_dir: str | Path | None = None): + if run_dir is None: + run_dir = Path(__file__).parent.parent / "model_repo" / RUN_NAME + return Path(run_dir) / relpath + + +def download(run_dir: str | Path | None = None): + relpaths = ["hparams.yaml", "ds/G/latest", "ds/G/default/mp_rank_00_model_states.pt"] + for relpath in relpaths: + path = get_target_path(relpath, run_dir=run_dir) + if path.exists(): + continue + url = get_source_url(relpath) + path.parent.mkdir(parents=True, exist_ok=True) + torch.hub.download_url_to_file(url, str(path)) + return get_target_path("", run_dir=run_dir) diff --git a/modules/repos_static/resemble_enhance/enhancer/enhancer.py b/modules/repos_static/resemble_enhance/enhancer/enhancer.py new file mode 100644 index 0000000000000000000000000000000000000000..84cda8b0ad3cb0d99060d27d13638bd5dae2098c --- /dev/null +++ b/modules/repos_static/resemble_enhance/enhancer/enhancer.py @@ -0,0 +1,185 @@ +import logging + +import matplotlib.pyplot as plt +import pandas as pd +import torch +from torch import Tensor, nn +from torch.distributions import Beta + +from ..common import Normalizer +from ..denoiser.inference import load_denoiser +from ..melspec import MelSpectrogram +from .hparams import HParams +from .lcfm import CFM, IRMAE, LCFM +from .univnet import UnivNet + +logger = logging.getLogger(__name__) + + +def _maybe(fn): + def _fn(*args): + if args[0] is None: + return None + return fn(*args) + + return _fn + + +def _normalize_wav(x: Tensor): + return x / (x.abs().max(dim=-1, keepdim=True).values + 1e-7) + + +class Enhancer(nn.Module): + def __init__(self, hp: HParams): + super().__init__() + self.hp = hp + + n_mels = self.hp.num_mels + vocoder_input_dim = n_mels + self.hp.vocoder_extra_dim + latent_dim = self.hp.lcfm_latent_dim + + self.lcfm = LCFM( + IRMAE( + input_dim=n_mels, + output_dim=vocoder_input_dim, + latent_dim=latent_dim, + ), + CFM( + cond_dim=n_mels, + output_dim=self.hp.lcfm_latent_dim, + solver_nfe=self.hp.cfm_solver_nfe, + solver_method=self.hp.cfm_solver_method, + time_mapping_divisor=self.hp.cfm_time_mapping_divisor, + ), + z_scale=self.hp.lcfm_z_scale, + ) + + self.lcfm.set_mode_(self.hp.lcfm_training_mode) + + self.mel_fn = MelSpectrogram(hp) + self.vocoder = UnivNet(self.hp, vocoder_input_dim) + self.denoiser = load_denoiser(self.hp.denoiser_run_dir, "cpu") + self.normalizer = Normalizer() + + self._eval_lambd = 0.0 + + self.dummy: Tensor + self.register_buffer("dummy", torch.zeros(1)) + + if self.hp.enhancer_stage1_run_dir is not None: + pretrained_path = ( + self.hp.enhancer_stage1_run_dir + / "ds/G/default/mp_rank_00_model_states.pt" + ) + self._load_pretrained(pretrained_path) + + logger.info(f"{self.__class__.__name__} summary") + logger.info(f"{self.summarize()}") + + def _load_pretrained(self, path): + # Clone is necessary as otherwise it holds a reference to the original model + cfm_state_dict = {k: v.clone() for k, v in self.lcfm.cfm.state_dict().items()} + denoiser_state_dict = { + k: v.clone() for k, v in self.denoiser.state_dict().items() + } + state_dict = torch.load(path, map_location="cpu")["module"] + self.load_state_dict(state_dict, strict=False) + self.lcfm.cfm.load_state_dict(cfm_state_dict) # Reset cfm + self.denoiser.load_state_dict(denoiser_state_dict) # Reset denoiser + logger.info(f"Loaded pretrained model from {path}") + + def summarize(self): + npa_train = lambda m: sum(p.numel() for p in m.parameters() if p.requires_grad) + npa = lambda m: sum(p.numel() for p in m.parameters()) + rows = [] + for name, module in self.named_children(): + rows.append(dict(name=name, trainable=npa_train(module), total=npa(module))) + rows.append(dict(name="total", trainable=npa_train(self), total=npa(self))) + df = pd.DataFrame(rows) + return df.to_markdown(index=False) + + def to_mel(self, x: Tensor, drop_last=True): + """ + Args: + x: (b t), wavs + Returns: + o: (b c t), mels + """ + if drop_last: + return self.mel_fn(x)[..., :-1] # (b d t) + return self.mel_fn(x) + + def _may_denoise(self, x: Tensor, y: Tensor | None = None): + if self.hp.lcfm_training_mode == "cfm": + return self.denoiser(x, y) + return x + + def configurate_(self, nfe, solver, lambd, tau): + """ + Args: + nfe: number of function evaluations + solver: solver method + lambd: denoiser strength [0, 1] + tau: prior temperature [0, 1] + """ + self.lcfm.cfm.solver.configurate_(nfe, solver) + self.lcfm.eval_tau_(tau) + self._eval_lambd = lambd + + def forward(self, x: Tensor, y: Tensor | None = None, z: Tensor | None = None): + """ + Args: + x: (b t), mix wavs (fg + bg) + y: (b t), fg clean wavs + z: (b t), fg distorted wavs + Returns: + o: (b t), reconstructed wavs + """ + assert x.dim() == 2, f"Expected (b t), got {x.size()}" + assert y is None or y.dim() == 2, f"Expected (b t), got {y.size()}" + + if self.hp.lcfm_training_mode == "cfm": + self.normalizer.eval() + + x = _normalize_wav(x) + y = _maybe(_normalize_wav)(y) + z = _maybe(_normalize_wav)(z) + + x_mel_original = self.normalizer(self.to_mel(x), update=False) # (b d t) + + if self.hp.lcfm_training_mode == "cfm": + if self.training: + lambd = Beta(0.2, 0.2).sample(x.shape[:1]).to(x.device) + lambd = lambd[:, None, None] + x_mel_denoised = self.normalizer( + self.to_mel(self._may_denoise(x, z)), update=False + ) + x_mel_denoised = x_mel_denoised.detach() + x_mel_denoised = lambd * x_mel_denoised + (1 - lambd) * x_mel_original + self._visualize(x_mel_original, x_mel_denoised) + else: + lambd = self._eval_lambd + if lambd == 0: + x_mel_denoised = x_mel_original + else: + x_mel_denoised = self.normalizer( + self.to_mel(self._may_denoise(x, z)), update=False + ) + x_mel_denoised = x_mel_denoised.detach() + x_mel_denoised = ( + lambd * x_mel_denoised + (1 - lambd) * x_mel_original + ) + else: + x_mel_denoised = x_mel_original + + y_mel = _maybe(self.to_mel)(y) # (b d t) + y_mel = _maybe(self.normalizer)(y_mel) + + lcfm_decoded = self.lcfm(x_mel_denoised, y_mel, ψ0=x_mel_original) # (b d t) + + if lcfm_decoded is None: + o = None + else: + o = self.vocoder(lcfm_decoded, y) + + return o diff --git a/modules/repos_static/resemble_enhance/enhancer/hparams.py b/modules/repos_static/resemble_enhance/enhancer/hparams.py new file mode 100644 index 0000000000000000000000000000000000000000..ca89bea6f5d7d4ec4f543f8bde88b29dcae69f6a --- /dev/null +++ b/modules/repos_static/resemble_enhance/enhancer/hparams.py @@ -0,0 +1,23 @@ +from dataclasses import dataclass +from pathlib import Path + +from ..hparams import HParams as HParamsBase + + +@dataclass(frozen=True) +class HParams(HParamsBase): + cfm_solver_method: str = "midpoint" + cfm_solver_nfe: int = 64 + cfm_time_mapping_divisor: int = 4 + univnet_nc: int = 96 + + lcfm_latent_dim: int = 64 + lcfm_training_mode: str = "ae" + lcfm_z_scale: float = 5 + + vocoder_extra_dim: int = 32 + + gan_training_start_step: int | None = 5_000 + enhancer_stage1_run_dir: Path | None = None + + denoiser_run_dir: Path | None = None diff --git a/modules/repos_static/resemble_enhance/enhancer/inference.py b/modules/repos_static/resemble_enhance/enhancer/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..af57a2c7d3e5cc7b08b00f85f0135e881e50fcbe --- /dev/null +++ b/modules/repos_static/resemble_enhance/enhancer/inference.py @@ -0,0 +1,48 @@ +import logging +from functools import cache +from pathlib import Path + +import torch + +from ..inference import inference +from .download import download +from .hparams import HParams +from .enhancer import Enhancer + +logger = logging.getLogger(__name__) + + +@cache +def load_enhancer(run_dir: str | Path | None, device): + run_dir = download(run_dir) + hp = HParams.load(run_dir) + enhancer = Enhancer(hp) + path = run_dir / "ds" / "G" / "default" / "mp_rank_00_model_states.pt" + state_dict = torch.load(path, map_location="cpu")["module"] + enhancer.load_state_dict(state_dict) + enhancer.eval() + enhancer.to(device) + return enhancer + + +@torch.inference_mode() +def denoise(dwav, sr, device, run_dir=None): + enhancer = load_enhancer(run_dir, device) + return inference(model=enhancer.denoiser, dwav=dwav, sr=sr, device=device) + + +@torch.inference_mode() +def enhance( + dwav, sr, device, nfe=32, solver="midpoint", lambd=0.5, tau=0.5, run_dir=None +): + assert 0 < nfe <= 128, f"nfe must be in (0, 128], got {nfe}" + assert solver in ( + "midpoint", + "rk4", + "euler", + ), f"solver must be in ('midpoint', 'rk4', 'euler'), got {solver}" + assert 0 <= lambd <= 1, f"lambd must be in [0, 1], got {lambd}" + assert 0 <= tau <= 1, f"tau must be in [0, 1], got {tau}" + enhancer = load_enhancer(run_dir, device) + enhancer.configurate_(nfe=nfe, solver=solver, lambd=lambd, tau=tau) + return inference(model=enhancer, dwav=dwav, sr=sr, device=device) diff --git a/modules/repos_static/resemble_enhance/enhancer/lcfm/__init__.py b/modules/repos_static/resemble_enhance/enhancer/lcfm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9eca51c6bc6b2132389ac7ec0380159169a69499 --- /dev/null +++ b/modules/repos_static/resemble_enhance/enhancer/lcfm/__init__.py @@ -0,0 +1,2 @@ +from .irmae import IRMAE +from .lcfm import CFM, LCFM diff --git a/modules/repos_static/resemble_enhance/enhancer/lcfm/cfm.py b/modules/repos_static/resemble_enhance/enhancer/lcfm/cfm.py new file mode 100644 index 0000000000000000000000000000000000000000..a5125267b7f32e11c58e4b96bffa3ba1e96fdc4f --- /dev/null +++ b/modules/repos_static/resemble_enhance/enhancer/lcfm/cfm.py @@ -0,0 +1,372 @@ +import logging +from dataclasses import dataclass +from functools import partial +from typing import Protocol + +import matplotlib.pyplot as plt +import numpy as np +import scipy +import torch +import torch.nn.functional as F +from torch import Tensor, nn +from tqdm import trange + +from .wn import WN + +logger = logging.getLogger(__name__) + + +class VelocityField(Protocol): + def __call__(self, *, t: Tensor, ψt: Tensor, dt: Tensor) -> Tensor: + ... + + +class Solver: + def __init__( + self, + method="midpoint", + nfe=32, + viz_name="solver", + viz_every=100, + mel_fn=None, + time_mapping_divisor=4, + verbose=False, + ): + self.configurate_(nfe=nfe, method=method) + + self.verbose = verbose + self.viz_every = viz_every + self.viz_name = viz_name + + self._camera = None + self._mel_fn = mel_fn + self._time_mapping = partial(self.exponential_decay_mapping, n=time_mapping_divisor) + + def configurate_(self, nfe=None, method=None): + if nfe is None: + nfe = self.nfe + + if method is None: + method = self.method + + if nfe == 1 and method in ("midpoint", "rk4"): + logger.warning(f"1 NFE is not supported for {method}, using euler method instead.") + method = "euler" + + self.nfe = nfe + self.method = method + + @property + def time_mapping(self): + return self._time_mapping + + @staticmethod + def exponential_decay_mapping(t, n=4): + """ + Args: + n: target step + """ + + def h(t, a): + return (a**t - 1) / (a - 1) + + # Solve h(1/n) = 0.5 + a = float(scipy.optimize.fsolve(lambda a: h(1 / n, a) - 0.5, x0=0)) + + t = h(t, a=a) + + return t + + @torch.no_grad() + def _maybe_camera_snap(self, *, ψt, t): + camera = self._camera + if camera is not None: + if ψt.shape[1] == 1: + # Waveform, b 1 t, plot every 100 samples + plt.subplot(211) + plt.plot(ψt.detach().cpu().numpy()[0, 0, ::100], color="blue") + if self._mel_fn is not None: + plt.subplot(212) + mel = self._mel_fn(ψt.detach().cpu().numpy()[0, 0]) + plt.imshow(mel, origin="lower", interpolation="none") + elif ψt.shape[1] == 2: + # Complex + plt.subplot(121) + plt.imshow( + ψt.detach().cpu().numpy()[0, 0], + origin="lower", + interpolation="none", + ) + plt.subplot(122) + plt.imshow( + ψt.detach().cpu().numpy()[0, 1], + origin="lower", + interpolation="none", + ) + else: + # Spectrogram, b c t + plt.imshow(ψt.detach().cpu().numpy()[0], origin="lower", interpolation="none") + ax = plt.gca() + ax.text(0.5, 1.01, f"t={t:.2f}", transform=ax.transAxes, ha="center") + camera.snap() + + @staticmethod + def _euler_step(t, ψt, dt, f: VelocityField): + return ψt + dt * f(t=t, ψt=ψt, dt=dt) + + @staticmethod + def _midpoint_step(t, ψt, dt, f: VelocityField): + return ψt + dt * f(t=t + dt / 2, ψt=ψt + dt * f(t=t, ψt=ψt, dt=dt) / 2, dt=dt) + + @staticmethod + def _rk4_step(t, ψt, dt, f: VelocityField): + k1 = f(t=t, ψt=ψt, dt=dt) + k2 = f(t=t + dt / 2, ψt=ψt + dt * k1 / 2, dt=dt) + k3 = f(t=t + dt / 2, ψt=ψt + dt * k2 / 2, dt=dt) + k4 = f(t=t + dt, ψt=ψt + dt * k3, dt=dt) + return ψt + dt * (k1 + 2 * k2 + 2 * k3 + k4) / 6 + + @property + def _step(self): + if self.method == "euler": + return self._euler_step + elif self.method == "midpoint": + return self._midpoint_step + elif self.method == "rk4": + return self._rk4_step + else: + raise ValueError(f"Unknown method: {self.method}") + + def get_running_train_loop(self): + try: + # Lazy import + from ...utils.train_loop import TrainLoop + + return TrainLoop.get_running_loop() + except ImportError: + return None + + @property + def visualizing(self): + loop = self.get_running_train_loop() + if loop is None: + return + out_path = loop.make_current_step_viz_path(self.viz_name, ".gif") + return loop.global_step % self.viz_every == 0 and not out_path.exists() + + def _reset_camera(self): + try: + from celluloid import Camera + + self._camera = Camera(plt.figure()) + except: + pass + + def _maybe_dump_camera(self): + camera = self._camera + loop = self.get_running_train_loop() + if camera is not None and loop is not None: + animation = camera.animate() + out_path = loop.make_current_step_viz_path(self.viz_name, ".gif") + out_path.parent.mkdir(exist_ok=True, parents=True) + animation.save(out_path, writer="pillow", fps=4) + plt.close() + self._camera = None + + @property + def n_steps(self): + n = self.nfe + if self.method == "euler": + pass + elif self.method == "midpoint": + n //= 2 + elif self.method == "rk4": + n //= 4 + else: + raise ValueError(f"Unknown method: {self.method}") + return n + + def solve(self, f: VelocityField, ψ0: Tensor, t0=0.0, t1=1.0): + ts = self._time_mapping(np.linspace(t0, t1, self.n_steps + 1)) + + if self.visualizing: + self._reset_camera() + + if self.verbose: + steps = trange(self.n_steps, desc="CFM inference") + else: + steps = range(self.n_steps) + + ψt = ψ0 + + for i in steps: + dt = ts[i + 1] - ts[i] + t = ts[i] + self._maybe_camera_snap(ψt=ψt, t=t) + ψt = self._step(t=t, ψt=ψt, dt=dt, f=f) + + self._maybe_camera_snap(ψt=ψt, t=ts[-1]) + + ψ1 = ψt + del ψt + + self._maybe_dump_camera() + + return ψ1 + + def __call__(self, f: VelocityField, ψ0: Tensor, t0=0.0, t1=1.0): + return self.solve(f=f, ψ0=ψ0, t0=t0, t1=t1) + + +class SinusodialTimeEmbedding(nn.Module): + def __init__(self, d_embed): + super().__init__() + self.d_embed = d_embed + assert d_embed % 2 == 0 + + def forward(self, t): + t = t.unsqueeze(-1) # ... 1 + p = torch.linspace(0, 4, self.d_embed // 2).to(t) + while p.dim() < t.dim(): + p = p.unsqueeze(0) # ... d/2 + sin = torch.sin(t * 10**p) + cos = torch.cos(t * 10**p) + return torch.cat([sin, cos], dim=-1) + + +@dataclass(eq=False) +class CFM(nn.Module): + """ + This mixin is for general diffusion models. + + ψ0 stands for the gaussian noise, and ψ1 is the data point. + + Here we follow the CFM style: + The generation process (reverse process) is from t=0 to t=1. + The forward process is from t=1 to t=0. + """ + + cond_dim: int + output_dim: int + time_emb_dim: int = 128 + viz_name: str = "cfm" + solver_nfe: int = 32 + solver_method: str = "midpoint" + time_mapping_divisor: int = 4 + + def __post_init__(self): + super().__init__() + self.solver = Solver( + viz_name=self.viz_name, + viz_every=1, + nfe=self.solver_nfe, + method=self.solver_method, + time_mapping_divisor=self.time_mapping_divisor, + ) + self.emb = SinusodialTimeEmbedding(self.time_emb_dim) + self.net = WN( + input_dim=self.output_dim, + output_dim=self.output_dim, + local_dim=self.cond_dim, + global_dim=self.time_emb_dim, + ) + + def _perturb(self, ψ1: Tensor, t: Tensor | None = None): + """ + Perturb ψ1 to ψt. + """ + raise NotImplementedError + + def _sample_ψ0(self, x: Tensor): + """ + Args: + x: (b c t), which implies the shape of ψ0 + """ + shape = list(x.shape) + shape[1] = self.output_dim + if self.training: + g = None + else: + g = torch.Generator(device=x.device) + g.manual_seed(0) # deterministic sampling during eval + ψ0 = torch.randn(shape, device=x.device, dtype=x.dtype, generator=g) + return ψ0 + + @property + def sigma(self): + return 1e-4 + + def _to_ψt(self, *, ψ1: Tensor, ψ0: Tensor, t: Tensor): + """ + Eq (22) + """ + while t.dim() < ψ1.dim(): + t = t.unsqueeze(-1) + μ = t * ψ1 + (1 - t) * ψ0 + return μ + torch.randn_like(μ) * self.sigma + + def _to_u(self, *, ψ1, ψ0: Tensor): + """ + Eq (21) + """ + return ψ1 - ψ0 + + def _to_v(self, *, ψt, x, t: float | Tensor): + """ + Args: + ψt: (b c t) + x: (b c t) + t: (b) + Returns: + v: (b c t) + """ + if isinstance(t, (float, int)): + t = torch.full(ψt.shape[:1], t).to(ψt) + t = t.clamp(0, 1) # [0, 1) + g = self.emb(t) # (b d) + v = self.net(ψt, l=x, g=g) + return v + + def compute_losses(self, x, y, ψ0) -> dict: + """ + Args: + x: (b c t) + y: (b c t) + Returns: + losses: dict + """ + t = torch.rand(len(x), device=x.device, dtype=x.dtype) + t = self.solver.time_mapping(t) + + if ψ0 is None: + ψ0 = self._sample_ψ0(x) + + ψt = self._to_ψt(ψ1=y, t=t, ψ0=ψ0) + + v = self._to_v(ψt=ψt, t=t, x=x) + u = self._to_u(ψ1=y, ψ0=ψ0) + + losses = dict(l1=F.l1_loss(v, u)) + + return losses + + @torch.inference_mode() + def sample(self, x, ψ0=None, t0=0.0): + """ + Args: + x: (b c t) + Returns: + y: (b ... t) + """ + if ψ0 is None: + ψ0 = self._sample_ψ0(x) + f = lambda t, ψt, dt: self._to_v(ψt=ψt, t=t, x=x) + ψ1 = self.solver(f=f, ψ0=ψ0, t0=t0) + return ψ1 + + def forward(self, x: Tensor, y: Tensor | None = None, ψ0: Tensor | None = None, t0=0.0): + if y is None: + y = self.sample(x, ψ0=ψ0, t0=t0) + else: + self.losses = self.compute_losses(x, y, ψ0=ψ0) + return y diff --git a/modules/repos_static/resemble_enhance/enhancer/lcfm/irmae.py b/modules/repos_static/resemble_enhance/enhancer/lcfm/irmae.py new file mode 100644 index 0000000000000000000000000000000000000000..e71ab0cd8b9f07c3c27ca3877ee79b6510445d1f --- /dev/null +++ b/modules/repos_static/resemble_enhance/enhancer/lcfm/irmae.py @@ -0,0 +1,123 @@ +import logging +from dataclasses import dataclass + +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.utils.parametrizations import weight_norm + +from ...common import Normalizer + +logger = logging.getLogger(__name__) + + +@dataclass +class IRMAEOutput: + latent: Tensor # latent vector + decoded: Tensor | None # decoder output, include extra dim + + +class ResBlock(nn.Sequential): + def __init__(self, channels, dilations=[1, 2, 4, 8]): + wn = weight_norm + super().__init__( + nn.GroupNorm(32, channels), + nn.GELU(), + wn(nn.Conv1d(channels, channels, 3, padding="same", dilation=dilations[0])), + nn.GroupNorm(32, channels), + nn.GELU(), + wn(nn.Conv1d(channels, channels, 3, padding="same", dilation=dilations[1])), + nn.GroupNorm(32, channels), + nn.GELU(), + wn(nn.Conv1d(channels, channels, 3, padding="same", dilation=dilations[2])), + nn.GroupNorm(32, channels), + nn.GELU(), + wn(nn.Conv1d(channels, channels, 3, padding="same", dilation=dilations[3])), + ) + + def forward(self, x: Tensor): + return x + super().forward(x) + + +class IRMAE(nn.Module): + def __init__( + self, + input_dim, + output_dim, + latent_dim, + hidden_dim=1024, + num_irms=4, + ): + """ + Args: + input_dim: input dimension + output_dim: output dimension + latent_dim: latent dimension + hidden_dim: hidden layer dimension + num_irm_matrics: number of implicit rank minimization matrices + norm: normalization layer + """ + self.input_dim = input_dim + super().__init__() + + self.encoder = nn.Sequential( + nn.Conv1d(input_dim, hidden_dim, 3, padding="same"), + *[ResBlock(hidden_dim) for _ in range(4)], + # Try to obtain compact representation (https://proceedings.neurips.cc/paper/2020/file/a9078e8653368c9c291ae2f8b74012e7-Paper.pdf) + *[nn.Conv1d(hidden_dim if i == 0 else latent_dim, latent_dim, 1, bias=False) for i in range(num_irms)], + nn.Tanh(), + ) + + self.decoder = nn.Sequential( + nn.Conv1d(latent_dim, hidden_dim, 3, padding="same"), + *[ResBlock(hidden_dim) for _ in range(4)], + nn.Conv1d(hidden_dim, output_dim, 1), + ) + + self.head = nn.Sequential( + nn.Conv1d(output_dim, hidden_dim, 3, padding="same"), + nn.GELU(), + nn.Conv1d(hidden_dim, input_dim, 1), + ) + + self.estimator = Normalizer() + + def encode(self, x): + """ + Args: + x: (b c t) tensor + """ + z = self.encoder(x) # (b c t) + _ = self.estimator(z) # Estimate the glboal mean and std of z + self.stats = {} + self.stats["z_mean"] = z.mean().item() + self.stats["z_std"] = z.std().item() + self.stats["z_abs_68"] = z.abs().quantile(0.6827).item() + self.stats["z_abs_95"] = z.abs().quantile(0.9545).item() + self.stats["z_abs_99"] = z.abs().quantile(0.9973).item() + return z + + def decode(self, z): + """ + Args: + z: (b c t) tensor + """ + return self.decoder(z) + + def forward(self, x, skip_decoding=False): + """ + Args: + x: (b c t) tensor + skip_decoding: if True, skip the decoding step + """ + z = self.encode(x) # q(z|x) + + if skip_decoding: + # This speeds up the training in cfm only mode + decoded = None + else: + decoded = self.decode(z) # p(x|z) + predicted = self.head(decoded) + self.losses = dict(mse=F.mse_loss(predicted, x)) + + return IRMAEOutput(latent=z, decoded=decoded) diff --git a/modules/repos_static/resemble_enhance/enhancer/lcfm/lcfm.py b/modules/repos_static/resemble_enhance/enhancer/lcfm/lcfm.py new file mode 100644 index 0000000000000000000000000000000000000000..4c2f5f88718e2f42f82e2f4714ea510b4677b450 --- /dev/null +++ b/modules/repos_static/resemble_enhance/enhancer/lcfm/lcfm.py @@ -0,0 +1,152 @@ +import logging +from enum import Enum + +import matplotlib.pyplot as plt +import torch +import torch.nn as nn +from torch import Tensor, nn + +from .cfm import CFM +from .irmae import IRMAE, IRMAEOutput + +logger = logging.getLogger(__name__) + + +def freeze_(module): + for p in module.parameters(): + p.requires_grad_(False) + + +class LCFM(nn.Module): + class Mode(Enum): + AE = "ae" + CFM = "cfm" + + def __init__(self, ae: IRMAE, cfm: CFM, z_scale: float = 1.0): + super().__init__() + self.ae = ae + self.cfm = cfm + self.z_scale = z_scale + self._mode = None + self._eval_tau = 0.5 + + @property + def mode(self): + return self._mode + + def set_mode_(self, mode): + mode = self.Mode(mode) + self._mode = mode + + if mode == mode.AE: + freeze_(self.cfm) + logger.info("Freeze cfm") + elif mode == mode.CFM: + freeze_(self.ae) + logger.info("Freeze ae (encoder and decoder)") + else: + raise ValueError(f"Unknown training mode: {mode}") + + def get_running_train_loop(self): + try: + # Lazy import + from ...utils.train_loop import TrainLoop + + return TrainLoop.get_running_loop() + except ImportError: + return None + + @property + def global_step(self): + loop = self.get_running_train_loop() + if loop is None: + return None + return loop.global_step + + @torch.no_grad() + def _visualize(self, x, y, y_): + loop = self.get_running_train_loop() + if loop is None: + return + + plt.subplot(221) + plt.imshow(y[0].detach().cpu().numpy(), aspect="auto", origin="lower", interpolation="none") + plt.title("GT") + + plt.subplot(222) + y_ = y_[:, : y.shape[1]] + plt.imshow(y_[0].detach().cpu().numpy(), aspect="auto", origin="lower", interpolation="none") + plt.title("Posterior") + + plt.subplot(223) + z_ = self.cfm(x) + y__ = self.ae.decode(z_) + y__ = y__[:, : y.shape[1]] + plt.imshow(y__[0].detach().cpu().numpy(), aspect="auto", origin="lower", interpolation="none") + plt.title("C-Prior") + del y__ + + plt.subplot(224) + z_ = torch.randn_like(z_) + y__ = self.ae.decode(z_) + y__ = y__[:, : y.shape[1]] + plt.imshow(y__[0].detach().cpu().numpy(), aspect="auto", origin="lower", interpolation="none") + plt.title("Prior") + del z_, y__ + + path = loop.make_current_step_viz_path("recon", ".png") + path.parent.mkdir(exist_ok=True, parents=True) + plt.tight_layout() + plt.savefig(path, dpi=500) + plt.close() + + def _scale(self, z: Tensor): + return z * self.z_scale + + def _unscale(self, z: Tensor): + return z / self.z_scale + + def eval_tau_(self, tau): + self._eval_tau = tau + + def forward(self, x, y: Tensor | None = None, ψ0: Tensor | None = None): + """ + Args: + x: (b d t), condition mel + y: (b d t), target mel + ψ0: (b d t), starting mel + """ + if self.mode == self.Mode.CFM: + self.ae.eval() # Always set to eval when training cfm + + if ψ0 is not None: + ψ0 = self._scale(self.ae.encode(ψ0)) + if self.training: + tau = torch.rand_like(ψ0[:, :1, :1]) + else: + tau = self._eval_tau + ψ0 = tau * torch.randn_like(ψ0) + (1 - tau) * ψ0 + + if y is None: + if self.mode == self.Mode.AE: + with torch.no_grad(): + training = self.ae.training + self.ae.eval() + z = self.ae.encode(x) + self.ae.train(training) + else: + z = self._unscale(self.cfm(x, ψ0=ψ0)) + + h = self.ae.decode(z) + else: + ae_output: IRMAEOutput = self.ae(y, skip_decoding=self.mode == self.Mode.CFM) + + if self.mode == self.Mode.CFM: + _ = self.cfm(x, self._scale(ae_output.latent.detach()), ψ0=ψ0) + + h = ae_output.decoded + + if h is not None and self.global_step is not None and self.global_step % 100 == 0: + self._visualize(x[:1], y[:1], h[:1]) + + return h diff --git a/modules/repos_static/resemble_enhance/enhancer/lcfm/wn.py b/modules/repos_static/resemble_enhance/enhancer/lcfm/wn.py new file mode 100644 index 0000000000000000000000000000000000000000..8bde173c205bb74f30ed95a9f013b3eb5b2abe5a --- /dev/null +++ b/modules/repos_static/resemble_enhance/enhancer/lcfm/wn.py @@ -0,0 +1,147 @@ +import logging +import math + +import torch +import torch.nn as nn + +logger = logging.getLogger(__name__) + + +@torch.jit.script +def _fused_tanh_sigmoid(h): + a, b = h.chunk(2, dim=1) + h = a.tanh() * b.sigmoid() + return h + + +class WNLayer(nn.Module): + """ + A DiffWave-like WN + """ + + def __init__(self, hidden_dim, local_dim, global_dim, kernel_size, dilation): + super().__init__() + + local_output_dim = hidden_dim * 2 + + if global_dim is not None: + self.gconv = nn.Conv1d(global_dim, hidden_dim, 1) + + if local_dim is not None: + self.lconv = nn.Conv1d(local_dim, local_output_dim, 1) + + self.dconv = nn.Conv1d(hidden_dim, local_output_dim, kernel_size, dilation=dilation, padding="same") + + self.out = nn.Conv1d(hidden_dim, 2 * hidden_dim, kernel_size=1) + + def forward(self, z, l, g): + identity = z + + if g is not None: + if g.dim() == 2: + g = g.unsqueeze(-1) + z = z + self.gconv(g) + + z = self.dconv(z) + + if l is not None: + z = z + self.lconv(l) + + z = _fused_tanh_sigmoid(z) + + h = self.out(z) + + z, s = h.chunk(2, dim=1) + + o = (z + identity) / math.sqrt(2) + + return o, s + + +class WN(nn.Module): + def __init__( + self, + input_dim, + output_dim, + local_dim=None, + global_dim=None, + n_layers=30, + kernel_size=3, + dilation_cycle=5, + hidden_dim=512, + ): + super().__init__() + assert kernel_size % 2 == 1 + assert hidden_dim % 2 == 0 + + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.local_dim = local_dim + self.global_dim = global_dim + + self.start = nn.Conv1d(input_dim, hidden_dim, 1) + if local_dim is not None: + self.local_norm = nn.InstanceNorm1d(local_dim) + + self.layers = nn.ModuleList( + [ + WNLayer( + hidden_dim=hidden_dim, + local_dim=local_dim, + global_dim=global_dim, + kernel_size=kernel_size, + dilation=2 ** (i % dilation_cycle), + ) + for i in range(n_layers) + ] + ) + + self.end = nn.Conv1d(hidden_dim, output_dim, 1) + + def forward(self, z, l=None, g=None): + """ + Args: + z: input (b c t) + l: local condition (b c t) + g: global condition (b d) + """ + z = self.start(z) + + if l is not None: + l = self.local_norm(l) + + # Skips + s_list = [] + + for layer in self.layers: + z, s = layer(z, l, g) + s_list.append(s) + + s_list = torch.stack(s_list, dim=0).sum(dim=0) + s_list = s_list / math.sqrt(len(self.layers)) + + o = self.end(s_list) + + return o + + def summarize(self, length=100): + from ptflops import get_model_complexity_info + + x = torch.randn(1, self.input_dim, length) + + macs, params = get_model_complexity_info( + self, + (self.input_dim, length), + as_strings=True, + print_per_layer_stat=True, + verbose=True, + ) + + print(f"Input shape: {x.shape}") + print(f"Computational complexity: {macs}") + print(f"Number of parameters: {params}") + + +if __name__ == "__main__": + model = WN(input_dim=64, output_dim=64) + model.summarize() diff --git a/modules/repos_static/resemble_enhance/enhancer/univnet/__init__.py b/modules/repos_static/resemble_enhance/enhancer/univnet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d4d2fea066e2e71371c6af840e759f1676380170 --- /dev/null +++ b/modules/repos_static/resemble_enhance/enhancer/univnet/__init__.py @@ -0,0 +1 @@ +from .univnet import UnivNet diff --git a/modules/repos_static/resemble_enhance/enhancer/univnet/alias_free_torch/__init__.py b/modules/repos_static/resemble_enhance/enhancer/univnet/alias_free_torch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..270596c8f44f9295026cf308b39151a08dbed85a --- /dev/null +++ b/modules/repos_static/resemble_enhance/enhancer/univnet/alias_free_torch/__init__.py @@ -0,0 +1,5 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 +# LICENSE is in incl_licenses directory. + +from .filter import * +from .resample import * diff --git a/modules/repos_static/resemble_enhance/enhancer/univnet/alias_free_torch/filter.py b/modules/repos_static/resemble_enhance/enhancer/univnet/alias_free_torch/filter.py new file mode 100644 index 0000000000000000000000000000000000000000..5165557d7dcadcb4d07018e13562b22f8c85e91e --- /dev/null +++ b/modules/repos_static/resemble_enhance/enhancer/univnet/alias_free_torch/filter.py @@ -0,0 +1,95 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 +# LICENSE is in incl_licenses directory. + +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + +if 'sinc' in dir(torch): + sinc = torch.sinc +else: + # This code is adopted from adefossez's julius.core.sinc under the MIT License + # https://adefossez.github.io/julius/julius/core.html + # LICENSE is in incl_licenses directory. + def sinc(x: torch.Tensor): + """ + Implementation of sinc, i.e. sin(pi * x) / (pi * x) + __Warning__: Different to julius.sinc, the input is multiplied by `pi`! + """ + return torch.where(x == 0, + torch.tensor(1., device=x.device, dtype=x.dtype), + torch.sin(math.pi * x) / math.pi / x) + + +# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License +# https://adefossez.github.io/julius/julius/lowpass.html +# LICENSE is in incl_licenses directory. +def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size] + even = (kernel_size % 2 == 0) + half_size = kernel_size // 2 + + #For kaiser window + delta_f = 4 * half_width + A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 + if A > 50.: + beta = 0.1102 * (A - 8.7) + elif A >= 21.: + beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.) + else: + beta = 0. + window = torch.kaiser_window(kernel_size, beta=beta, periodic=False) + + # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio + if even: + time = (torch.arange(-half_size, half_size) + 0.5) + else: + time = torch.arange(kernel_size) - half_size + if cutoff == 0: + filter_ = torch.zeros_like(time) + else: + filter_ = 2 * cutoff * window * sinc(2 * cutoff * time) + # Normalize filter to have sum = 1, otherwise we will have a small leakage + # of the constant component in the input signal. + filter_ /= filter_.sum() + filter = filter_.view(1, 1, kernel_size) + + return filter + + +class LowPassFilter1d(nn.Module): + def __init__(self, + cutoff=0.5, + half_width=0.6, + stride: int = 1, + padding: bool = True, + padding_mode: str = 'replicate', + kernel_size: int = 12): + # kernel_size should be even number for stylegan3 setup, + # in this implementation, odd number is also possible. + super().__init__() + if cutoff < -0.: + raise ValueError("Minimum cutoff must be larger than zero.") + if cutoff > 0.5: + raise ValueError("A cutoff above 0.5 does not make sense.") + self.kernel_size = kernel_size + self.even = (kernel_size % 2 == 0) + self.pad_left = kernel_size // 2 - int(self.even) + self.pad_right = kernel_size // 2 + self.stride = stride + self.padding = padding + self.padding_mode = padding_mode + filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size) + self.register_buffer("filter", filter) + + #input [B, C, T] + def forward(self, x): + _, C, _ = x.shape + + if self.padding: + x = F.pad(x, (self.pad_left, self.pad_right), + mode=self.padding_mode) + out = F.conv1d(x, self.filter.expand(C, -1, -1), + stride=self.stride, groups=C) + + return out diff --git a/modules/repos_static/resemble_enhance/enhancer/univnet/alias_free_torch/resample.py b/modules/repos_static/resemble_enhance/enhancer/univnet/alias_free_torch/resample.py new file mode 100644 index 0000000000000000000000000000000000000000..7fc6e12a9dbaa9ac41bd349b7f1797442052e4f6 --- /dev/null +++ b/modules/repos_static/resemble_enhance/enhancer/univnet/alias_free_torch/resample.py @@ -0,0 +1,49 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 +# LICENSE is in incl_licenses directory. + +import torch.nn as nn +from torch.nn import functional as F +from .filter import LowPassFilter1d +from .filter import kaiser_sinc_filter1d + + +class UpSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None): + super().__init__() + self.ratio = ratio + self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + self.stride = ratio + self.pad = self.kernel_size // ratio - 1 + self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2 + self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2 + filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio, + half_width=0.6 / ratio, + kernel_size=self.kernel_size) + self.register_buffer("filter", filter) + + # x: [B, C, T] + def forward(self, x): + _, C, _ = x.shape + + x = F.pad(x, (self.pad, self.pad), mode='replicate') + x = self.ratio * F.conv_transpose1d( + x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C) + x = x[..., self.pad_left:-self.pad_right] + + return x + + +class DownSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None): + super().__init__() + self.ratio = ratio + self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio, + half_width=0.6 / ratio, + stride=ratio, + kernel_size=self.kernel_size) + + def forward(self, x): + xx = self.lowpass(x) + + return xx diff --git a/modules/repos_static/resemble_enhance/enhancer/univnet/amp.py b/modules/repos_static/resemble_enhance/enhancer/univnet/amp.py new file mode 100644 index 0000000000000000000000000000000000000000..469026338771408a24253ae52c8f2f22a6057475 --- /dev/null +++ b/modules/repos_static/resemble_enhance/enhancer/univnet/amp.py @@ -0,0 +1,101 @@ +# Refer from https://github.com/NVIDIA/BigVGAN + +import math + +import torch +import torch.nn as nn +from torch import nn +from torch.nn.utils.parametrizations import weight_norm + +from .alias_free_torch import DownSample1d, UpSample1d + + +class SnakeBeta(nn.Module): + """ + A modified Snake function which uses separate parameters for the magnitude of the periodic components + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + References: + - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://arxiv.org/abs/2006.08195 + Examples: + >>> a1 = snakebeta(256) + >>> x = torch.randn(256) + >>> x = a1(x) + """ + + def __init__(self, in_features, alpha=1.0, clamp=(1e-2, 50)): + """ + Initialization. + INPUT: + - in_features: shape of the input + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + alpha is initialized to 1 by default, higher values = higher-frequency. + beta is initialized to 1 by default, higher values = higher-magnitude. + alpha will be trained along with the rest of your model. + """ + super().__init__() + self.in_features = in_features + self.log_alpha = nn.Parameter(torch.zeros(in_features) + math.log(alpha)) + self.log_beta = nn.Parameter(torch.zeros(in_features) + math.log(alpha)) + self.clamp = clamp + + def forward(self, x): + """ + Forward pass of the function. + Applies the function to the input elementwise. + SnakeBeta ∶= x + 1/b * sin^2 (xa) + """ + alpha = self.log_alpha.exp().clamp(*self.clamp) + alpha = alpha[None, :, None] + + beta = self.log_beta.exp().clamp(*self.clamp) + beta = beta[None, :, None] + + x = x + (1.0 / beta) * (x * alpha).sin().pow(2) + + return x + + +class UpActDown(nn.Module): + def __init__( + self, + act, + up_ratio: int = 2, + down_ratio: int = 2, + up_kernel_size: int = 12, + down_kernel_size: int = 12, + ): + super().__init__() + self.up_ratio = up_ratio + self.down_ratio = down_ratio + self.act = act + self.upsample = UpSample1d(up_ratio, up_kernel_size) + self.downsample = DownSample1d(down_ratio, down_kernel_size) + + def forward(self, x): + # x: [B,C,T] + x = self.upsample(x) + x = self.act(x) + x = self.downsample(x) + return x + + +class AMPBlock(nn.Sequential): + def __init__(self, channels, *, kernel_size=3, dilations=(1, 3, 5)): + super().__init__(*(self._make_layer(channels, kernel_size, d) for d in dilations)) + + def _make_layer(self, channels, kernel_size, dilation): + return nn.Sequential( + weight_norm(nn.Conv1d(channels, channels, kernel_size, dilation=dilation, padding="same")), + UpActDown(act=SnakeBeta(channels)), + weight_norm(nn.Conv1d(channels, channels, kernel_size, padding="same")), + ) + + def forward(self, x): + return x + super().forward(x) diff --git a/modules/repos_static/resemble_enhance/enhancer/univnet/discriminator.py b/modules/repos_static/resemble_enhance/enhancer/univnet/discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..ef3bd2552ea6f7f654c72737e079ce3239835d68 --- /dev/null +++ b/modules/repos_static/resemble_enhance/enhancer/univnet/discriminator.py @@ -0,0 +1,210 @@ +import logging + +import torch +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.utils.parametrizations import weight_norm + +from ..hparams import HParams +from .mrstft import get_stft_cfgs + +logger = logging.getLogger(__name__) + + +class PeriodNetwork(nn.Module): + def __init__(self, period): + super().__init__() + self.period = period + wn = weight_norm + self.convs = nn.ModuleList( + [ + wn(nn.Conv2d(1, 64, (5, 1), (3, 1), padding=(2, 0))), + wn(nn.Conv2d(64, 128, (5, 1), (3, 1), padding=(2, 0))), + wn(nn.Conv2d(128, 256, (5, 1), (3, 1), padding=(2, 0))), + wn(nn.Conv2d(256, 512, (5, 1), (3, 1), padding=(2, 0))), + wn(nn.Conv2d(512, 1024, (5, 1), 1, padding=(2, 0))), + ] + ) + self.conv_post = wn(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + + def forward(self, x): + """ + Args: + x: [B, 1, T] + """ + assert x.dim() == 3, f"(B, 1, T) is expected, but got {x.shape}." + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, 0.2) + x = self.conv_post(x) + x = torch.flatten(x, 1, -1) + + return x + + +class SpecNetwork(nn.Module): + def __init__(self, stft_cfg: dict): + super().__init__() + wn = weight_norm + self.stft_cfg = stft_cfg + self.convs = nn.ModuleList( + [ + wn(nn.Conv2d(1, 32, (3, 9), padding=(1, 4))), + wn(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))), + wn(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))), + wn(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))), + wn(nn.Conv2d(32, 32, (3, 3), padding=(1, 1))), + ] + ) + self.conv_post = wn(nn.Conv2d(32, 1, (3, 3), padding=(1, 1))) + + def forward(self, x): + """ + Args: + x: [B, 1, T] + """ + x = self.spectrogram(x) + x = x.unsqueeze(1) + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, 0.2) + x = self.conv_post(x) + x = x.flatten(1, -1) + return x + + def spectrogram(self, x): + """ + Args: + x: [B, 1, T] + """ + x = x.squeeze(1) + dtype = x.dtype + stft_cfg = dict(self.stft_cfg) + x = torch.stft(x.float(), center=False, return_complex=False, **stft_cfg) + mag = x.norm(p=2, dim=-1) # [B, F, TT] + mag = mag.to(dtype) # [B, F, TT] + return mag + + +class MD(nn.ModuleList): + def __init__(self, l: list): + super().__init__([self._create_network(x) for x in l]) + self._loss_type = None + + def loss_type_(self, loss_type): + self._loss_type = loss_type + + def _create_network(self, _): + raise NotImplementedError + + def _forward_each(self, d, x, y): + assert self._loss_type is not None, "loss_type is not set." + loss_type = self._loss_type + + if loss_type == "hinge": + if y == 0: + # d(x) should be small -> -1 + loss = F.relu(1 + d(x)).mean() + elif y == 1: + # d(x) should be large -> 1 + loss = F.relu(1 - d(x)).mean() + else: + raise ValueError(f"Invalid y: {y}") + elif loss_type == "wgan": + if y == 0: + loss = d(x).mean() + elif y == 1: + loss = -d(x).mean() + else: + raise ValueError(f"Invalid y: {y}") + else: + raise ValueError(f"Invalid loss_type: {loss_type}") + + return loss + + def forward(self, x, y) -> Tensor: + losses = [self._forward_each(d, x, y) for d in self] + return torch.stack(losses).mean() + + +class MPD(MD): + def __init__(self): + super().__init__([2, 3, 7, 13, 17]) + + def _create_network(self, period): + return PeriodNetwork(period) + + +class MRD(MD): + def __init__(self, stft_cfgs): + super().__init__(stft_cfgs) + + def _create_network(self, stft_cfg): + return SpecNetwork(stft_cfg) + + +class Discriminator(nn.Module): + @property + def wav_rate(self): + return self.hp.wav_rate + + def __init__(self, hp: HParams): + super().__init__() + self.hp = hp + self.stft_cfgs = get_stft_cfgs(hp) + self.mpd = MPD() + self.mrd = MRD(self.stft_cfgs) + self.dummy_float: Tensor + self.register_buffer("dummy_float", torch.zeros(0), persistent=False) + + def loss_type_(self, loss_type): + self.mpd.loss_type_(loss_type) + self.mrd.loss_type_(loss_type) + + def forward(self, fake, real=None): + """ + Args: + fake: [B T] + real: [B T] + """ + fake = fake.to(self.dummy_float) + + if real is None: + self.loss_type_("wgan") + else: + length_difference = (fake.shape[-1] - real.shape[-1]) / real.shape[-1] + assert length_difference < 0.05, f"length_difference should be smaller than 5%" + + self.loss_type_("hinge") + real = real.to(self.dummy_float) + + fake = fake[..., : real.shape[-1]] + real = real[..., : fake.shape[-1]] + + losses = {} + + assert fake.dim() == 2, f"(B, T) is expected, but got {fake.shape}." + assert real is None or real.dim() == 2, f"(B, T) is expected, but got {real.shape}." + + fake = fake.unsqueeze(1) + + if real is None: + losses["mpd"] = self.mpd(fake, 1) + losses["mrd"] = self.mrd(fake, 1) + else: + real = real.unsqueeze(1) + losses["mpd_fake"] = self.mpd(fake, 0) + losses["mpd_real"] = self.mpd(real, 1) + losses["mrd_fake"] = self.mrd(fake, 0) + losses["mrd_real"] = self.mrd(real, 1) + + return losses diff --git a/modules/repos_static/resemble_enhance/enhancer/univnet/lvcnet.py b/modules/repos_static/resemble_enhance/enhancer/univnet/lvcnet.py new file mode 100644 index 0000000000000000000000000000000000000000..da56619090206c45fece0bc2c70e8fd3d2513704 --- /dev/null +++ b/modules/repos_static/resemble_enhance/enhancer/univnet/lvcnet.py @@ -0,0 +1,281 @@ +""" refer from https://github.com/zceng/LVCNet """ + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import nn +from torch.nn.utils.parametrizations import weight_norm + +from .amp import AMPBlock + + +class KernelPredictor(torch.nn.Module): + """Kernel predictor for the location-variable convolutions""" + + def __init__( + self, + cond_channels, + conv_in_channels, + conv_out_channels, + conv_layers, + conv_kernel_size=3, + kpnet_hidden_channels=64, + kpnet_conv_size=3, + kpnet_dropout=0.0, + kpnet_nonlinear_activation="LeakyReLU", + kpnet_nonlinear_activation_params={"negative_slope": 0.1}, + ): + """ + Args: + cond_channels (int): number of channel for the conditioning sequence, + conv_in_channels (int): number of channel for the input sequence, + conv_out_channels (int): number of channel for the output sequence, + conv_layers (int): number of layers + """ + super().__init__() + + self.conv_in_channels = conv_in_channels + self.conv_out_channels = conv_out_channels + self.conv_kernel_size = conv_kernel_size + self.conv_layers = conv_layers + + kpnet_kernel_channels = conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers # l_w + kpnet_bias_channels = conv_out_channels * conv_layers # l_b + + self.input_conv = nn.Sequential( + weight_norm(nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)), + getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), + ) + + self.residual_convs = nn.ModuleList() + padding = (kpnet_conv_size - 1) // 2 + for _ in range(3): + self.residual_convs.append( + nn.Sequential( + nn.Dropout(kpnet_dropout), + weight_norm( + nn.Conv1d( + kpnet_hidden_channels, + kpnet_hidden_channels, + kpnet_conv_size, + padding=padding, + bias=True, + ) + ), + getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), + weight_norm( + nn.Conv1d( + kpnet_hidden_channels, + kpnet_hidden_channels, + kpnet_conv_size, + padding=padding, + bias=True, + ) + ), + getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), + ) + ) + self.kernel_conv = weight_norm( + nn.Conv1d( + kpnet_hidden_channels, + kpnet_kernel_channels, + kpnet_conv_size, + padding=padding, + bias=True, + ) + ) + self.bias_conv = weight_norm( + nn.Conv1d( + kpnet_hidden_channels, + kpnet_bias_channels, + kpnet_conv_size, + padding=padding, + bias=True, + ) + ) + + def forward(self, c): + """ + Args: + c (Tensor): the conditioning sequence (batch, cond_channels, cond_length) + """ + batch, _, cond_length = c.shape + c = self.input_conv(c) + for residual_conv in self.residual_convs: + residual_conv.to(c.device) + c = c + residual_conv(c) + k = self.kernel_conv(c) + b = self.bias_conv(c) + kernels = k.contiguous().view( + batch, + self.conv_layers, + self.conv_in_channels, + self.conv_out_channels, + self.conv_kernel_size, + cond_length, + ) + bias = b.contiguous().view( + batch, + self.conv_layers, + self.conv_out_channels, + cond_length, + ) + + return kernels, bias + + +class LVCBlock(torch.nn.Module): + """the location-variable convolutions""" + + def __init__( + self, + in_channels, + cond_channels, + stride, + dilations=[1, 3, 9, 27], + lReLU_slope=0.2, + conv_kernel_size=3, + cond_hop_length=256, + kpnet_hidden_channels=64, + kpnet_conv_size=3, + kpnet_dropout=0.0, + add_extra_noise=False, + downsampling=False, + ): + super().__init__() + + self.add_extra_noise = add_extra_noise + + self.cond_hop_length = cond_hop_length + self.conv_layers = len(dilations) + self.conv_kernel_size = conv_kernel_size + + self.kernel_predictor = KernelPredictor( + cond_channels=cond_channels, + conv_in_channels=in_channels, + conv_out_channels=2 * in_channels, + conv_layers=len(dilations), + conv_kernel_size=conv_kernel_size, + kpnet_hidden_channels=kpnet_hidden_channels, + kpnet_conv_size=kpnet_conv_size, + kpnet_dropout=kpnet_dropout, + kpnet_nonlinear_activation_params={"negative_slope": lReLU_slope}, + ) + + if downsampling: + self.convt_pre = nn.Sequential( + nn.LeakyReLU(lReLU_slope), + weight_norm(nn.Conv1d(in_channels, in_channels, 2 * stride + 1, padding="same")), + nn.AvgPool1d(stride, stride), + ) + else: + if stride == 1: + self.convt_pre = nn.Sequential( + nn.LeakyReLU(lReLU_slope), + weight_norm(nn.Conv1d(in_channels, in_channels, 1)), + ) + else: + self.convt_pre = nn.Sequential( + nn.LeakyReLU(lReLU_slope), + weight_norm( + nn.ConvTranspose1d( + in_channels, + in_channels, + 2 * stride, + stride=stride, + padding=stride // 2 + stride % 2, + output_padding=stride % 2, + ) + ), + ) + + self.amp_block = AMPBlock(in_channels) + + self.conv_blocks = nn.ModuleList() + for d in dilations: + self.conv_blocks.append( + nn.Sequential( + nn.LeakyReLU(lReLU_slope), + weight_norm(nn.Conv1d(in_channels, in_channels, conv_kernel_size, dilation=d, padding="same")), + nn.LeakyReLU(lReLU_slope), + ) + ) + + def forward(self, x, c): + """forward propagation of the location-variable convolutions. + Args: + x (Tensor): the input sequence (batch, in_channels, in_length) + c (Tensor): the conditioning sequence (batch, cond_channels, cond_length) + + Returns: + Tensor: the output sequence (batch, in_channels, in_length) + """ + _, in_channels, _ = x.shape # (B, c_g, L') + + x = self.convt_pre(x) # (B, c_g, stride * L') + + # Add one amp block just after the upsampling + x = self.amp_block(x) # (B, c_g, stride * L') + + kernels, bias = self.kernel_predictor(c) + + if self.add_extra_noise: + # Add extra noise to part of the feature + a, b = x.chunk(2, dim=1) + b = b + torch.randn_like(b) * 0.1 + x = torch.cat([a, b], dim=1) + + for i, conv in enumerate(self.conv_blocks): + output = conv(x) # (B, c_g, stride * L') + + k = kernels[:, i, :, :, :, :] # (B, 2 * c_g, c_g, kernel_size, cond_length) + b = bias[:, i, :, :] # (B, 2 * c_g, cond_length) + + output = self.location_variable_convolution( + output, k, b, hop_size=self.cond_hop_length + ) # (B, 2 * c_g, stride * L'): LVC + x = x + torch.sigmoid(output[:, :in_channels, :]) * torch.tanh( + output[:, in_channels:, :] + ) # (B, c_g, stride * L'): GAU + + return x + + def location_variable_convolution(self, x, kernel, bias, dilation=1, hop_size=256): + """perform location-variable convolution operation on the input sequence (x) using the local convolution kernl. + Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100. + Args: + x (Tensor): the input sequence (batch, in_channels, in_length). + kernel (Tensor): the local convolution kernel (batch, in_channel, out_channels, kernel_size, kernel_length) + bias (Tensor): the bias for the local convolution (batch, out_channels, kernel_length) + dilation (int): the dilation of convolution. + hop_size (int): the hop_size of the conditioning sequence. + Returns: + (Tensor): the output sequence after performing local convolution. (batch, out_channels, in_length). + """ + batch, _, in_length = x.shape + batch, _, out_channels, kernel_size, kernel_length = kernel.shape + + assert in_length == ( + kernel_length * hop_size + ), f"length of (x, kernel) is not matched, {in_length} != {kernel_length} * {hop_size}" + + padding = dilation * int((kernel_size - 1) / 2) + x = F.pad(x, (padding, padding), "constant", 0) # (batch, in_channels, in_length + 2*padding) + x = x.unfold(2, hop_size + 2 * padding, hop_size) # (batch, in_channels, kernel_length, hop_size + 2*padding) + + if hop_size < dilation: + x = F.pad(x, (0, dilation), "constant", 0) + x = x.unfold( + 3, dilation, dilation + ) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation) + x = x[:, :, :, :, :hop_size] + x = x.transpose(3, 4) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation) + x = x.unfold(4, kernel_size, 1) # (batch, in_channels, kernel_length, dilation, _, kernel_size) + + o = torch.einsum("bildsk,biokl->bolsd", x, kernel) + o = o.to(memory_format=torch.channels_last_3d) + bias = bias.unsqueeze(-1).unsqueeze(-1).to(memory_format=torch.channels_last_3d) + o = o + bias + o = o.contiguous().view(batch, out_channels, -1) + + return o diff --git a/modules/repos_static/resemble_enhance/enhancer/univnet/mrstft.py b/modules/repos_static/resemble_enhance/enhancer/univnet/mrstft.py new file mode 100644 index 0000000000000000000000000000000000000000..ce95b43269c17ff05736bc338220e59345524309 --- /dev/null +++ b/modules/repos_static/resemble_enhance/enhancer/univnet/mrstft.py @@ -0,0 +1,128 @@ +# -*- coding: utf-8 -*- + +# Copyright 2019 Tomoki Hayashi +# MIT License (https://opensource.org/licenses/MIT) + + +import torch +import torch.nn.functional as F +from torch import nn + +from ..hparams import HParams + + +def _make_stft_cfg(hop_length, win_length=None): + if win_length is None: + win_length = 4 * hop_length + n_fft = 2 ** (win_length - 1).bit_length() + return dict(n_fft=n_fft, hop_length=hop_length, win_length=win_length) + + +def get_stft_cfgs(hp: HParams): + assert hp.wav_rate == 44100, f"wav_rate must be 44100, got {hp.wav_rate}" + return [_make_stft_cfg(h) for h in (100, 256, 512)] + + +def stft(x, n_fft, hop_length, win_length, window): + dtype = x.dtype + x = torch.stft(x.float(), n_fft, hop_length, win_length, window, return_complex=True) + x = x.abs().to(dtype) + x = x.transpose(2, 1) # (b f t) -> (b t f) + return x + + +class SpectralConvergengeLoss(nn.Module): + def forward(self, x_mag, y_mag): + """Calculate forward propagation. + Args: + x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins). + y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins). + Returns: + Tensor: Spectral convergence loss value. + """ + return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro") + + +class LogSTFTMagnitudeLoss(nn.Module): + def forward(self, x_mag, y_mag): + """Calculate forward propagation. + Args: + x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins). + y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins). + Returns: + Tensor: Log STFT magnitude loss value. + """ + return F.l1_loss(torch.log1p(x_mag), torch.log1p(y_mag)) + + +class STFTLoss(nn.Module): + def __init__(self, hp, stft_cfg: dict, window="hann_window"): + super().__init__() + self.hp = hp + self.stft_cfg = stft_cfg + self.spectral_convergenge_loss = SpectralConvergengeLoss() + self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss() + self.register_buffer("window", getattr(torch, window)(stft_cfg["win_length"]), persistent=False) + + def forward(self, x, y): + """Calculate forward propagation. + Args: + x (Tensor): Predicted signal (B, T). + y (Tensor): Groundtruth signal (B, T). + Returns: + Tensor: Spectral convergence loss value. + Tensor: Log STFT magnitude loss value. + """ + stft_cfg = dict(self.stft_cfg) + x_mag = stft(x, **stft_cfg, window=self.window) # (b t) -> (b t f) + y_mag = stft(y, **stft_cfg, window=self.window) + sc_loss = self.spectral_convergenge_loss(x_mag, y_mag) + mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag) + return dict(sc=sc_loss, mag=mag_loss) + + +class MRSTFTLoss(nn.Module): + def __init__(self, hp: HParams, window="hann_window"): + """Initialize Multi resolution STFT loss module. + Args: + resolutions (list): List of (FFT size, hop size, window length). + window (str): Window function type. + """ + super().__init__() + stft_cfgs = get_stft_cfgs(hp) + self.stft_losses = nn.ModuleList() + self.hp = hp + for c in stft_cfgs: + self.stft_losses += [STFTLoss(hp, c, window=window)] + + def forward(self, x, y): + """Calculate forward propagation. + Args: + x (Tensor): Predicted signal (b t). + y (Tensor): Groundtruth signal (b t). + Returns: + Tensor: Multi resolution spectral convergence loss value. + Tensor: Multi resolution log STFT magnitude loss value. + """ + assert x.dim() == 2 and y.dim() == 2, f"(b t) is expected, but got {x.shape} and {y.shape}." + + dtype = x.dtype + + x = x.float() + y = y.float() + + # Align length + x = x[..., : y.shape[-1]] + y = y[..., : x.shape[-1]] + + losses = {} + + for f in self.stft_losses: + d = f(x, y) + for k, v in d.items(): + losses.setdefault(k, []).append(v) + + for k, v in losses.items(): + losses[k] = torch.stack(v, dim=0).mean().to(dtype) + + return losses diff --git a/modules/repos_static/resemble_enhance/enhancer/univnet/univnet.py b/modules/repos_static/resemble_enhance/enhancer/univnet/univnet.py new file mode 100644 index 0000000000000000000000000000000000000000..bb20217f048f398236698f6a38927310d0c1ba9b --- /dev/null +++ b/modules/repos_static/resemble_enhance/enhancer/univnet/univnet.py @@ -0,0 +1,94 @@ +import numpy as np +import torch +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.utils.parametrizations import weight_norm + +from ..hparams import HParams +from .lvcnet import LVCBlock +from .mrstft import MRSTFTLoss + + +class UnivNet(nn.Module): + @property + def d_noise(self): + return 128 + + @property + def strides(self): + return [7, 5, 4, 3] + + @property + def dilations(self): + return [1, 3, 9, 27] + + @property + def nc(self): + return self.hp.univnet_nc + + @property + def scale_factor(self) -> int: + return self.hp.hop_size + + def __init__(self, hp: HParams, d_input): + super().__init__() + self.d_input = d_input + + self.hp = hp + + self.blocks = nn.ModuleList( + [ + LVCBlock( + self.nc, + d_input, + stride=stride, + dilations=self.dilations, + cond_hop_length=hop_length, + kpnet_conv_size=3, + ) + for stride, hop_length in zip(self.strides, np.cumprod(self.strides)) + ] + ) + + self.conv_pre = weight_norm(nn.Conv1d(self.d_noise, self.nc, 7, padding=3, padding_mode="reflect")) + + self.conv_post = nn.Sequential( + nn.LeakyReLU(0.2), + weight_norm(nn.Conv1d(self.nc, 1, 7, padding=3, padding_mode="reflect")), + nn.Tanh(), + ) + + self.mrstft = MRSTFTLoss(hp) + + @property + def eps(self): + return 1e-5 + + def forward(self, x: Tensor, y: Tensor | None = None, npad=10): + """ + Args: + x: (b c t), acoustic features + y: (b t), waveform + Returns: + z: (b t), waveform + """ + assert x.ndim == 3, "x must be 3D tensor" + assert y is None or y.ndim == 2, "y must be 2D tensor" + assert x.shape[1] == self.d_input, f"x.shape[1] must be {self.d_input}, but got {x.shape}" + assert npad >= 0, "npad must be positive or zero" + + x = F.pad(x, (0, npad), "constant", 0) + z = torch.randn(x.shape[0], self.d_noise, x.shape[2]).to(x) + z = self.conv_pre(z) # (b c t) + + for block in self.blocks: + z = block(z, x) # (b c t) + + z = self.conv_post(z) # (b 1 t) + z = z[..., : -self.scale_factor * npad] + z = z.squeeze(1) # (b t) + + if y is not None: + self.losses = self.mrstft(z, y) + + return z diff --git a/modules/repos_static/resemble_enhance/hparams.py b/modules/repos_static/resemble_enhance/hparams.py new file mode 100644 index 0000000000000000000000000000000000000000..a8e716175fa962ada1d98cd755430e2ea770278c --- /dev/null +++ b/modules/repos_static/resemble_enhance/hparams.py @@ -0,0 +1,128 @@ +import logging +from dataclasses import asdict, dataclass +from pathlib import Path + +from omegaconf import OmegaConf +from rich.console import Console +from rich.panel import Panel +from rich.table import Table + +logger = logging.getLogger(__name__) + +console = Console() + + +def _make_stft_cfg(hop_length, win_length=None): + if win_length is None: + win_length = 4 * hop_length + n_fft = 2 ** (win_length - 1).bit_length() + return dict(n_fft=n_fft, hop_length=hop_length, win_length=win_length) + + +def _build_rich_table(rows, columns, title=None): + table = Table(title=title, header_style=None) + for column in columns: + table.add_column(column.capitalize(), justify="left") + for row in rows: + table.add_row(*map(str, row)) + return Panel(table, expand=False) + + +def _rich_print_dict(d, title="Config", key="Key", value="Value"): + console.print(_build_rich_table(d.items(), [key, value], title)) + + +@dataclass(frozen=True) +class HParams: + # Dataset + fg_dir: Path = Path("data/fg") + bg_dir: Path = Path("data/bg") + rir_dir: Path = Path("data/rir") + load_fg_only: bool = False + praat_augment_prob: float = 0 + + # Audio settings + wav_rate: int = 44_100 + n_fft: int = 2048 + win_size: int = 2048 + hop_size: int = 420 # 9.5ms + num_mels: int = 128 + stft_magnitude_min: float = 1e-4 + preemphasis: float = 0.97 + mix_alpha_range: tuple[float, float] = (0.2, 0.8) + + # Training + nj: int = 64 + training_seconds: float = 1.0 + batch_size_per_gpu: int = 16 + min_lr: float = 1e-5 + max_lr: float = 1e-4 + warmup_steps: int = 1000 + max_steps: int = 1_000_000 + gradient_clipping: float = 1.0 + + @property + def deepspeed_config(self): + return { + "train_micro_batch_size_per_gpu": self.batch_size_per_gpu, + "optimizer": { + "type": "Adam", + "params": {"lr": float(self.min_lr)}, + }, + "scheduler": { + "type": "WarmupDecayLR", + "params": { + "warmup_min_lr": float(self.min_lr), + "warmup_max_lr": float(self.max_lr), + "warmup_num_steps": self.warmup_steps, + "total_num_steps": self.max_steps, + "warmup_type": "linear", + }, + }, + "gradient_clipping": self.gradient_clipping, + } + + @property + def stft_cfgs(self): + assert self.wav_rate == 44_100, f"wav_rate must be 44_100, got {self.wav_rate}" + return [_make_stft_cfg(h) for h in (100, 256, 512)] + + @classmethod + def from_yaml(cls, path: Path) -> "HParams": + logger.info(f"Reading hparams from {path}") + # First merge to fix types (e.g., str -> Path) + return cls(**dict(OmegaConf.merge(cls(), OmegaConf.load(path)))) + + def save_if_not_exists(self, run_dir: Path): + path = run_dir / "hparams.yaml" + if path.exists(): + logger.info(f"{path} already exists, not saving") + return + path.parent.mkdir(parents=True, exist_ok=True) + OmegaConf.save(asdict(self), str(path)) + + @classmethod + def load(cls, run_dir, yaml: Path | None = None): + hps = [] + + if (run_dir / "hparams.yaml").exists(): + hps.append(cls.from_yaml(run_dir / "hparams.yaml")) + + if yaml is not None: + hps.append(cls.from_yaml(yaml)) + + if len(hps) == 0: + hps.append(cls()) + + for hp in hps[1:]: + if hp != hps[0]: + errors = {} + for k, v in asdict(hp).items(): + if getattr(hps[0], k) != v: + errors[k] = f"{getattr(hps[0], k)} != {v}" + raise ValueError(f"Found inconsistent hparams: {errors}, consider deleting {run_dir}") + + return hps[0] + + def print(self): + _rich_print_dict(asdict(self), title="HParams") diff --git a/modules/repos_static/resemble_enhance/inference.py b/modules/repos_static/resemble_enhance/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..6e78a11fdf134bcc182e5c9ef0cf81e02c64850b --- /dev/null +++ b/modules/repos_static/resemble_enhance/inference.py @@ -0,0 +1,163 @@ +import logging +import time + +import torch +import torch.nn.functional as F +from torch.nn.utils.parametrize import remove_parametrizations +from torchaudio.functional import resample +from torchaudio.transforms import MelSpectrogram +from tqdm import trange + +from .hparams import HParams + +logger = logging.getLogger(__name__) + + +@torch.inference_mode() +def inference_chunk(model, dwav, sr, device, npad=441): + assert model.hp.wav_rate == sr, f"Expected {model.hp.wav_rate} Hz, got {sr} Hz" + del sr + + length = dwav.shape[-1] + abs_max = dwav.abs().max().clamp(min=1e-7) + + assert dwav.dim() == 1, f"Expected 1D waveform, got {dwav.dim()}D" + dwav = dwav.to(device) + dwav = dwav / abs_max # Normalize + dwav = F.pad(dwav, (0, npad)) + hwav = model(dwav[None])[0].cpu() # (T,) + hwav = hwav[:length] # Trim padding + hwav = hwav * abs_max # Unnormalize + + return hwav + + +def compute_corr(x, y): + return torch.fft.ifft(torch.fft.fft(x) * torch.fft.fft(y).conj()).abs() + + +def compute_offset(chunk1, chunk2, sr=44100): + """ + Args: + chunk1: (T,) + chunk2: (T,) + Returns: + offset: int, offset in samples such that chunk1 ~= chunk2.roll(-offset) + """ + hop_length = sr // 200 # 5 ms resolution + win_length = hop_length * 4 + n_fft = 2 ** (win_length - 1).bit_length() + + mel_fn = MelSpectrogram( + sample_rate=sr, + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + n_mels=80, + f_min=0.0, + f_max=sr // 2, + ) + + spec1 = mel_fn(chunk1).log1p() + spec2 = mel_fn(chunk2).log1p() + + corr = compute_corr(spec1, spec2) # (F, T) + corr = corr.mean(dim=0) # (T,) + + argmax = corr.argmax().item() + + if argmax > len(corr) // 2: + argmax -= len(corr) + + offset = -argmax * hop_length + + return offset + + +def merge_chunks(chunks, chunk_length, hop_length, sr=44100, length=None): + signal_length = (len(chunks) - 1) * hop_length + chunk_length + overlap_length = chunk_length - hop_length + signal = torch.zeros(signal_length, device=chunks[0].device) + + fadein = torch.linspace(0, 1, overlap_length, device=chunks[0].device) + fadein = torch.cat([fadein, torch.ones(hop_length, device=chunks[0].device)]) + fadeout = torch.linspace(1, 0, overlap_length, device=chunks[0].device) + fadeout = torch.cat([torch.ones(hop_length, device=chunks[0].device), fadeout]) + + for i, chunk in enumerate(chunks): + start = i * hop_length + end = start + chunk_length + + if len(chunk) < chunk_length: + chunk = F.pad(chunk, (0, chunk_length - len(chunk))) + + if i > 0: + pre_region = chunks[i - 1][-overlap_length:] + cur_region = chunk[:overlap_length] + offset = compute_offset(pre_region, cur_region, sr=sr) + start -= offset + end -= offset + + if i == 0: + chunk = chunk * fadeout + elif i == len(chunks) - 1: + chunk = chunk * fadein + else: + chunk = chunk * fadein * fadeout + + signal[start:end] += chunk[: len(signal[start:end])] + + signal = signal[:length] + + return signal + + +def remove_weight_norm_recursively(module): + for _, module in module.named_modules(): + try: + remove_parametrizations(module, "weight") + except Exception: + pass + + +def inference(model, dwav, sr, device, chunk_seconds: float = 30.0, overlap_seconds: float = 1.0): + remove_weight_norm_recursively(model) + + hp: HParams = model.hp + + dwav = resample( + dwav, + orig_freq=sr, + new_freq=hp.wav_rate, + lowpass_filter_width=64, + rolloff=0.9475937167399596, + resampling_method="sinc_interp_kaiser", + beta=14.769656459379492, + ) + + del sr # Everything is in hp.wav_rate now + + sr = hp.wav_rate + + if torch.cuda.is_available(): + torch.cuda.synchronize() + + start_time = time.perf_counter() + + chunk_length = int(sr * chunk_seconds) + overlap_length = int(sr * overlap_seconds) + hop_length = chunk_length - overlap_length + + chunks = [] + for start in trange(0, dwav.shape[-1], hop_length): + chunks.append(inference_chunk(model, dwav[start : start + chunk_length], sr, device)) + + hwav = merge_chunks(chunks, chunk_length, hop_length, sr=sr, length=dwav.shape[-1]) + + if torch.cuda.is_available(): + torch.cuda.synchronize() + + elapsed_time = time.perf_counter() - start_time + logger.info(f"Elapsed time: {elapsed_time:.3f} s, {hwav.shape[-1] / elapsed_time / 1000:.3f} kHz") + + return hwav, sr diff --git a/modules/repos_static/resemble_enhance/melspec.py b/modules/repos_static/resemble_enhance/melspec.py new file mode 100644 index 0000000000000000000000000000000000000000..dce1f8bfb95b9a1814db8c7305c07ccf2bfa9111 --- /dev/null +++ b/modules/repos_static/resemble_enhance/melspec.py @@ -0,0 +1,61 @@ +import numpy as np +import torch +from torch import nn +from torchaudio.transforms import MelSpectrogram as TorchMelSpectrogram + +from .hparams import HParams + + +class MelSpectrogram(nn.Module): + def __init__(self, hp: HParams): + """ + Torch implementation of Resemble's mel extraction. + Note that the values are NOT identical to librosa's implementation + due to floating point precisions. + """ + super().__init__() + self.hp = hp + self.melspec = TorchMelSpectrogram( + hp.wav_rate, + n_fft=hp.n_fft, + win_length=hp.win_size, + hop_length=hp.hop_size, + f_min=0, + f_max=hp.wav_rate // 2, + n_mels=hp.num_mels, + power=1, + normalized=False, + # NOTE: Folowing librosa's default. + pad_mode="constant", + norm="slaney", + mel_scale="slaney", + ) + self.register_buffer("stft_magnitude_min", torch.FloatTensor([hp.stft_magnitude_min])) + self.min_level_db = 20 * np.log10(hp.stft_magnitude_min) + self.preemphasis = hp.preemphasis + self.hop_size = hp.hop_size + + def forward(self, wav, pad=True): + """ + Args: + wav: [B, T] + """ + device = wav.device + if wav.is_mps: + wav = wav.cpu() + self.to(wav.device) + if self.preemphasis > 0: + wav = torch.nn.functional.pad(wav, [1, 0], value=0) + wav = wav[..., 1:] - self.preemphasis * wav[..., :-1] + mel = self.melspec(wav) + mel = self._amp_to_db(mel) + mel_normed = self._normalize(mel) + assert not pad or mel_normed.shape[-1] == 1 + wav.shape[-1] // self.hop_size # Sanity check + mel_normed = mel_normed.to(device) + return mel_normed # (M, T) + + def _normalize(self, s, headroom_db=15): + return (s - self.min_level_db) / (-self.min_level_db + headroom_db) + + def _amp_to_db(self, x): + return x.clamp_min(self.hp.stft_magnitude_min).log10() * 20 diff --git a/modules/repos_static/resemble_enhance/utils/__init__.py b/modules/repos_static/resemble_enhance/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..07f6303742506be443f477c40a42a9551b4e8af4 --- /dev/null +++ b/modules/repos_static/resemble_enhance/utils/__init__.py @@ -0,0 +1,2 @@ +from .logging import setup_logging +from .utils import save_mels, tree_map diff --git a/modules/repos_static/resemble_enhance/utils/control.py b/modules/repos_static/resemble_enhance/utils/control.py new file mode 100644 index 0000000000000000000000000000000000000000..56b74b46d73b0c3757849dad310ca0899bb5f5a4 --- /dev/null +++ b/modules/repos_static/resemble_enhance/utils/control.py @@ -0,0 +1,26 @@ +import logging +import selectors +import sys +from functools import cache + +from .distributed import global_leader_only + +_logger = logging.getLogger(__name__) + + +@cache +def _get_stdin_selector(): + selector = selectors.DefaultSelector() + selector.register(fileobj=sys.stdin, events=selectors.EVENT_READ) + return selector + + +@global_leader_only(boardcast_return=True) +def non_blocking_input(): + s = "" + selector = _get_stdin_selector() + events = selector.select(timeout=0) + for key, _ in events: + s: str = key.fileobj.readline().strip() + _logger.info(f'Get stdin "{s}".') + return s diff --git a/modules/repos_static/resemble_enhance/utils/logging.py b/modules/repos_static/resemble_enhance/utils/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..26c43b6dc785ff6547478cb04833dd92b5df7311 --- /dev/null +++ b/modules/repos_static/resemble_enhance/utils/logging.py @@ -0,0 +1,38 @@ +import logging +from pathlib import Path + +from rich.logging import RichHandler + +from .distributed import global_leader_only + + +@global_leader_only +def setup_logging(run_dir): + handlers = [] + stdout_handler = RichHandler() + stdout_handler.setLevel(logging.INFO) + handlers.append(stdout_handler) + + if run_dir is not None: + filename = Path(run_dir) / f"log.txt" + filename.parent.mkdir(parents=True, exist_ok=True) + file_handler = logging.FileHandler(filename, mode="a") + file_handler.setLevel(logging.DEBUG) + handlers.append(file_handler) + + # Update all existing loggers + for name in ["DeepSpeed"]: + logger = logging.getLogger(name) + if isinstance(logger, logging.Logger): + for handler in list(logger.handlers): + logger.removeHandler(handler) + for handler in handlers: + logger.addHandler(handler) + + # Set the default logger + logging.basicConfig( + level=logging.getLevelName("INFO"), + format="%(message)s", + datefmt="[%X]", + handlers=handlers, + ) diff --git a/modules/repos_static/resemble_enhance/utils/utils.py b/modules/repos_static/resemble_enhance/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c402c9ae2bd634e903d2a9861243005e6a8c9147 --- /dev/null +++ b/modules/repos_static/resemble_enhance/utils/utils.py @@ -0,0 +1,73 @@ +from typing import Callable, TypeVar, overload + +import matplotlib.pyplot as plt +import numpy as np + + +def save_mels(path, *, targ_mel, pred_mel, cond_mel): + n = 3 if cond_mel is None else 4 + + plt.figure(figsize=(10, n * 4)) + + i = 1 + + plt.subplot(n, 1, i) + plt.imshow(pred_mel, origin="lower", interpolation="none") + plt.title(f"Pred mel {pred_mel.shape}") + i += 1 + + plt.subplot(n, 1, i) + plt.imshow(targ_mel, origin="lower", interpolation="none") + plt.title(f"GT mel {targ_mel.shape}") + i += 1 + + plt.subplot(n, 1, i) + pred_mel = pred_mel[:, : targ_mel.shape[1]] + targ_mel = targ_mel[:, : pred_mel.shape[1]] + plt.imshow(np.abs(pred_mel - targ_mel), origin="lower", interpolation="none") + plt.title(f"Diff mel {pred_mel.shape}, mse={np.mean((pred_mel - targ_mel)**2):.4f}") + i += 1 + + if cond_mel is not None: + plt.subplot(n, 1, i) + plt.imshow(cond_mel, origin="lower", interpolation="none") + plt.title(f"Cond mel {cond_mel.shape}") + i += 1 + + plt.savefig(path, dpi=480) + plt.close() + + +T = TypeVar("T") + + +@overload +def tree_map(fn: Callable, x: list[T]) -> list[T]: + ... + + +@overload +def tree_map(fn: Callable, x: tuple[T]) -> tuple[T]: + ... + + +@overload +def tree_map(fn: Callable, x: dict[str, T]) -> dict[str, T]: + ... + + +@overload +def tree_map(fn: Callable, x: T) -> T: + ... + + +def tree_map(fn: Callable, x): + if isinstance(x, list): + x = [tree_map(fn, xi) for xi in x] + elif isinstance(x, tuple): + x = (tree_map(fn, xi) for xi in x) + elif isinstance(x, dict): + x = {k: tree_map(fn, v) for k, v in x.items()} + else: + x = fn(x) + return x diff --git a/modules/speaker.py b/modules/speaker.py index d066f2b20d3cd0d54331eba6c7b905db695bc794..2fcbc5a4ac99a89f382fd4f9988757d4c8f71470 100644 --- a/modules/speaker.py +++ b/modules/speaker.py @@ -99,6 +99,10 @@ class SpeakerManager: self.speakers[speaker_file] = Speaker.from_file( self.speaker_dir + speaker_file ) + # 检查是否有被删除的,同步到 speakers + for fname, spk in self.speakers.items(): + if not os.path.exists(self.speaker_dir + fname): + del self.speakers[fname] def list_speakers(self): return list(self.speakers.values()) diff --git a/modules/webui/speaker/__init__.py b/modules/webui/speaker/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/modules/webui/speaker/speaker_creator.py b/modules/webui/speaker/speaker_creator.py new file mode 100644 index 0000000000000000000000000000000000000000..344ab1f08317be3f7741d0aa3d3c705b64e34aef --- /dev/null +++ b/modules/webui/speaker/speaker_creator.py @@ -0,0 +1,171 @@ +import gradio as gr +import torch +from modules.speaker import Speaker +from modules.utils.SeedContext import SeedContext +from modules.hf import spaces +from modules.models import load_chat_tts +from modules.utils.rng import np_rng +from modules.webui.webui_utils import get_speakers, tts_generate + +import tempfile + +names_list = [ + "Alice", + "Bob", + "Carol", + "Carlos", + "Charlie", + "Chuck", + "Chad", + "Craig", + "Dan", + "Dave", + "David", + "Erin", + "Eve", + "Yves", + "Faythe", + "Frank", + "Grace", + "Heidi", + "Ivan", + "Judy", + "Mallory", + "Mallet", + "Darth", + "Michael", + "Mike", + "Niaj", + "Olivia", + "Oscar", + "Peggy", + "Pat", + "Rupert", + "Sybil", + "Trent", + "Ted", + "Trudy", + "Victor", + "Vanna", + "Walter", + "Wendy", +] + + +@torch.inference_mode() +@spaces.GPU +def create_spk_from_seed( + seed: int, + name: str, + gender: str, + desc: str, +): + chat_tts = load_chat_tts() + with SeedContext(seed): + emb = chat_tts.sample_random_speaker() + spk = Speaker(seed=-2, name=name, gender=gender, describe=desc) + spk.emb = emb + + with tempfile.NamedTemporaryFile(delete=False, suffix=".pt") as tmp_file: + torch.save(spk, tmp_file) + tmp_file_path = tmp_file.name + + return tmp_file_path + + +@torch.inference_mode() +@spaces.GPU +def test_spk_voice(seed: int, text: str): + return tts_generate( + spk=seed, + text=text, + ) + + +def random_speaker(): + seed = np_rng() + name = names_list[seed % len(names_list)] + return seed, name + + +creator_ui_desc = """ +## Speaker Creator +使用本面板快捷抽卡生成 speaker.pt 文件。 + +1. **生成说话人**:输入种子、名字、性别和描述。点击 "Generate speaker.pt" 按钮,生成的说话人配置会保存为.pt文件。 +2. **测试说话人声音**:输入测试文本。点击 "Test Voice" 按钮,生成的音频会在 "Output Audio" 中播放。 +3. **随机生成说话人**:点击 "Random Speaker" 按钮,随机生成一个种子和名字,可以进一步编辑其他信息并测试。 +""" + + +def speaker_creator_ui(): + def on_generate(seed, name, gender, desc): + file_path = create_spk_from_seed(seed, name, gender, desc) + return file_path + + def create_test_voice_card(seed_input): + with gr.Group(): + gr.Markdown("🎤Test voice") + with gr.Row(): + test_voice_btn = gr.Button("Test Voice", variant="secondary") + + with gr.Column(scale=4): + test_text = gr.Textbox( + label="Test Text", + placeholder="Please input test text", + value="说话人测试 123456789 [uv_break] ok, test done [lbreak]", + ) + with gr.Row(): + current_seed = gr.Label(label="Current Seed", value=-1) + with gr.Column(scale=4): + output_audio = gr.Audio(label="Output Audio") + + test_voice_btn.click( + fn=test_spk_voice, + inputs=[seed_input, test_text], + outputs=[output_audio], + ) + test_voice_btn.click( + fn=lambda x: x, + inputs=[seed_input], + outputs=[current_seed], + ) + + gr.Markdown(creator_ui_desc) + + with gr.Row(): + with gr.Column(scale=2): + with gr.Group(): + gr.Markdown("ℹ️Speaker info") + seed_input = gr.Number(label="Seed", value=2) + name_input = gr.Textbox( + label="Name", placeholder="Enter speaker name", value="Bob" + ) + gender_input = gr.Textbox( + label="Gender", placeholder="Enter gender", value="*" + ) + desc_input = gr.Textbox( + label="Description", + placeholder="Enter description", + ) + random_button = gr.Button("Random Speaker") + with gr.Group(): + gr.Markdown("🔊Generate speaker.pt") + generate_button = gr.Button("Save .pt file") + output_file = gr.File(label="Save to File") + with gr.Column(scale=5): + create_test_voice_card(seed_input=seed_input) + create_test_voice_card(seed_input=seed_input) + create_test_voice_card(seed_input=seed_input) + create_test_voice_card(seed_input=seed_input) + + random_button.click( + random_speaker, + outputs=[seed_input, name_input], + ) + + generate_button.click( + fn=on_generate, + inputs=[seed_input, name_input, gender_input, desc_input], + outputs=[output_file], + ) diff --git a/modules/webui/speaker/speaker_merger.py b/modules/webui/speaker/speaker_merger.py new file mode 100644 index 0000000000000000000000000000000000000000..66a0854790d258f3e0cb3476efc995aac862364f --- /dev/null +++ b/modules/webui/speaker/speaker_merger.py @@ -0,0 +1,255 @@ +import io +import gradio as gr +import torch + +from modules.hf import spaces +from modules.webui.webui_utils import get_speakers, tts_generate +from modules.speaker import speaker_mgr, Speaker + +import tempfile + + +def spk_to_tensor(spk): + spk = spk.split(" : ")[1].strip() if " : " in spk else spk + if spk == "None" or spk == "": + return None + return speaker_mgr.get_speaker(spk).emb + + +def get_speaker_show_name(spk): + if spk.gender == "*" or spk.gender == "": + return spk.name + return f"{spk.gender} : {spk.name}" + + +def merge_spk( + spk_a, + spk_a_w, + spk_b, + spk_b_w, + spk_c, + spk_c_w, + spk_d, + spk_d_w, +): + tensor_a = spk_to_tensor(spk_a) + tensor_b = spk_to_tensor(spk_b) + tensor_c = spk_to_tensor(spk_c) + tensor_d = spk_to_tensor(spk_d) + + assert ( + tensor_a is not None + or tensor_b is not None + or tensor_c is not None + or tensor_d is not None + ), "At least one speaker should be selected" + + merge_tensor = torch.zeros_like( + tensor_a + if tensor_a is not None + else ( + tensor_b + if tensor_b is not None + else tensor_c if tensor_c is not None else tensor_d + ) + ) + + total_weight = 0 + if tensor_a is not None: + merge_tensor += spk_a_w * tensor_a + total_weight += spk_a_w + if tensor_b is not None: + merge_tensor += spk_b_w * tensor_b + total_weight += spk_b_w + if tensor_c is not None: + merge_tensor += spk_c_w * tensor_c + total_weight += spk_c_w + if tensor_d is not None: + merge_tensor += spk_d_w * tensor_d + total_weight += spk_d_w + + if total_weight > 0: + merge_tensor /= total_weight + + merged_spk = Speaker.from_tensor(merge_tensor) + merged_spk.name = "" + + return merged_spk + + +@torch.inference_mode() +@spaces.GPU +def merge_and_test_spk_voice( + spk_a, spk_a_w, spk_b, spk_b_w, spk_c, spk_c_w, spk_d, spk_d_w, test_text +): + merged_spk = merge_spk( + spk_a, + spk_a_w, + spk_b, + spk_b_w, + spk_c, + spk_c_w, + spk_d, + spk_d_w, + ) + return tts_generate( + spk=merged_spk, + text=test_text, + ) + + +@torch.inference_mode() +@spaces.GPU +def merge_spk_to_file( + spk_a, + spk_a_w, + spk_b, + spk_b_w, + spk_c, + spk_c_w, + spk_d, + spk_d_w, + speaker_name, + speaker_gender, + speaker_desc, +): + merged_spk = merge_spk( + spk_a, spk_a_w, spk_b, spk_b_w, spk_c, spk_c_w, spk_d, spk_d_w + ) + merged_spk.name = speaker_name + merged_spk.gender = speaker_gender + merged_spk.desc = speaker_desc + + with tempfile.NamedTemporaryFile(delete=False, suffix=".pt") as tmp_file: + torch.save(merged_spk, tmp_file) + tmp_file_path = tmp_file.name + + return tmp_file_path + + +merge_desc = """ +## Speaker Merger + +在本面板中,您可以选择多个说话人并指定他们的权重,合成新的语音并进行测试。以下是各个功能的详细说明: + +1. 选择说话人: 您可以从下拉菜单中选择最多四个说话人(A、B、C、D),每个说话人都有一个对应的权重滑块,范围从0到10。权重决定了每个说话人在合成语音中的影响程度。 +2. 合成语音: 在选择好说话人和设置好权重后,您可以在“Test Text”框中输入要测试的文本,然后点击“测试语音”按钮来生成并播放合成的语音。 +3. 保存说话人: 您还可以在右侧的“说话人信息”部分填写新的说话人的名称、性别和描述,并点击“Save Speaker”按钮来保存合成的说话人。保存后的说话人文件将显示在“Merged Speaker”栏中,供下载使用。 +""" + + +def get_spk_choices(): + speakers = get_speakers() + + speaker_names = ["None"] + [get_speaker_show_name(speaker) for speaker in speakers] + return speaker_names + + +# 显示 a b c d 四个选择框,选择一个或多个,然后可以试音,并导出 +def create_speaker_merger(): + speaker_names = get_spk_choices() + + gr.Markdown(merge_desc) + + def spk_picker(label_tail: str): + with gr.Row(): + spk_a = gr.Dropdown( + choices=speaker_names, value="None", label=f"Speaker {label_tail}" + ) + refresh_a_btn = gr.Button("🔄", variant="secondary") + + def refresh_a(): + speaker_mgr.refresh_speakers() + speaker_names = get_spk_choices() + return gr.update(choices=speaker_names) + + refresh_a_btn.click(refresh_a, outputs=[spk_a]) + spk_a_w = gr.Slider( + value=1, + minimum=0, + maximum=10, + step=0.1, + label=f"Weight {label_tail}", + ) + return spk_a, spk_a_w + + with gr.Row(): + with gr.Column(scale=5): + with gr.Row(): + with gr.Group(): + spk_a, spk_a_w = spk_picker("A") + + with gr.Group(): + spk_b, spk_b_w = spk_picker("B") + + with gr.Group(): + spk_c, spk_c_w = spk_picker("C") + + with gr.Group(): + spk_d, spk_d_w = spk_picker("D") + + with gr.Row(): + with gr.Column(scale=3): + with gr.Group(): + gr.Markdown("🎤Test voice") + with gr.Row(): + test_voice_btn = gr.Button( + "Test Voice", variant="secondary" + ) + + with gr.Column(scale=4): + test_text = gr.Textbox( + label="Test Text", + placeholder="Please input test text", + value="说话人合并测试 123456789 [uv_break] ok, test done [lbreak]", + ) + + output_audio = gr.Audio(label="Output Audio") + + with gr.Column(scale=1): + with gr.Group(): + gr.Markdown("🗃️Save to file") + + speaker_name = gr.Textbox(label="Name", value="forge_speaker_merged") + speaker_gender = gr.Textbox(label="Gender", value="*") + speaker_desc = gr.Textbox(label="Description", value="merged speaker") + + save_btn = gr.Button("Save Speaker", variant="primary") + + merged_spker = gr.File( + label="Merged Speaker", interactive=False, type="binary" + ) + + test_voice_btn.click( + merge_and_test_spk_voice, + inputs=[ + spk_a, + spk_a_w, + spk_b, + spk_b_w, + spk_c, + spk_c_w, + spk_d, + spk_d_w, + test_text, + ], + outputs=[output_audio], + ) + + save_btn.click( + merge_spk_to_file, + inputs=[ + spk_a, + spk_a_w, + spk_b, + spk_b_w, + spk_c, + spk_c_w, + spk_d, + spk_d_w, + speaker_name, + speaker_gender, + speaker_desc, + ], + outputs=[merged_spker], + ) diff --git a/modules/webui/speaker_tab.py b/modules/webui/speaker_tab.py index 31abf96c4b2acc213a674603ee0e44400add3e4b..4021bc6646a9877dcd29284f49e4a95cab3e6531 100644 --- a/modules/webui/speaker_tab.py +++ b/modules/webui/speaker_tab.py @@ -1,259 +1,13 @@ -import io import gradio as gr -import torch -from modules.hf import spaces -from modules.webui.webui_utils import get_speakers, tts_generate -from modules.speaker import speaker_mgr, Speaker +from modules.webui.speaker.speaker_merger import create_speaker_merger +from modules.webui.speaker.speaker_creator import speaker_creator_ui -import tempfile - -def spk_to_tensor(spk): - spk = spk.split(" : ")[1].strip() if " : " in spk else spk - if spk == "None" or spk == "": - return None - return speaker_mgr.get_speaker(spk).emb - - -def get_speaker_show_name(spk): - if spk.gender == "*" or spk.gender == "": - return spk.name - return f"{spk.gender} : {spk.name}" - - -def merge_spk( - spk_a, - spk_a_w, - spk_b, - spk_b_w, - spk_c, - spk_c_w, - spk_d, - spk_d_w, -): - tensor_a = spk_to_tensor(spk_a) - tensor_b = spk_to_tensor(spk_b) - tensor_c = spk_to_tensor(spk_c) - tensor_d = spk_to_tensor(spk_d) - - assert ( - tensor_a is not None - or tensor_b is not None - or tensor_c is not None - or tensor_d is not None - ), "At least one speaker should be selected" - - merge_tensor = torch.zeros_like( - tensor_a - if tensor_a is not None - else ( - tensor_b - if tensor_b is not None - else tensor_c if tensor_c is not None else tensor_d - ) - ) - - total_weight = 0 - if tensor_a is not None: - merge_tensor += spk_a_w * tensor_a - total_weight += spk_a_w - if tensor_b is not None: - merge_tensor += spk_b_w * tensor_b - total_weight += spk_b_w - if tensor_c is not None: - merge_tensor += spk_c_w * tensor_c - total_weight += spk_c_w - if tensor_d is not None: - merge_tensor += spk_d_w * tensor_d - total_weight += spk_d_w - - if total_weight > 0: - merge_tensor /= total_weight - - merged_spk = Speaker.from_tensor(merge_tensor) - merged_spk.name = "" - - return merged_spk - - -@torch.inference_mode() -@spaces.GPU -def merge_and_test_spk_voice( - spk_a, spk_a_w, spk_b, spk_b_w, spk_c, spk_c_w, spk_d, spk_d_w, test_text -): - merged_spk = merge_spk( - spk_a, - spk_a_w, - spk_b, - spk_b_w, - spk_c, - spk_c_w, - spk_d, - spk_d_w, - ) - return tts_generate( - spk=merged_spk, - text=test_text, - ) - - -@torch.inference_mode() -@spaces.GPU -def merge_spk_to_file( - spk_a, - spk_a_w, - spk_b, - spk_b_w, - spk_c, - spk_c_w, - spk_d, - spk_d_w, - speaker_name, - speaker_gender, - speaker_desc, -): - merged_spk = merge_spk( - spk_a, spk_a_w, spk_b, spk_b_w, spk_c, spk_c_w, spk_d, spk_d_w - ) - merged_spk.name = speaker_name - merged_spk.gender = speaker_gender - merged_spk.desc = speaker_desc - - with tempfile.NamedTemporaryFile(delete=False, suffix=".pt") as tmp_file: - torch.save(merged_spk, tmp_file) - tmp_file_path = tmp_file.name - - return tmp_file_path - - -merge_desc = """ -## Speaker Merger - -在本面板中,您可以选择多个说话人并指定他们的权重,合成新的语音并进行测试。以下是各个功能的详细说明: - -### 1. 选择说话人 -您可以从下拉菜单中选择最多四个说话人(A、B、C、D),每个说话人都有一个对应的权重滑块,范围从0到10。权重决定了每个说话人在合成语音中的影响程度。 - -### 2. 合成语音 -在选择好说话人和设置好权重后,您可以在“测试文本”框中输入要测试的文本,然后点击“测试语音”按钮来生成并播放合成的语音。 - -### 3. 保存说话人 -您还可以在右侧的“说话人信息”部分填写新的说话人的名称、性别和描述,并点击“保存说话人”按钮来保存合成的说话人。保存后的说话人文件将显示在“合成说话人”栏中,供下载使用。 -""" - - -# 显示 a b c d 四个选择框,选择一个或多个,然后可以试音,并导出 def create_speaker_panel(): - speakers = get_speakers() - - speaker_names = ["None"] + [get_speaker_show_name(speaker) for speaker in speakers] with gr.Tabs(): + with gr.TabItem("Creator"): + speaker_creator_ui() with gr.TabItem("Merger"): - gr.Markdown(merge_desc) - - with gr.Row(): - with gr.Column(scale=5): - with gr.Row(): - with gr.Group(): - spk_a = gr.Dropdown( - choices=speaker_names, value="None", label="Speaker A" - ) - spk_a_w = gr.Slider( - value=1, minimum=0, maximum=10, step=1, label="Weight A" - ) - - with gr.Group(): - spk_b = gr.Dropdown( - choices=speaker_names, value="None", label="Speaker B" - ) - spk_b_w = gr.Slider( - value=1, minimum=0, maximum=10, step=1, label="Weight B" - ) - - with gr.Group(): - spk_c = gr.Dropdown( - choices=speaker_names, value="None", label="Speaker C" - ) - spk_c_w = gr.Slider( - value=1, minimum=0, maximum=10, step=1, label="Weight C" - ) - - with gr.Group(): - spk_d = gr.Dropdown( - choices=speaker_names, value="None", label="Speaker D" - ) - spk_d_w = gr.Slider( - value=1, minimum=0, maximum=10, step=1, label="Weight D" - ) - - with gr.Row(): - with gr.Column(scale=3): - with gr.Group(): - gr.Markdown("🎤Test voice") - with gr.Row(): - test_voice_btn = gr.Button( - "Test Voice", variant="secondary" - ) - - with gr.Column(scale=4): - test_text = gr.Textbox( - label="Test Text", - placeholder="Please input test text", - value="说话人合并测试 123456789 [uv_break] ok, test done [lbreak]", - ) - - output_audio = gr.Audio(label="Output Audio") - - with gr.Column(scale=1): - with gr.Group(): - gr.Markdown("🗃️Save to file") - - speaker_name = gr.Textbox( - label="Name", value="forge_speaker_merged" - ) - speaker_gender = gr.Textbox(label="Gender", value="*") - speaker_desc = gr.Textbox( - label="Description", value="merged speaker" - ) - - save_btn = gr.Button("Save Speaker", variant="primary") - - merged_spker = gr.File( - label="Merged Speaker", interactive=False, type="binary" - ) - - test_voice_btn.click( - merge_and_test_spk_voice, - inputs=[ - spk_a, - spk_a_w, - spk_b, - spk_b_w, - spk_c, - spk_c_w, - spk_d, - spk_d_w, - test_text, - ], - outputs=[output_audio], - ) - - save_btn.click( - merge_spk_to_file, - inputs=[ - spk_a, - spk_a_w, - spk_b, - spk_b_w, - spk_c, - spk_c_w, - spk_d, - spk_d_w, - speaker_name, - speaker_gender, - speaker_desc, - ], - outputs=[merged_spker], - ) + create_speaker_merger() diff --git a/modules/webui/tts_tab.py b/modules/webui/tts_tab.py index 0c807d5e2fe6e514b44917c891db16c26557eaca..d51cb2eea4c646590c5b3f63b7ec266ade221a44 100644 --- a/modules/webui/tts_tab.py +++ b/modules/webui/tts_tab.py @@ -13,10 +13,7 @@ from modules import config default_text_content = """ chat T T S 是一款强大的对话式文本转语音模型。它有中英混读和多说话人的能力。 -chat T T S 不仅能够生成自然流畅的语音,还能控制[laugh]笑声啊[laugh], -停顿啊[uv_break]语气词啊等副语言现象[uv_break]。这个韵律超越了许多开源模型[uv_break]。 -请注意,chat T T S 的使用应遵守法律和伦理准则,避免滥用的安全风险。[uv_break] -""" +""".strip() def create_tts_interface():