CL-KWS_202408_v1 / train_CLKWS.py
Francis0917's picture
Upload folder using huggingface_hub
2045faa verified
raw
history blame
16.3 kB
import sys, os, datetime, warnings, argparse
import tensorflow as tf
import numpy as np
from model import ukws
from dataset import libriphrase, google, qualcomm,libriphrase
from criterion import total_CLKWS
from criterion.utils import eer
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
warnings.filterwarnings('ignore')
warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning)
np.warnings.filterwarnings('ignore', category=np.VisibleDeprecationWarning)
warnings.simplefilter("ignore")
seed = 42
tf.random.set_seed(seed)
np.random.seed(seed)
parser = argparse.ArgumentParser()
parser.add_argument('--epoch', required=True, type=int)
parser.add_argument('--lr', required=True, type=float)
parser.add_argument('--loss_weight', default=[1.0, 1.0], nargs=2, type=float)
parser.add_argument('--text_input', required=False, type=str, default='g2p_embed')
parser.add_argument('--audio_input', required=False, type=str, default='both')
parser.add_argument('--load_checkpoint_path', required=True, type=str)
parser.add_argument('--train_pkl', required=False, type=str, default='/home/DB/data/libriandGSC__fix300_5_shuffle.pkl')
parser.add_argument('--google_pkl', required=False, type=str, default='/home/DB//data/google.pkl')
parser.add_argument('--qualcomm_pkl', required=False, type=str, default='/home/DB/data/qualcomm.pkl')
parser.add_argument('--libriphrase_pkl', required=False, type=str, default='/home/DB/data/test_both.pkl')
parser.add_argument('--stack_extractor', action='store_true')
parser.add_argument('--comment', required=False, type=str)
args = parser.parse_args()
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
try:
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
except RuntimeError as e:
print(e)
strategy = tf.distribute.MirroredStrategy()
# Batch size per GPU
# GLOBAL_BATCH_SIZE = 400 * strategy.num_replicas_in_sync
GLOBAL_BATCH_SIZE = 1000
BATCH_SIZE_PER_REPLICA = GLOBAL_BATCH_SIZE / strategy.num_replicas_in_sync
# Make Dataloader
text_input = args.text_input
audio_input = args.audio_input
load_checkpoint_path = args.load_checkpoint_path
train_dataset = libriphrase.LibriPhraseDataloader(batch_size=GLOBAL_BATCH_SIZE, features=text_input, train=True, types='both', shuffle=True, pkl=args.train_pkl)
test_dataset = libriphrase.LibriPhraseDataloader(batch_size=GLOBAL_BATCH_SIZE, features=text_input, train=False, types='both', shuffle=True, pkl=args.libriphrase_pkl)
test_easy_dataset = libriphrase.LibriPhraseDataloader(batch_size=GLOBAL_BATCH_SIZE, features=text_input, train=False, types='easy', shuffle=True, pkl=args.libriphrase_pkl)
test_hard_dataset = libriphrase.LibriPhraseDataloader(batch_size=GLOBAL_BATCH_SIZE, features=text_input, train=False, types='hard', shuffle=True, pkl=args.libriphrase_pkl)
test_google_dataset = google.GoogleCommandsDataloader(batch_size=GLOBAL_BATCH_SIZE, features=text_input, shuffle=True, pkl=args.google_pkl)
test_qualcomm_dataset = qualcomm.QualcommKeywordSpeechDataloader(batch_size=GLOBAL_BATCH_SIZE, features=text_input, shuffle=True, pkl=args.qualcomm_pkl)
# Number of phonemes
vocab = train_dataset.nPhoneme
# Convert tf.utils.sequence to tf.dataset
train_dataset = libriphrase.convert_sequence_to_dataset(train_dataset)
test_dataset = libriphrase.convert_sequence_to_dataset(test_dataset)
test_easy_dataset = libriphrase.convert_sequence_to_dataset(test_easy_dataset)
test_hard_dataset = libriphrase.convert_sequence_to_dataset(test_hard_dataset)
test_google_dataset = google.convert_sequence_to_dataset(test_google_dataset)
test_qualcomm_dataset = qualcomm.convert_sequence_to_dataset(test_qualcomm_dataset)
# Make disribute dataset for multi-gpu
train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset)
test_dist_dataset = strategy.experimental_distribute_dataset(test_dataset)
test_easy_dist_dataset = strategy.experimental_distribute_dataset(test_easy_dataset)
test_hard_dist_dataset = strategy.experimental_distribute_dataset(test_hard_dataset)
test_google_dist_dataset = strategy.experimental_distribute_dataset(test_google_dataset)
test_qualcomm_dist_dataset = strategy.experimental_distribute_dataset(test_qualcomm_dataset)
# Model params.
kwargs = {
'vocab' : vocab,
'text_input' : text_input,
'audio_input' : audio_input,
'frame_length' : 400,
'hop_length' : 160,
'num_mel' : 40,
'sample_rate' : 16000,
'log_mel' : False,
'stack_extractor' : args.stack_extractor,
}
# Train params.
EPOCHS = args.epoch
lr = args.lr
# Make tensorboard dict.
param = kwargs
param['epoch'] = EPOCHS
param['lr'] = lr
param['loss weight'] = args.loss_weight
param['comment'] = args.comment
checkpoint_dir = './checkpoint_results/checkpoint_CLKWS/' + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
result_dir='./checkpoint_results/result_CLKWS/' + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
if not os.path.exists(result_dir):
os.makedirs(result_dir, exist_ok=True)
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir, exist_ok=True)
result_file=os.path.join(result_dir, "result.txt")
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
tensorboard_prefix = os.path.join(checkpoint_dir, "tensorboard")
with strategy.scope():
loss_object = total_CLKWS.TotalLoss_SCE(weight=args.loss_weight)
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_loss_d = tf.keras.metrics.Mean(name='train_loss_Utt')
train_loss_sce = tf.keras.metrics.Mean(name='train_loss_Phon')
test_loss = tf.keras.metrics.Mean(name='test_loss')
test_loss_d = tf.keras.metrics.Mean(name='test_loss_Utt')
train_auc = tf.keras.metrics.AUC(name='train_auc')
train_eer = eer(name='train_eer')
test_auc = tf.keras.metrics.AUC(name='test_auc')
test_eer = eer(name='test_eer')
test_easy_auc = tf.keras.metrics.AUC(name='test_easy_auc')
test_easy_eer = eer(name='test_easy_eer')
test_hard_auc = tf.keras.metrics.AUC(name='test_hard_auc')
test_hard_eer = eer(name='test_hard_eer')
google_auc = tf.keras.metrics.AUC(name='google_auc')
google_eer = eer(name='google_eer')
qualcomm_auc = tf.keras.metrics.AUC(name='qualcomm_auc')
qualcomm_eer = eer(name='qualcomm_eer')
model = ukws.BaseUKWS(**kwargs)
optimizer = tf.keras.optimizers.Adam(learning_rate=lr)
checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
if args.load_checkpoint_path:
checkpoint_dir=args.load_checkpoint_path
checkpoint = tf.train.Checkpoint(model=model)
checkpoint_manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=5)
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
if latest_checkpoint:
checkpoint.restore(latest_checkpoint)
print("Checkpoint restored!")
@tf.function
def train_step(inputs):
clean_speech, noisy_speech, text, labels, speech_labels, text_labels = inputs
with tf.GradientTape(watch_accessed_variables=False, persistent=False) as tape:
model(clean_speech, text, training=False)
tape.watch(model.trainable_variables)
prob, affinity_matrix, LD, sce_logit = model(noisy_speech, text, training=True)
loss, LD, LC = loss_object(labels, LD, speech_labels, text_labels, sce_logit, prob)
loss /= GLOBAL_BATCH_SIZE
LC /= GLOBAL_BATCH_SIZE
LD /= GLOBAL_BATCH_SIZE
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
train_loss.update_state(loss)
train_loss_d.update_state(LD)
train_loss_sce.update_state(LC)
train_auc.update_state(labels, prob)
train_eer.update_state(labels, prob)
return loss, tf.expand_dims(tf.cast(affinity_matrix * 255, tf.uint8), -1), labels
@tf.function
def test_step(inputs):
clean_speech = inputs[0]
text = inputs[1]
labels = inputs[2]
prob, affinity_matrix, LD, LC = model(clean_speech, text, training=False)[:4]
t_loss, LD = total_CLKWS.TotalLoss(weight=args.loss_weight[0])(labels, LD)
t_loss /= GLOBAL_BATCH_SIZE
LD /= GLOBAL_BATCH_SIZE
test_loss.update_state(t_loss)
test_loss_d.update_state(LD)
test_auc.update_state(labels, prob)
test_eer.update_state(labels, prob)
return t_loss, tf.expand_dims(tf.cast(affinity_matrix * 255, tf.uint8), -1), labels
@tf.function
def test_step_metric_only(inputs, metric=[]):
clean_speech = inputs[0]
text = inputs[1]
labels = inputs[2]
prob = model(clean_speech, text, training=False)[0]
for m in metric:
m.update_state(labels, prob)
train_log_dir = os.path.join(tensorboard_prefix, "train")
test_log_dir = os.path.join(tensorboard_prefix, "test")
train_summary_writer = tf.summary.create_file_writer(train_log_dir)
test_summary_writer = tf.summary.create_file_writer(test_log_dir)
def distributed_train_step(dataset_inputs):
per_replica_losses, per_replica_affinity_matrix, per_replica_labels = strategy.run(train_step, args=(dataset_inputs,))
return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None), strategy.experimental_local_results(per_replica_affinity_matrix)[0], strategy.experimental_local_results(per_replica_labels)[0]
def distributed_test_step(dataset_inputs):
per_replica_losses, per_replica_affinity_matrix, per_replica_labels = strategy.run(test_step, args=(dataset_inputs,))
return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None), strategy.experimental_local_results(per_replica_affinity_matrix)[0], strategy.experimental_local_results(per_replica_labels)[0]
def distributed_test_step_metric_only(dataset_inputs, metric=[]):
strategy.run(test_step_metric_only, args=(dataset_inputs, metric))
with train_summary_writer.as_default():
tf.summary.text('Hyperparameters', tf.stack([tf.convert_to_tensor([k, str(v)]) for k, v in param.items()]), step=0)
for epoch in range(EPOCHS):
# TRAIN LOOP
train_matrix = None
train_labels = None
test_matrix = None
train_labels = None
for i, x in enumerate(train_dist_dataset):
_, train_matrix, train_labels = distributed_train_step(x)
match_train_matrix = []
unmatch_train_matrix = []
for i, x in enumerate(train_labels):
if x == 1:
match_train_matrix.append(train_matrix[i])
elif x == 0:
unmatch_train_matrix.append(train_matrix[i])
with train_summary_writer.as_default():
tf.summary.scalar('0. Total loss', train_loss.result(), step=epoch)
tf.summary.scalar('1. Utterance-level Detection loss', train_loss_d.result(), step=epoch)
tf.summary.scalar('2. Phoneme-levle Detection loss', train_loss_sce.result(), step=epoch)
tf.summary.scalar('3. AUC', train_auc.result(), step=epoch)
tf.summary.scalar('4. EER', train_eer.result(), step=epoch)
tf.summary.image("Affinity matrix (match)", match_train_matrix, max_outputs=5, step=epoch)
tf.summary.image("Affinity matrix (unmatch)", unmatch_train_matrix, max_outputs=5, step=epoch)
# TEST LOOP
for x in test_dist_dataset:
_, test_matrix, test_labels = distributed_test_step(x)
match_test_matrix = []
unmatch_test_matrix = []
for i, x in enumerate(test_labels):
if x == 1:
match_test_matrix.append(test_matrix[i])
elif x == 0:
unmatch_test_matrix.append(test_matrix[i])
for x in test_easy_dist_dataset:
distributed_test_step_metric_only(x, metric=[test_easy_auc, test_easy_eer])
for x in test_hard_dist_dataset:
distributed_test_step_metric_only(x, metric=[test_hard_auc, test_hard_eer])
for x in test_google_dist_dataset:
distributed_test_step_metric_only(x, metric=[google_auc, google_eer])
for x in test_qualcomm_dist_dataset:
distributed_test_step_metric_only(x, metric=[qualcomm_auc, qualcomm_eer])
with test_summary_writer.as_default():
tf.summary.scalar('0. Total loss', test_loss.result(), step=epoch)
tf.summary.scalar('1. Utterance-level Detection loss', test_loss_d.result(), step=epoch)
tf.summary.scalar('3. AUC', test_auc.result(), step=epoch)
tf.summary.scalar('3. AUC (EASY)', test_easy_auc.result(), step=epoch)
tf.summary.scalar('3. AUC (HARD)', test_hard_auc.result(), step=epoch)
tf.summary.scalar('3. AUC (Google)', google_auc.result(), step=epoch)
tf.summary.scalar('3. AUC (Qualcomm)', qualcomm_auc.result(), step=epoch)
tf.summary.scalar('4. EER', test_eer.result(), step=epoch)
tf.summary.scalar('4. EER (EASY)', test_easy_eer.result(), step=epoch)
tf.summary.scalar('4. EER (HARD)', test_hard_eer.result(), step=epoch)
tf.summary.scalar('4. EER (Google)', google_eer.result(), step=epoch)
tf.summary.scalar('4. EER (Qualcomm)', qualcomm_eer.result(), step=epoch)
tf.summary.image("Affinity matrix (match)", match_test_matrix, max_outputs=5, step=epoch)
tf.summary.image("Affinity matrix (unmatch)", unmatch_test_matrix, max_outputs=5, step=epoch)
if epoch % 1 == 0:
checkpoint.save(checkpoint_prefix)
template = ("Epoch {} | TRAIN | Loss {:.3f}, AUC {:.2f}, EER {:.2f} | EER | G {:.2f}, Q {:.2f}, LE {:.2f}, LH {:.2f} | AUC | G {:.2f}, Q {:.2f}, LE {:.2f}, LH {:.2f} |")
print (template.format(epoch + 1,
train_loss.result(),
train_auc.result() * 100,
train_eer.result() * 100,
google_eer.result() * 100,
qualcomm_eer.result() * 100,
test_easy_eer.result() * 100,
test_hard_eer.result() * 100,
google_auc.result() * 100,
qualcomm_auc.result() * 100,
test_easy_auc.result() * 100,
test_hard_auc.result() * 100,
)
)
with open(result_file, 'a') as file:
file.write(template.format(epoch + 1,
train_loss.result(),
train_auc.result() * 100,
train_eer.result() * 100,
google_eer.result() * 100,
qualcomm_eer.result() * 100,
test_easy_eer.result() * 100,
test_hard_eer.result() * 100,
google_auc.result() * 100,
qualcomm_auc.result() * 100,
test_easy_auc.result() * 100,
test_hard_auc.result() * 100,
)+'\n'
)
train_loss.reset_states()
test_loss.reset_states()
train_auc.reset_states()
test_auc.reset_states()
test_easy_auc.reset_states()
test_hard_auc.reset_states()
train_eer.reset_states()
test_eer.reset_states()
test_easy_eer.reset_states()
test_hard_eer.reset_states()
google_eer.reset_states()
qualcomm_eer.reset_states()
google_auc.reset_states()
qualcomm_auc.reset_states()