tts / vietTTS /nat /dsp.py
tobiccino's picture
upload
12da6cc
raw
history blame
No virus
4.11 kB
from functools import partial
from typing import Optional
import jax
import jax.numpy as jnp
import librosa
from einops import rearrange
from jax.numpy import ndarray
def rolling_window(a: ndarray, window: int, hop_length: int):
"""return a stack of overlap subsequence of an array.
``return jnp.stack( [a[0:10], a[5:15], a[10:20],...], axis=0)``
Source: https://github.com/google/jax/issues/3171
Args:
a (ndarray): input array of shape `[L, ...]`
window (int): length of each subarray (window).
hop_length (int): distance between neighbouring windows.
"""
idx = (
jnp.arange(window)[:, None]
+ jnp.arange((len(a) - window) // hop_length + 1)[None, :] * hop_length
)
return a[idx]
@partial(jax.jit, static_argnums=[1, 2, 3, 4, 5, 6])
def stft(
y: ndarray,
n_fft: int = 2048,
hop_length: Optional[int] = None,
win_length: Optional[int] = None,
window: str = "hann",
center: bool = True,
pad_mode: str = "reflect",
):
"""A jax reimplementation of ``librosa.stft`` function."""
if win_length is None:
win_length = n_fft
if hop_length is None:
hop_length = win_length // 4
if window == "hann":
fft_window = jnp.hanning(win_length + 1)[:-1]
else:
raise RuntimeError(f"{window} window function is not supported!")
pad_len = (n_fft - win_length) // 2
fft_window = jnp.pad(fft_window, (pad_len, pad_len), mode="constant")
fft_window = fft_window[:, None]
if center:
y = jnp.pad(y, int(n_fft // 2), mode=pad_mode)
# jax does not support ``np.lib.stride_tricks.as_strided`` function
# see https://github.com/google/jax/issues/3171 for comments.
y_frames = rolling_window(y, n_fft, hop_length) * fft_window
stft_matrix = jnp.fft.fft(y_frames, axis=0)
d = int(1 + n_fft // 2)
return stft_matrix[:d]
@partial(jax.jit, static_argnums=[1, 2, 3, 4, 5, 6])
def batched_stft(
y: ndarray,
n_fft: int,
hop_length: int,
win_length: int,
window: str,
center: bool = True,
pad_mode: str = "reflect",
):
"""Batched version of ``stft`` function.
TN => FTN
"""
assert len(y.shape) >= 2
if window == "hann":
fft_window = jnp.hanning(win_length + 1)[:-1]
else:
raise RuntimeError(f"{window} window function is not supported!")
pad_len = (n_fft - win_length) // 2
if pad_len > 0:
fft_window = jnp.pad(fft_window, (pad_len, pad_len), mode="constant")
win_length = n_fft
else:
fft_window = fft_window
if center:
pad_width = ((n_fft // 2, n_fft // 2),) + ((0, 0),) * (len(y.shape) - 1)
y = jnp.pad(y, pad_width, mode=pad_mode)
# jax does not support ``np.lib.stride_tricks.as_strided`` function
# see https://github.com/google/jax/issues/3171 for comments.
y_frames = rolling_window(y, n_fft, hop_length)
fft_window = jnp.reshape(fft_window, (-1,) + (1,) * (len(y.shape)))
y_frames = y_frames * fft_window
stft_matrix = jnp.fft.fft(y_frames, axis=0)
d = int(1 + n_fft // 2)
return stft_matrix[:d]
class MelFilter:
"""Convert waveform to mel spectrogram."""
def __init__(self, sample_rate: int, n_fft: int, n_mels: int, fmin=0.0, fmax=8000):
self.melfb = jax.device_put(
librosa.filters.mel(
sr=sample_rate, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax
)
)
self.n_fft = n_fft
def __call__(self, y: ndarray) -> ndarray:
hop_length = self.n_fft // 4
window_length = self.n_fft
assert len(y.shape) == 2
y = rearrange(y, "n s -> s n")
p = (self.n_fft - hop_length) // 2
y = jnp.pad(y, ((p, p), (0, 0)), mode="reflect")
spec = batched_stft(
y, self.n_fft, hop_length, window_length, "hann", False, "reflect"
)
mag = jnp.sqrt(jnp.square(spec.real) + jnp.square(spec.imag) + 1e-9)
mel = jnp.einsum("ms,sfn->nfm", self.melfb, mag)
cond = jnp.log(jnp.clip(mel, a_min=1e-5, a_max=None))
return cond