import sys, os, datetime, warnings, argparse import tensorflow as tf import numpy as np from model import ukws from dataset import libriphrase, google, qualcomm,libriphrase from criterion import total_CLKWS from criterion.utils import eer os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) warnings.filterwarnings('ignore') warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning) np.warnings.filterwarnings('ignore', category=np.VisibleDeprecationWarning) warnings.simplefilter("ignore") seed = 42 tf.random.set_seed(seed) np.random.seed(seed) parser = argparse.ArgumentParser() parser.add_argument('--epoch', required=True, type=int) parser.add_argument('--lr', required=True, type=float) parser.add_argument('--loss_weight', default=[1.0, 1.0], nargs=2, type=float) parser.add_argument('--text_input', required=False, type=str, default='g2p_embed') parser.add_argument('--audio_input', required=False, type=str, default='both') parser.add_argument('--load_checkpoint_path', required=True, type=str) parser.add_argument('--train_pkl', required=False, type=str, default='/home/DB/data/libriandGSC__fix300_5_shuffle.pkl') parser.add_argument('--google_pkl', required=False, type=str, default='/home/DB//data/google.pkl') parser.add_argument('--qualcomm_pkl', required=False, type=str, default='/home/DB/data/qualcomm.pkl') parser.add_argument('--libriphrase_pkl', required=False, type=str, default='/home/DB/data/test_both.pkl') parser.add_argument('--stack_extractor', action='store_true') parser.add_argument('--comment', required=False, type=str) args = parser.parse_args() gpus = tf.config.experimental.list_physical_devices('GPU') if gpus: try: for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) except RuntimeError as e: print(e) strategy = tf.distribute.MirroredStrategy() # Batch size per GPU # GLOBAL_BATCH_SIZE = 400 * strategy.num_replicas_in_sync GLOBAL_BATCH_SIZE = 1000 BATCH_SIZE_PER_REPLICA = GLOBAL_BATCH_SIZE / strategy.num_replicas_in_sync # Make Dataloader text_input = args.text_input audio_input = args.audio_input load_checkpoint_path = args.load_checkpoint_path train_dataset = libriphrase.LibriPhraseDataloader(batch_size=GLOBAL_BATCH_SIZE, features=text_input, train=True, types='both', shuffle=True, pkl=args.train_pkl) test_dataset = libriphrase.LibriPhraseDataloader(batch_size=GLOBAL_BATCH_SIZE, features=text_input, train=False, types='both', shuffle=True, pkl=args.libriphrase_pkl) test_easy_dataset = libriphrase.LibriPhraseDataloader(batch_size=GLOBAL_BATCH_SIZE, features=text_input, train=False, types='easy', shuffle=True, pkl=args.libriphrase_pkl) test_hard_dataset = libriphrase.LibriPhraseDataloader(batch_size=GLOBAL_BATCH_SIZE, features=text_input, train=False, types='hard', shuffle=True, pkl=args.libriphrase_pkl) test_google_dataset = google.GoogleCommandsDataloader(batch_size=GLOBAL_BATCH_SIZE, features=text_input, shuffle=True, pkl=args.google_pkl) test_qualcomm_dataset = qualcomm.QualcommKeywordSpeechDataloader(batch_size=GLOBAL_BATCH_SIZE, features=text_input, shuffle=True, pkl=args.qualcomm_pkl) # Number of phonemes vocab = train_dataset.nPhoneme # Convert tf.utils.sequence to tf.dataset train_dataset = libriphrase.convert_sequence_to_dataset(train_dataset) test_dataset = libriphrase.convert_sequence_to_dataset(test_dataset) test_easy_dataset = libriphrase.convert_sequence_to_dataset(test_easy_dataset) test_hard_dataset = libriphrase.convert_sequence_to_dataset(test_hard_dataset) test_google_dataset = google.convert_sequence_to_dataset(test_google_dataset) test_qualcomm_dataset = qualcomm.convert_sequence_to_dataset(test_qualcomm_dataset) # Make disribute dataset for multi-gpu train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset) test_dist_dataset = strategy.experimental_distribute_dataset(test_dataset) test_easy_dist_dataset = strategy.experimental_distribute_dataset(test_easy_dataset) test_hard_dist_dataset = strategy.experimental_distribute_dataset(test_hard_dataset) test_google_dist_dataset = strategy.experimental_distribute_dataset(test_google_dataset) test_qualcomm_dist_dataset = strategy.experimental_distribute_dataset(test_qualcomm_dataset) # Model params. kwargs = { 'vocab' : vocab, 'text_input' : text_input, 'audio_input' : audio_input, 'frame_length' : 400, 'hop_length' : 160, 'num_mel' : 40, 'sample_rate' : 16000, 'log_mel' : False, 'stack_extractor' : args.stack_extractor, } # Train params. EPOCHS = args.epoch lr = args.lr # Make tensorboard dict. param = kwargs param['epoch'] = EPOCHS param['lr'] = lr param['loss weight'] = args.loss_weight param['comment'] = args.comment checkpoint_dir = './checkpoint_results/checkpoint_CLKWS/' + datetime.datetime.now().strftime("%Y%m%d-%H%M%S") result_dir='./checkpoint_results/result_CLKWS/' + datetime.datetime.now().strftime("%Y%m%d-%H%M%S") if not os.path.exists(result_dir): os.makedirs(result_dir, exist_ok=True) if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir, exist_ok=True) result_file=os.path.join(result_dir, "result.txt") checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt") tensorboard_prefix = os.path.join(checkpoint_dir, "tensorboard") with strategy.scope(): loss_object = total_CLKWS.TotalLoss_SCE(weight=args.loss_weight) train_loss = tf.keras.metrics.Mean(name='train_loss') train_loss_d = tf.keras.metrics.Mean(name='train_loss_Utt') train_loss_sce = tf.keras.metrics.Mean(name='train_loss_Phon') test_loss = tf.keras.metrics.Mean(name='test_loss') test_loss_d = tf.keras.metrics.Mean(name='test_loss_Utt') train_auc = tf.keras.metrics.AUC(name='train_auc') train_eer = eer(name='train_eer') test_auc = tf.keras.metrics.AUC(name='test_auc') test_eer = eer(name='test_eer') test_easy_auc = tf.keras.metrics.AUC(name='test_easy_auc') test_easy_eer = eer(name='test_easy_eer') test_hard_auc = tf.keras.metrics.AUC(name='test_hard_auc') test_hard_eer = eer(name='test_hard_eer') google_auc = tf.keras.metrics.AUC(name='google_auc') google_eer = eer(name='google_eer') qualcomm_auc = tf.keras.metrics.AUC(name='qualcomm_auc') qualcomm_eer = eer(name='qualcomm_eer') model = ukws.BaseUKWS(**kwargs) optimizer = tf.keras.optimizers.Adam(learning_rate=lr) checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model) if args.load_checkpoint_path: checkpoint_dir=args.load_checkpoint_path checkpoint = tf.train.Checkpoint(model=model) checkpoint_manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=5) latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir) if latest_checkpoint: checkpoint.restore(latest_checkpoint) print("Checkpoint restored!") @tf.function def train_step(inputs): clean_speech, noisy_speech, text, labels, speech_labels, text_labels = inputs with tf.GradientTape(watch_accessed_variables=False, persistent=False) as tape: model(clean_speech, text, training=False) tape.watch(model.trainable_variables) prob, affinity_matrix, LD, sce_logit = model(noisy_speech, text, training=True) loss, LD, LC = loss_object(labels, LD, speech_labels, text_labels, sce_logit, prob) loss /= GLOBAL_BATCH_SIZE LC /= GLOBAL_BATCH_SIZE LD /= GLOBAL_BATCH_SIZE gradients = tape.gradient(loss, model.trainable_variables) optimizer.apply_gradients(zip(gradients, model.trainable_variables)) train_loss.update_state(loss) train_loss_d.update_state(LD) train_loss_sce.update_state(LC) train_auc.update_state(labels, prob) train_eer.update_state(labels, prob) return loss, tf.expand_dims(tf.cast(affinity_matrix * 255, tf.uint8), -1), labels @tf.function def test_step(inputs): clean_speech = inputs[0] text = inputs[1] labels = inputs[2] prob, affinity_matrix, LD, LC = model(clean_speech, text, training=False)[:4] t_loss, LD = total_CLKWS.TotalLoss(weight=args.loss_weight[0])(labels, LD) t_loss /= GLOBAL_BATCH_SIZE LD /= GLOBAL_BATCH_SIZE test_loss.update_state(t_loss) test_loss_d.update_state(LD) test_auc.update_state(labels, prob) test_eer.update_state(labels, prob) return t_loss, tf.expand_dims(tf.cast(affinity_matrix * 255, tf.uint8), -1), labels @tf.function def test_step_metric_only(inputs, metric=[]): clean_speech = inputs[0] text = inputs[1] labels = inputs[2] prob = model(clean_speech, text, training=False)[0] for m in metric: m.update_state(labels, prob) train_log_dir = os.path.join(tensorboard_prefix, "train") test_log_dir = os.path.join(tensorboard_prefix, "test") train_summary_writer = tf.summary.create_file_writer(train_log_dir) test_summary_writer = tf.summary.create_file_writer(test_log_dir) def distributed_train_step(dataset_inputs): per_replica_losses, per_replica_affinity_matrix, per_replica_labels = strategy.run(train_step, args=(dataset_inputs,)) return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None), strategy.experimental_local_results(per_replica_affinity_matrix)[0], strategy.experimental_local_results(per_replica_labels)[0] def distributed_test_step(dataset_inputs): per_replica_losses, per_replica_affinity_matrix, per_replica_labels = strategy.run(test_step, args=(dataset_inputs,)) return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None), strategy.experimental_local_results(per_replica_affinity_matrix)[0], strategy.experimental_local_results(per_replica_labels)[0] def distributed_test_step_metric_only(dataset_inputs, metric=[]): strategy.run(test_step_metric_only, args=(dataset_inputs, metric)) with train_summary_writer.as_default(): tf.summary.text('Hyperparameters', tf.stack([tf.convert_to_tensor([k, str(v)]) for k, v in param.items()]), step=0) for epoch in range(EPOCHS): # TRAIN LOOP train_matrix = None train_labels = None test_matrix = None train_labels = None for i, x in enumerate(train_dist_dataset): _, train_matrix, train_labels = distributed_train_step(x) match_train_matrix = [] unmatch_train_matrix = [] for i, x in enumerate(train_labels): if x == 1: match_train_matrix.append(train_matrix[i]) elif x == 0: unmatch_train_matrix.append(train_matrix[i]) with train_summary_writer.as_default(): tf.summary.scalar('0. Total loss', train_loss.result(), step=epoch) tf.summary.scalar('1. Utterance-level Detection loss', train_loss_d.result(), step=epoch) tf.summary.scalar('2. Phoneme-levle Detection loss', train_loss_sce.result(), step=epoch) tf.summary.scalar('3. AUC', train_auc.result(), step=epoch) tf.summary.scalar('4. EER', train_eer.result(), step=epoch) tf.summary.image("Affinity matrix (match)", match_train_matrix, max_outputs=5, step=epoch) tf.summary.image("Affinity matrix (unmatch)", unmatch_train_matrix, max_outputs=5, step=epoch) # TEST LOOP for x in test_dist_dataset: _, test_matrix, test_labels = distributed_test_step(x) match_test_matrix = [] unmatch_test_matrix = [] for i, x in enumerate(test_labels): if x == 1: match_test_matrix.append(test_matrix[i]) elif x == 0: unmatch_test_matrix.append(test_matrix[i]) for x in test_easy_dist_dataset: distributed_test_step_metric_only(x, metric=[test_easy_auc, test_easy_eer]) for x in test_hard_dist_dataset: distributed_test_step_metric_only(x, metric=[test_hard_auc, test_hard_eer]) for x in test_google_dist_dataset: distributed_test_step_metric_only(x, metric=[google_auc, google_eer]) for x in test_qualcomm_dist_dataset: distributed_test_step_metric_only(x, metric=[qualcomm_auc, qualcomm_eer]) with test_summary_writer.as_default(): tf.summary.scalar('0. Total loss', test_loss.result(), step=epoch) tf.summary.scalar('1. Utterance-level Detection loss', test_loss_d.result(), step=epoch) tf.summary.scalar('3. AUC', test_auc.result(), step=epoch) tf.summary.scalar('3. AUC (EASY)', test_easy_auc.result(), step=epoch) tf.summary.scalar('3. AUC (HARD)', test_hard_auc.result(), step=epoch) tf.summary.scalar('3. AUC (Google)', google_auc.result(), step=epoch) tf.summary.scalar('3. AUC (Qualcomm)', qualcomm_auc.result(), step=epoch) tf.summary.scalar('4. EER', test_eer.result(), step=epoch) tf.summary.scalar('4. EER (EASY)', test_easy_eer.result(), step=epoch) tf.summary.scalar('4. EER (HARD)', test_hard_eer.result(), step=epoch) tf.summary.scalar('4. EER (Google)', google_eer.result(), step=epoch) tf.summary.scalar('4. EER (Qualcomm)', qualcomm_eer.result(), step=epoch) tf.summary.image("Affinity matrix (match)", match_test_matrix, max_outputs=5, step=epoch) tf.summary.image("Affinity matrix (unmatch)", unmatch_test_matrix, max_outputs=5, step=epoch) if epoch % 1 == 0: checkpoint.save(checkpoint_prefix) template = ("Epoch {} | TRAIN | Loss {:.3f}, AUC {:.2f}, EER {:.2f} | EER | G {:.2f}, Q {:.2f}, LE {:.2f}, LH {:.2f} | AUC | G {:.2f}, Q {:.2f}, LE {:.2f}, LH {:.2f} |") print (template.format(epoch + 1, train_loss.result(), train_auc.result() * 100, train_eer.result() * 100, google_eer.result() * 100, qualcomm_eer.result() * 100, test_easy_eer.result() * 100, test_hard_eer.result() * 100, google_auc.result() * 100, qualcomm_auc.result() * 100, test_easy_auc.result() * 100, test_hard_auc.result() * 100, ) ) with open(result_file, 'a') as file: file.write(template.format(epoch + 1, train_loss.result(), train_auc.result() * 100, train_eer.result() * 100, google_eer.result() * 100, qualcomm_eer.result() * 100, test_easy_eer.result() * 100, test_hard_eer.result() * 100, google_auc.result() * 100, qualcomm_auc.result() * 100, test_easy_auc.result() * 100, test_hard_auc.result() * 100, )+'\n' ) train_loss.reset_states() test_loss.reset_states() train_auc.reset_states() test_auc.reset_states() test_easy_auc.reset_states() test_hard_auc.reset_states() train_eer.reset_states() test_eer.reset_states() test_easy_eer.reset_states() test_hard_eer.reset_states() google_eer.reset_states() qualcomm_eer.reset_states() google_auc.reset_states() qualcomm_auc.reset_states()