File size: 2,187 Bytes
2045faa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from syslog import LOG_DAEMON
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras import layers

class LogMelgramLayer(Model):
    def __init__(self, name="mel_specgram", **kwargs):
        if kwargs['log_mel']:
            super(LogMelgramLayer, self).__init__(name="log_mel_specgram",)
        else:
            super(LogMelgramLayer, self).__init__(name=name,)
            
        self.log_mel = kwargs['log_mel']
        num_fft = 1 << (kwargs['frame_length'] - 1).bit_length()
        self.hop_length = kwargs['hop_length']
        self.frame_length = kwargs['frame_length']

        num_freqs = (num_fft // 2) + 1
        lin_to_mel_matrix = tf.signal.linear_to_mel_weight_matrix(
            num_mel_bins=kwargs['num_mel'],
            num_spectrogram_bins=num_freqs,
            sample_rate=kwargs['sample_rate'],
            lower_edge_hertz=80,
            upper_edge_hertz=kwargs['sample_rate']/2,
        )

        self.lin_to_mel_matrix = lin_to_mel_matrix
        self.non_trainable_weights.append(self.lin_to_mel_matrix)

    def call(self, input):
        """
        Args:
            input (tensor): Batch of mono waveform, shape: (None, N)

        Returns:
            log_melgrams (tensor): Batch of log mel-spectrograms, shape: (None, num_frame, mel_bins, channel=1)

        """

        def _tf_log10(x):
            numerator = tf.math.log(x)
            denominator = tf.math.log(tf.constant(10, dtype=numerator.dtype))
            return numerator / denominator
      
        # tf.signal.stft seems to be applied along the last axis
        stfts = tf.signal.stft(
            input, frame_length=self.frame_length, frame_step=self.hop_length
        )
        mag_stfts = tf.abs(stfts)

        melgrams = tf.tensordot(tf.square(mag_stfts), self.lin_to_mel_matrix, axes=[2, 0])
        melgrams.mask = layers.Masking(mask_value=0.0)(melgrams)._keras_mask

        if self.log_mel:
            log_melgrams = _tf_log10(melgrams + tf.keras.backend.epsilon())
            log_melgrams.mask = layers.Masking(mask_value=-7.0)(log_melgrams)._keras_mask
            return log_melgrams
        else:
            return melgrams