import sys, os, datetime, warnings, argparse import tensorflow as tf import numpy as np from model import ukws from dataset import google_infe202405 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) warnings.filterwarnings('ignore') warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning) np.warnings.filterwarnings('ignore', category=np.VisibleDeprecationWarning) warnings.simplefilter("ignore") seed = 42 tf.random.set_seed(seed) np.random.seed(seed) parser = argparse.ArgumentParser() parser.add_argument('--text_input', required=False, type=str, default='g2p_embed') parser.add_argument('--audio_input', required=False, type=str, default='both') parser.add_argument('--load_checkpoint_path', required=True, type=str) parser.add_argument('--google_pkl', required=False, type=str, default='/home/DB/data/google_test_all.pkl') parser.add_argument('--stack_extractor', action='store_true') args = parser.parse_args() gpus = tf.config.experimental.list_physical_devices('GPU') if gpus: try: for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) except RuntimeError as e: print(e) strategy = tf.distribute.MirroredStrategy() # Batch size per GPU GLOBAL_BATCH_SIZE = 1000 * strategy.num_replicas_in_sync BATCH_SIZE_PER_REPLICA = GLOBAL_BATCH_SIZE / strategy.num_replicas_in_sync # Make Dataloader text_input = args.text_input audio_input = args.audio_input load_checkpoint_path = args.load_checkpoint_path test_google_dataset = google_infe202405.GoogleCommandsDataloader(batch_size=GLOBAL_BATCH_SIZE, features=text_input, shuffle=False, pkl=args.google_pkl) test_google_dataset = google_infe202405.convert_sequence_to_dataset(test_google_dataset) test_google_dist_dataset = strategy.experimental_distribute_dataset(test_google_dataset) phonemes = ["", ] + ['AA0', 'AA1', 'AA2', 'AE0', 'AE1', 'AE2', 'AH0', 'AH1', 'AH2', 'AO0', 'AO1', 'AO2', 'AW0', 'AW1', 'AW2', 'AY0', 'AY1', 'AY2', 'B', 'CH', 'D', 'DH', 'EH0', 'EH1', 'EH2', 'ER0', 'ER1', 'ER2', 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH0', 'IH1', 'IH2', 'IY0', 'IY1', 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW0', 'OW1', 'OW2', 'OY0', 'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH0', 'UH1', 'UH2', 'UW', 'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH', ' '] # Number of phonemes vocab = len(phonemes) # Model params. kwargs = { 'vocab' : vocab, 'text_input' : text_input, 'audio_input' : audio_input, 'frame_length' : 400, 'hop_length' : 160, 'num_mel' : 40, 'sample_rate' : 16000, 'log_mel' : False, 'stack_extractor' : args.stack_extractor, } # Make tensorboard dict. param = kwargs with strategy.scope(): model = ukws.BaseUKWS(**kwargs) if args.load_checkpoint_path: checkpoint_dir=args.load_checkpoint_path checkpoint = tf.train.Checkpoint(model=model) checkpoint_manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=5) latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir) if latest_checkpoint: checkpoint.restore(latest_checkpoint) print("Checkpoint restored!") # @tf.function def test_step_metric_only(inputs): clean_speech = inputs[0] text = inputs[1] labels = inputs[2] prob = model(clean_speech, text, training=False)[0] dim1=labels.shape[0]//20 prob = tf.reshape(prob,[dim1,20]) labels = tf.reshape(labels,[dim1,20]) predictions = tf.math.argmax(prob, axis=1) actuals = tf.math.argmax(labels, axis=1) true_count = tf.reduce_sum(tf.cast(tf.math.equal(predictions , actuals), tf.float32)).numpy() num_testdata = dim1 return true_count, num_testdata def distributed_test_step_metric_only(dataset_inputs): true_count, num_testdata = strategy.run(test_step_metric_only, args=(dataset_inputs,)) return true_count, num_testdata total_true_count = 0 total_num_testdata = 0 for x in test_google_dist_dataset: true_count, num_testdata = distributed_test_step_metric_only(x) total_true_count += true_count total_num_testdata += num_testdata accuracy = total_true_count / total_num_testdata * 100.0 print("準確率:", accuracy, "%")