CL-KWS_202408_v1 / train_guided_CTC.py
Francis0917's picture
Upload folder using huggingface_hub
2045faa verified
import sys, os, datetime, warnings, argparse
import tensorflow as tf
import numpy as np
from model import ukws_guided_ctc
from dataset import libriphrase_ctc1, google, qualcomm
from criterion import total_ctc1
from criterion.utils import eer
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
warnings.filterwarnings('ignore')
warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning)
np.warnings.filterwarnings('ignore', category=np.VisibleDeprecationWarning)
warnings.simplefilter("ignore")
seed = 42
tf.random.set_seed(seed)
np.random.seed(seed)
checkpoint_dir = '/Home/DB/checkpoint&results/checkpoint_guided_ctc/' + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
result_dir='/Home/DB/checkpoint&results/result_guided_ctc/' + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
if not os.path.exists(result_dir):
os.makedirs(result_dir, exist_ok=True)
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir, exist_ok=True)
result_file=os.path.join(result_dir, "result.txt")
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
tensorboard_prefix = os.path.join(checkpoint_dir, "tensorboard")
parser = argparse.ArgumentParser()
parser.add_argument('--epoch', required=True, type=int)
parser.add_argument('--lr', required=True, type=float)
parser.add_argument('--loss_weight', default=[1.0, 1.0, 0.2], nargs=3, type=float)
parser.add_argument('--text_input', required=False, type=str, default='g2p_embed')
parser.add_argument('--audio_input', required=False, type=str, default='both')
parser.add_argument('--train_pkl', required=False, type=str, default='/Home/DB/data/train_360_100.pkl')
parser.add_argument('--google_pkl', required=False, type=str, default='/Home/DB/data/google.pkl')
parser.add_argument('--qualcomm_pkl', required=False, type=str, default='/Home/DB/data/qualcomm.pkl')
parser.add_argument('--libriphrase_pkl', required=False, type=str, default='/Home/DB/data/test_both.pkl')
parser.add_argument('--stack_extractor', action='store_true')
parser.add_argument('--comment', required=False, type=str)
args = parser.parse_args()
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
try:
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
except RuntimeError as e:
print(e)
strategy = tf.distribute.MirroredStrategy()
# Batch size per GPU
GLOBAL_BATCH_SIZE = 1024 * strategy.num_replicas_in_sync
BATCH_SIZE_PER_REPLICA = GLOBAL_BATCH_SIZE / strategy.num_replicas_in_sync
# Make Dataloader
text_input = args.text_input
audio_input = args.audio_input
train_dataset = libriphrase_ctc1.LibriPhraseDataloader(batch_size=GLOBAL_BATCH_SIZE, features=text_input, train=True, types='both', shuffle=True, pkl=args.train_pkl)
test_dataset = libriphrase_ctc1.LibriPhraseDataloader(batch_size=GLOBAL_BATCH_SIZE, features=text_input, train=False, types='both', shuffle=True, pkl=args.libriphrase_pkl)
test_easy_dataset = libriphrase_ctc1.LibriPhraseDataloader(batch_size=GLOBAL_BATCH_SIZE, features=text_input, train=False, types='easy', shuffle=True, pkl=args.libriphrase_pkl)
test_hard_dataset = libriphrase_ctc1.LibriPhraseDataloader(batch_size=GLOBAL_BATCH_SIZE, features=text_input, train=False, types='hard', shuffle=True, pkl=args.libriphrase_pkl)
test_google_dataset = google.GoogleCommandsDataloader(batch_size=GLOBAL_BATCH_SIZE, features=text_input, shuffle=True, pkl=args.google_pkl)
test_qualcomm_dataset = qualcomm.QualcommKeywordSpeechDataloader(batch_size=GLOBAL_BATCH_SIZE, features=text_input, shuffle=True, pkl=args.qualcomm_pkl)
# Number of phonemes
vocab = train_dataset.nPhoneme
# Convert tf.utils.sequence to tf.dataset
train_dataset = libriphrase_ctc1.convert_sequence_to_dataset(train_dataset)
test_dataset = libriphrase_ctc1.convert_sequence_to_dataset(test_dataset)
test_easy_dataset = libriphrase_ctc1.convert_sequence_to_dataset(test_easy_dataset)
test_hard_dataset = libriphrase_ctc1.convert_sequence_to_dataset(test_hard_dataset)
test_google_dataset = google.convert_sequence_to_dataset(test_google_dataset)
test_qualcomm_dataset = qualcomm.convert_sequence_to_dataset(test_qualcomm_dataset)
# Make disribute dataset for multi-gpu
train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset)
test_dist_dataset = strategy.experimental_distribute_dataset(test_dataset)
test_easy_dist_dataset = strategy.experimental_distribute_dataset(test_easy_dataset)
test_hard_dist_dataset = strategy.experimental_distribute_dataset(test_hard_dataset)
test_google_dist_dataset = strategy.experimental_distribute_dataset(test_google_dataset)
test_qualcomm_dist_dataset = strategy.experimental_distribute_dataset(test_qualcomm_dataset)
# Model params.
kwargs = {
'vocab' : vocab,
'text_input' : text_input,
'audio_input' : audio_input,
'frame_length' : 400,
'hop_length' : 160,
'num_mel' : 40,
'sample_rate' : 16000,
'log_mel' : False,
'stack_extractor' : args.stack_extractor,
}
# Train params.
EPOCHS = args.epoch
lr = args.lr
# Make tensorboard dict.
param = kwargs
param['epoch'] = EPOCHS
param['lr'] = lr
param['loss weight'] = args.loss_weight
param['comment'] = args.comment
with strategy.scope():
loss_object = total_ctc1.TotalLoss_SCE(weight=args.loss_weight)
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_loss_d = tf.keras.metrics.Mean(name='train_loss_Utt')
train_loss_sce = tf.keras.metrics.Mean(name='train_loss_Phon')
test_loss = tf.keras.metrics.Mean(name='test_loss')
test_loss_d = tf.keras.metrics.Mean(name='test_loss_Utt')
train_auc = tf.keras.metrics.AUC(name='train_auc')
train_eer = eer(name='train_eer')
test_auc = tf.keras.metrics.AUC(name='test_auc')
test_eer = eer(name='test_eer')
test_easy_auc = tf.keras.metrics.AUC(name='test_easy_auc')
test_easy_eer = eer(name='test_easy_eer')
test_hard_auc = tf.keras.metrics.AUC(name='test_hard_auc')
test_hard_eer = eer(name='test_hard_eer')
google_auc = tf.keras.metrics.AUC(name='google_auc')
google_eer = eer(name='google_eer')
qualcomm_auc = tf.keras.metrics.AUC(name='qualcomm_auc')
qualcomm_eer = eer(name='qualcomm_eer')
model = ukws_guided_ctc.BaseUKWS(**kwargs)
optimizer = tf.keras.optimizers.Adam(learning_rate=lr)
checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
@tf.function
def train_step(inputs):
clean_speech, noisy_speech, text, labels, speech_labels, text_labels = inputs
with tf.GradientTape(watch_accessed_variables=False, persistent=False) as tape:
model(clean_speech, text, training=False)
tape.watch(model.trainable_variables)
prob, affinity_matrix, LD, sce_logit,n_speech = model(noisy_speech, text, training=True)
loss, LD, LC = loss_object(labels, LD, speech_labels, text_labels, sce_logit,affinity_matrix,n_speech)
loss /= GLOBAL_BATCH_SIZE
LC /= GLOBAL_BATCH_SIZE
LD /= GLOBAL_BATCH_SIZE
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
train_loss.update_state(loss)
train_loss_d.update_state(LD)
train_loss_sce.update_state(LC)
train_auc.update_state(labels, prob)
train_eer.update_state(labels, prob)
return loss, tf.expand_dims(tf.cast(affinity_matrix * 255, tf.uint8), -1), labels
@tf.function
def test_step(inputs):
clean_speech = inputs[0]
text = inputs[1]
labels = inputs[2]
prob, affinity_matrix, LD, LC = model(clean_speech, text, training=False)[:4]
t_loss, LD = total_ctc1.TotalLoss(weight=args.loss_weight[0])(labels, LD)
t_loss /= GLOBAL_BATCH_SIZE
LD /= GLOBAL_BATCH_SIZE
test_loss.update_state(t_loss)
test_loss_d.update_state(LD)
test_auc.update_state(labels, prob)
test_eer.update_state(labels, prob)
return t_loss, tf.expand_dims(tf.cast(affinity_matrix * 255, tf.uint8), -1), labels
@tf.function
def test_step_metric_only(inputs, metric=[]):
clean_speech = inputs[0]
text = inputs[1]
labels = inputs[2]
prob = model(clean_speech, text, training=False)[0]
for m in metric:
m.update_state(labels, prob)
train_log_dir = os.path.join(tensorboard_prefix, "train")
test_log_dir = os.path.join(tensorboard_prefix, "test")
train_summary_writer = tf.summary.create_file_writer(train_log_dir)
test_summary_writer = tf.summary.create_file_writer(test_log_dir)
def distributed_train_step(dataset_inputs):
per_replica_losses, per_replica_affinity_matrix, per_replica_labels = strategy.run(train_step, args=(dataset_inputs,))
return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None), strategy.experimental_local_results(per_replica_affinity_matrix)[0], strategy.experimental_local_results(per_replica_labels)[0]
def distributed_test_step(dataset_inputs):
per_replica_losses, per_replica_affinity_matrix, per_replica_labels = strategy.run(test_step, args=(dataset_inputs,))
return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None), strategy.experimental_local_results(per_replica_affinity_matrix)[0], strategy.experimental_local_results(per_replica_labels)[0]
def distributed_test_step_metric_only(dataset_inputs, metric=[]):
strategy.run(test_step_metric_only, args=(dataset_inputs, metric))
with train_summary_writer.as_default():
tf.summary.text('Hyperparameters', tf.stack([tf.convert_to_tensor([k, str(v)]) for k, v in param.items()]), step=0)
for epoch in range(EPOCHS):
# TRAIN LOOP
train_matrix = None
train_labels = None
test_matrix = None
train_labels = None
for i, x in enumerate(train_dist_dataset):
_, train_matrix, train_labels = distributed_train_step(x)
match_train_matrix = []
unmatch_train_matrix = []
for i, x in enumerate(train_labels):
if x == 1:
match_train_matrix.append(train_matrix[i])
elif x == 0:
unmatch_train_matrix.append(train_matrix[i])
with train_summary_writer.as_default():
tf.summary.scalar('0. Total loss', train_loss.result(), step=epoch)
tf.summary.scalar('1. Utterance-level Detection loss', train_loss_d.result(), step=epoch)
tf.summary.scalar('2. Phoneme-levle Detection loss', train_loss_sce.result(), step=epoch)
tf.summary.scalar('3. AUC', train_auc.result(), step=epoch)
tf.summary.scalar('4. EER', train_eer.result(), step=epoch)
tf.summary.image("Affinity matrix (match)", match_train_matrix, max_outputs=5, step=epoch)
tf.summary.image("Affinity matrix (unmatch)", unmatch_train_matrix, max_outputs=5, step=epoch)
# TEST LOOP
for x in test_dist_dataset:
_, test_matrix, test_labels = distributed_test_step(x)
match_test_matrix = []
unmatch_test_matrix = []
for i, x in enumerate(test_labels):
if x == 1:
match_test_matrix.append(test_matrix[i])
elif x == 0:
unmatch_test_matrix.append(test_matrix[i])
for x in test_easy_dist_dataset:
distributed_test_step_metric_only(x, metric=[test_easy_auc, test_easy_eer])
for x in test_hard_dist_dataset:
distributed_test_step_metric_only(x, metric=[test_hard_auc, test_hard_eer])
for x in test_google_dist_dataset:
distributed_test_step_metric_only(x, metric=[google_auc, google_eer])
for x in test_qualcomm_dist_dataset:
distributed_test_step_metric_only(x, metric=[qualcomm_auc, qualcomm_eer])
with test_summary_writer.as_default():
tf.summary.scalar('0. Total loss', test_loss.result(), step=epoch)
tf.summary.scalar('1. Utterance-level Detection loss', test_loss_d.result(), step=epoch)
tf.summary.scalar('3. AUC', test_auc.result(), step=epoch)
tf.summary.scalar('3. AUC (EASY)', test_easy_auc.result(), step=epoch)
tf.summary.scalar('3. AUC (HARD)', test_hard_auc.result(), step=epoch)
tf.summary.scalar('3. AUC (Google)', google_auc.result(), step=epoch)
tf.summary.scalar('3. AUC (Qualcomm)', qualcomm_auc.result(), step=epoch)
tf.summary.scalar('4. EER', test_eer.result(), step=epoch)
tf.summary.scalar('4. EER (EASY)', test_easy_eer.result(), step=epoch)
tf.summary.scalar('4. EER (HARD)', test_hard_eer.result(), step=epoch)
tf.summary.scalar('4. EER (Google)', google_eer.result(), step=epoch)
tf.summary.scalar('4. EER (Qualcomm)', qualcomm_eer.result(), step=epoch)
tf.summary.image("Affinity matrix (match)", match_test_matrix, max_outputs=5, step=epoch)
tf.summary.image("Affinity matrix (unmatch)", unmatch_test_matrix, max_outputs=5, step=epoch)
if epoch % 1 == 0:
checkpoint.save(checkpoint_prefix)
template = ("Epoch {} | TRAIN | Loss {:.3f}, AUC {:.2f}, EER {:.2f} | EER | G {:.2f}, Q {:.2f}, LE {:.2f}, LH {:.2f} | AUC | G {:.2f}, Q {:.2f}, LE {:.2f}, LH {:.2f} |")
print (template.format(epoch + 1,
train_loss.result(),
train_auc.result() * 100,
train_eer.result() * 100,
google_eer.result() * 100,
qualcomm_eer.result() * 100,
test_easy_eer.result() * 100,
test_hard_eer.result() * 100,
google_auc.result() * 100,
qualcomm_auc.result() * 100,
test_easy_auc.result() * 100,
test_hard_auc.result() * 100,
)
)
with open(result_file, 'a') as file:
file.write(template.format(epoch + 1,
train_loss.result(),
train_auc.result() * 100,
train_eer.result() * 100,
google_eer.result() * 100,
qualcomm_eer.result() * 100,
test_easy_eer.result() * 100,
test_hard_eer.result() * 100,
google_auc.result() * 100,
qualcomm_auc.result() * 100,
test_easy_auc.result() * 100,
test_hard_auc.result() * 100,
)+'\n'
)
train_loss.reset_states()
test_loss.reset_states()
train_auc.reset_states()
test_auc.reset_states()
test_easy_auc.reset_states()
test_hard_auc.reset_states()
train_eer.reset_states()
test_eer.reset_states()
test_easy_eer.reset_states()
test_hard_eer.reset_states()
google_eer.reset_states()
qualcomm_eer.reset_states()
google_auc.reset_states()
qualcomm_auc.reset_states()