CL-KWS_202408_v1 / model /speech_embedding.py
Francis0917's picture
Upload folder using huggingface_hub
2045faa verified
import os, sys, time
import tensorflow as tf
from tensorflow.keras.models import Model
class GoogleSpeechEmbedder(Model):
def __init__(self, name="google_embedding", **kwargs):
super(GoogleSpeechEmbedder, self).__init__(name=name)
self._embeddingModel = tf.saved_model.load(os.path.join(os.path.abspath(os.path.dirname(__file__)), 'google_speech_embedding'), tags=[]).signatures["default"]
self.window = 12400
self.shift = 1280
self.pre_padding = self.window - self.shift
self.non_trainable_weights.append(self._embeddingModel.variables)
def __call__(self, speech):
batch = speech.shape[0]
mask = tf.keras.layers.Masking(mask_value=0.0)(tf.expand_dims(speech, -1))._keras_mask
speech = tf.concat([tf.zeros([speech.shape[0], self.pre_padding]), speech], -1)
assert speech.shape[-1] > self.window, 'Input speech length must over 880 samples'
est_end = int(1 + (speech.shape[-1] - self.window) // self.shift)
trim = int(speech.shape[-1] % self.shift)
speech = speech[:, :-trim]
speech = tf.concat([tf.reshape(speech, [1, -1]), tf.zeros([1, self.pre_padding])], -1)
emb = tf.reshape(self._embeddingModel(speech)['default'], [batch, -1, 1, 96])[:, :est_end, :, :]
emb = tf.squeeze(emb, axis=2)
emb.mask = mask[:,self.shift-1::self.shift]
emb._keras_mask = emb.mask
return emb