Spaces:
Runtime error
Runtime error
File size: 2,187 Bytes
2045faa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
from syslog import LOG_DAEMON
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras import layers
class LogMelgramLayer(Model):
def __init__(self, name="mel_specgram", **kwargs):
if kwargs['log_mel']:
super(LogMelgramLayer, self).__init__(name="log_mel_specgram",)
else:
super(LogMelgramLayer, self).__init__(name=name,)
self.log_mel = kwargs['log_mel']
num_fft = 1 << (kwargs['frame_length'] - 1).bit_length()
self.hop_length = kwargs['hop_length']
self.frame_length = kwargs['frame_length']
num_freqs = (num_fft // 2) + 1
lin_to_mel_matrix = tf.signal.linear_to_mel_weight_matrix(
num_mel_bins=kwargs['num_mel'],
num_spectrogram_bins=num_freqs,
sample_rate=kwargs['sample_rate'],
lower_edge_hertz=80,
upper_edge_hertz=kwargs['sample_rate']/2,
)
self.lin_to_mel_matrix = lin_to_mel_matrix
self.non_trainable_weights.append(self.lin_to_mel_matrix)
def call(self, input):
"""
Args:
input (tensor): Batch of mono waveform, shape: (None, N)
Returns:
log_melgrams (tensor): Batch of log mel-spectrograms, shape: (None, num_frame, mel_bins, channel=1)
"""
def _tf_log10(x):
numerator = tf.math.log(x)
denominator = tf.math.log(tf.constant(10, dtype=numerator.dtype))
return numerator / denominator
# tf.signal.stft seems to be applied along the last axis
stfts = tf.signal.stft(
input, frame_length=self.frame_length, frame_step=self.hop_length
)
mag_stfts = tf.abs(stfts)
melgrams = tf.tensordot(tf.square(mag_stfts), self.lin_to_mel_matrix, axes=[2, 0])
melgrams.mask = layers.Masking(mask_value=0.0)(melgrams)._keras_mask
if self.log_mel:
log_melgrams = _tf_log10(melgrams + tf.keras.backend.epsilon())
log_melgrams.mask = layers.Masking(mask_value=-7.0)(log_melgrams)._keras_mask
return log_melgrams
else:
return melgrams |