Spaces:
Runtime error
Runtime error
from syslog import LOG_DAEMON | |
import tensorflow as tf | |
from tensorflow.keras.models import Model | |
from tensorflow.keras import layers | |
class LogMelgramLayer(Model): | |
def __init__(self, name="mel_specgram", **kwargs): | |
if kwargs['log_mel']: | |
super(LogMelgramLayer, self).__init__(name="log_mel_specgram",) | |
else: | |
super(LogMelgramLayer, self).__init__(name=name,) | |
self.log_mel = kwargs['log_mel'] | |
num_fft = 1 << (kwargs['frame_length'] - 1).bit_length() | |
self.hop_length = kwargs['hop_length'] | |
self.frame_length = kwargs['frame_length'] | |
num_freqs = (num_fft // 2) + 1 | |
lin_to_mel_matrix = tf.signal.linear_to_mel_weight_matrix( | |
num_mel_bins=kwargs['num_mel'], | |
num_spectrogram_bins=num_freqs, | |
sample_rate=kwargs['sample_rate'], | |
lower_edge_hertz=80, | |
upper_edge_hertz=kwargs['sample_rate']/2, | |
) | |
self.lin_to_mel_matrix = lin_to_mel_matrix | |
self.non_trainable_weights.append(self.lin_to_mel_matrix) | |
def call(self, input): | |
""" | |
Args: | |
input (tensor): Batch of mono waveform, shape: (None, N) | |
Returns: | |
log_melgrams (tensor): Batch of log mel-spectrograms, shape: (None, num_frame, mel_bins, channel=1) | |
""" | |
def _tf_log10(x): | |
numerator = tf.math.log(x) | |
denominator = tf.math.log(tf.constant(10, dtype=numerator.dtype)) | |
return numerator / denominator | |
# tf.signal.stft seems to be applied along the last axis | |
stfts = tf.signal.stft( | |
input, frame_length=self.frame_length, frame_step=self.hop_length | |
) | |
mag_stfts = tf.abs(stfts) | |
melgrams = tf.tensordot(tf.square(mag_stfts), self.lin_to_mel_matrix, axes=[2, 0]) | |
melgrams.mask = layers.Masking(mask_value=0.0)(melgrams)._keras_mask | |
if self.log_mel: | |
log_melgrams = _tf_log10(melgrams + tf.keras.backend.epsilon()) | |
log_melgrams.mask = layers.Masking(mask_value=-7.0)(log_melgrams)._keras_mask | |
return log_melgrams | |
else: | |
return melgrams |