CL-KWS_202408_v1 / model /ukws_guided_ctc.py
Francis0917's picture
Upload folder using huggingface_hub
2045faa verified
raw
history blame
4.84 kB
import os, sys
import tensorflow as tf
import numpy as np
from tensorflow.keras.models import Model
from tensorflow.keras import layers
import datetime
sys.path.append(os.path.dirname(__file__))
import encoder, extractor, discriminator, log_melspectrogram, speech_embedding
from utils import make_feature_matrix as concat_sequence
seed = 42
tf.random.set_seed(seed)
np.random.seed(seed)
class ukws(Model):
"""Base class for user-defined kws mdoel"""
def __init__(self, name="ukws", **kwargs):
super(ukws, self).__init__(name=name)
def call(self, speech, text):
"""
Args:
speech : speech feature of shape `(batch, time)`
text : text embedding of shape `(batch, phoneme)`
"""
raise NotImplementedError
class BaseUKWS(ukws):
"""Base class for user-defined kws mdoel"""
def __init__(self, name="BaseUKWS", **kwargs):
super(BaseUKWS, self).__init__(name=name)
embedding=128
self.audio_input = kwargs['audio_input']
self.text_input = kwargs['text_input']
self.stack_extractor = kwargs['stack_extractor']
_stft={
'frame_length' : kwargs['frame_length'],
'hop_length' : kwargs['hop_length'],
'num_mel' : kwargs['num_mel'] ,
'sample_rate' : kwargs['sample_rate'],
'log_mel' : kwargs['log_mel'],
}
_ae = {
# [filter, kernel size, stride]
'conv' : [[embedding, 5, 2], [embedding * 2, 5, 1]],
# [unit]
'gru' : [[embedding], [embedding]],
# fully-connected layer unit
'fc' : embedding,
'audio_input' : self.audio_input,
}
_te = {
# fully-connected layer unit
'fc' : embedding,
# number of uniq. phonemes
'vocab' : kwargs['vocab'],
'text_input' : kwargs['text_input'],
}
_ext = {
# [unit]
'embedding' : embedding,
}
_dis = {
# [unit]
'gru' : [[embedding],],
}
if self.audio_input == 'both':
self.SPEC = log_melspectrogram.LogMelgramLayer(**_stft)
self.EMBD = speech_embedding.GoogleSpeechEmbedder()
self.AE = encoder.EfficientAudioEncoder(downsample=False, **_ae)
else:
if self.audio_input == 'raw':
self.FEAT = log_melspectrogram.LogMelgramLayer(**_stft)
elif self.audio_input == 'google_embed':
self.FEAT = speech_embedding.GoogleSpeechEmbedder()
self.AE = encoder.AudioEncoder(**_ae)
self.TE = encoder.TextEncoder(**_te)
if kwargs['stack_extractor']:
self.EXT = extractor.StackExtractor(**_ext)
else:
self.EXT = extractor.BaseExtractor(**_ext)
self.DIS = discriminator.BaseDiscriminator(**_dis)
self.seq_ce_logit = layers.Dense(1, name='sequence_ce')
def call(self, speech, text):
"""
Args:
speech : speech feature of shape `(batch, time)`
text : text embedding of shape `(batch, phoneme)`
"""
if self.audio_input == 'both':
s = self.SPEC(speech)
g = self.EMBD(speech)
emb_s, LDN = self.AE(s, g)
else:
feat = self.FEAT(speech)
emb_s, LDN = self.AE(feat)
emb_t = self.TE(text)
attention_output, affinity_matrix = self.EXT(emb_s, emb_t)
prob, LD = self.DIS(attention_output)
n_speech = tf.math.reduce_sum(tf.cast(emb_s._keras_mask, tf.float32), -1)
if self.stack_extractor:
n_speech = tf.math.reduce_sum(tf.cast(emb_s._keras_mask, tf.float32), -1)
n_text = tf.math.reduce_sum(tf.cast(emb_t._keras_mask, tf.float32), -1)
n_total = n_speech + n_text
valid_mask = tf.sequence_mask(n_total, maxlen=tf.shape(attention_output)[1], dtype=tf.float32) - tf.sequence_mask(n_speech, maxlen=tf.shape(attention_output)[1], dtype=tf.float32)
valid_attention_output = tf.ragged.boolean_mask(attention_output, tf.cast(valid_mask, tf.bool)).to_tensor(0.)
seq_ce_logit = self.seq_ce_logit(valid_attention_output)[:,:,0]
seq_ce_logit = tf.pad(seq_ce_logit, [[0, 0],[0, tf.shape(emb_t)[1] - tf.shape(seq_ce_logit)[1]]], 'CONSTANT', constant_values=0.)
seq_ce_logit._keras_mask = emb_t._keras_mask
else:
seq_ce_logit = self.seq_ce_logit(attention_output)[:,:,0]
seq_ce_logit._keras_mask = attention_output._keras_mask
return prob, affinity_matrix, LD, seq_ce_logit, n_speech