File size: 4,113 Bytes
12da6cc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
from functools import partial
from typing import Optional
import jax
import jax.numpy as jnp
import librosa
from einops import rearrange
from jax.numpy import ndarray
def rolling_window(a: ndarray, window: int, hop_length: int):
"""return a stack of overlap subsequence of an array.
``return jnp.stack( [a[0:10], a[5:15], a[10:20],...], axis=0)``
Source: https://github.com/google/jax/issues/3171
Args:
a (ndarray): input array of shape `[L, ...]`
window (int): length of each subarray (window).
hop_length (int): distance between neighbouring windows.
"""
idx = (
jnp.arange(window)[:, None]
+ jnp.arange((len(a) - window) // hop_length + 1)[None, :] * hop_length
)
return a[idx]
@partial(jax.jit, static_argnums=[1, 2, 3, 4, 5, 6])
def stft(
y: ndarray,
n_fft: int = 2048,
hop_length: Optional[int] = None,
win_length: Optional[int] = None,
window: str = "hann",
center: bool = True,
pad_mode: str = "reflect",
):
"""A jax reimplementation of ``librosa.stft`` function."""
if win_length is None:
win_length = n_fft
if hop_length is None:
hop_length = win_length // 4
if window == "hann":
fft_window = jnp.hanning(win_length + 1)[:-1]
else:
raise RuntimeError(f"{window} window function is not supported!")
pad_len = (n_fft - win_length) // 2
fft_window = jnp.pad(fft_window, (pad_len, pad_len), mode="constant")
fft_window = fft_window[:, None]
if center:
y = jnp.pad(y, int(n_fft // 2), mode=pad_mode)
# jax does not support ``np.lib.stride_tricks.as_strided`` function
# see https://github.com/google/jax/issues/3171 for comments.
y_frames = rolling_window(y, n_fft, hop_length) * fft_window
stft_matrix = jnp.fft.fft(y_frames, axis=0)
d = int(1 + n_fft // 2)
return stft_matrix[:d]
@partial(jax.jit, static_argnums=[1, 2, 3, 4, 5, 6])
def batched_stft(
y: ndarray,
n_fft: int,
hop_length: int,
win_length: int,
window: str,
center: bool = True,
pad_mode: str = "reflect",
):
"""Batched version of ``stft`` function.
TN => FTN
"""
assert len(y.shape) >= 2
if window == "hann":
fft_window = jnp.hanning(win_length + 1)[:-1]
else:
raise RuntimeError(f"{window} window function is not supported!")
pad_len = (n_fft - win_length) // 2
if pad_len > 0:
fft_window = jnp.pad(fft_window, (pad_len, pad_len), mode="constant")
win_length = n_fft
else:
fft_window = fft_window
if center:
pad_width = ((n_fft // 2, n_fft // 2),) + ((0, 0),) * (len(y.shape) - 1)
y = jnp.pad(y, pad_width, mode=pad_mode)
# jax does not support ``np.lib.stride_tricks.as_strided`` function
# see https://github.com/google/jax/issues/3171 for comments.
y_frames = rolling_window(y, n_fft, hop_length)
fft_window = jnp.reshape(fft_window, (-1,) + (1,) * (len(y.shape)))
y_frames = y_frames * fft_window
stft_matrix = jnp.fft.fft(y_frames, axis=0)
d = int(1 + n_fft // 2)
return stft_matrix[:d]
class MelFilter:
"""Convert waveform to mel spectrogram."""
def __init__(self, sample_rate: int, n_fft: int, n_mels: int, fmin=0.0, fmax=8000):
self.melfb = jax.device_put(
librosa.filters.mel(
sr=sample_rate, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax
)
)
self.n_fft = n_fft
def __call__(self, y: ndarray) -> ndarray:
hop_length = self.n_fft // 4
window_length = self.n_fft
assert len(y.shape) == 2
y = rearrange(y, "n s -> s n")
p = (self.n_fft - hop_length) // 2
y = jnp.pad(y, ((p, p), (0, 0)), mode="reflect")
spec = batched_stft(
y, self.n_fft, hop_length, window_length, "hann", False, "reflect"
)
mag = jnp.sqrt(jnp.square(spec.real) + jnp.square(spec.imag) + 1e-9)
mel = jnp.einsum("ms,sfn->nfm", self.melfb, mag)
cond = jnp.log(jnp.clip(mel, a_min=1e-5, a_max=None))
return cond
|