import os, warnings, argparse import tensorflow as tf import numpy as np from model import ukws from dataset import dataloader_demo import gradio as gr # import librosa os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) warnings.filterwarnings('ignore') warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning) # np.warnings.filterwarnings('ignore', category=np.VisibleDeprecationWarning) warnings.simplefilter("ignore") seed = 42 tf.random.set_seed(seed) np.random.seed(seed) parser = argparse.ArgumentParser() parser.add_argument('--text_input', required=False, type=str, default='g2p_embed') parser.add_argument('--audio_input', required=False, type=str, default='both') parser.add_argument('--load_checkpoint_path', default='/share/nas165/yiting/CL-KWS_202408_v1/checkpoint_results/checkpoint_guided_ctc/20240725-011006') parser.add_argument('--keyword_list_length', default=8, type=int) # parser.add_argument('--load_checkpoint_path', required=True, type=str) # parser.add_argument('--keyword_list_length', required=True, type=int) parser.add_argument('--stack_extractor', action='store_true') parser.add_argument('--comment', required=False, type=str) args = parser.parse_args() gpus = tf.config.experimental.list_physical_devices('GPU') if gpus: try: for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) except RuntimeError as e: print(e) strategy = tf.distribute.MirroredStrategy() batch_size = args.keyword_list_length # Batch size per GPU GLOBAL_BATCH_SIZE = batch_size * strategy.num_replicas_in_sync # BATCH_SIZE_PER_REPLICA = GLOBAL_BATCH_SIZE / strategy.num_replicas_in_sync # Make Dataloader text_input = args.text_input audio_input = args.audio_input load_checkpoint_path = args.load_checkpoint_path phonemes = ["", ] + ['AA0', 'AA1', 'AA2', 'AE0', 'AE1', 'AE2', 'AH0', 'AH1', 'AH2', 'AO0', 'AO1', 'AO2', 'AW0', 'AW1', 'AW2', 'AY0', 'AY1', 'AY2', 'B', 'CH', 'D', 'DH', 'EH0', 'EH1', 'EH2', 'ER0', 'ER1', 'ER2', 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH0', 'IH1', 'IH2', 'IY0', 'IY1', 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW0', 'OW1', 'OW2', 'OY0', 'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH0', 'UH1', 'UH2', 'UW', 'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH', ' '] # Number of phonemes vocab = len(phonemes) # Model params. kwargs = { 'vocab' : vocab, 'text_input' : text_input, 'audio_input' : audio_input, 'frame_length' : 400, 'hop_length' : 160, 'num_mel' : 40, 'sample_rate' : 16000, 'log_mel' : False, 'stack_extractor' : args.stack_extractor, } # Make tensorboard dict. global keyword param = kwargs param['comment'] = args.comment with strategy.scope(): model = ukws.BaseUKWS(**kwargs) if args.load_checkpoint_path: checkpoint_dir=args.load_checkpoint_path checkpoint = tf.train.Checkpoint(model=model) checkpoint_manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=5) latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir) if latest_checkpoint: checkpoint.restore(latest_checkpoint) print("Checkpoint restored!") else: print("No checkpoint found.") def inference(audio,keyword): if isinstance(keyword, str): keyword = [kw.strip() for kw in keyword.split(',')] test_google_dataset = dataloader_demo.GoogleCommandsDataloader(batch_size=GLOBAL_BATCH_SIZE, features=text_input, wav_path_or_object=audio, keyword = keyword) test_google_dataset = dataloader_demo.convert_sequence_to_dataset(test_google_dataset) test_google_dist_dataset = strategy.experimental_distribute_dataset(test_google_dataset) # @tf.function def test_step_metric_only(inputs,keyword_list): clean_speech = inputs[0] text = inputs[1] labels = inputs[2] prob, affinity_matrix = model(clean_speech, text, training=False)[:2] prob=tf.round(prob * 1000) / 1000 prob = prob.numpy().flatten() max_indices = np.argmax(prob,axis=0) if prob[max_indices] >= 0.8: keyword = keyword_list[ max_indices] else : keyword = 'no keyword' print('keyword:',keyword_list) print('prob',prob) msg = '' for k, p in zip(keyword_list, prob): msg += '{} | {:.2f} \n'.format(k, p) return keyword, msg for x in test_google_dist_dataset: keyword, prob = test_step_metric_only(x,keyword) return keyword, prob # keyword = ['realtek go','ok google','vintage','hackney','crocodile','surroundings','oversaw','northwestern'] # audio = '/share/nas165/yiting/recording/ok_google/Default_20240725-183000.wav' # inference(audio,keyword) demo = gr.Interface( fn=inference, inputs=[gr.Audio(source="upload", label="Sound"), gr.Textbox(placeholder="Keyword List Here...", label="keyword_list")], examples=[ ["./recording/ok_google/ok_google-183000.wav", 'realtek go,ok google,vintage,hackney,crocodile,surroundings,oversaw,northwestern'], ["./recording/ok_google/ok_google-183005.wav", 'realtek go,ok google,vintage,hackney,crocodile,surroundings,oversaw,northwestern'], ["./recording/ok_google/ok_google-183008.wav", 'realtek go,ok google,vintage,hackney,crocodile,surroundings,oversaw,northwestern'], ["./recording/ok_google/ok_google-183011.wav", 'realtek go,ok google,vintage,hackney,crocodile,surroundings,oversaw,northwestern'], ["./recording/ok_google/ok_google-183015.wav", 'realtek go,ok google,vintage,hackney,crocodile,surroundings,oversaw,northwestern'], ["./recording/realtek_go/realtek_go-183029.wav", 'realtek go,ok google,vintage,hackney,crocodile,surroundings,oversaw,northwestern'], ["./recording/realtek_go/realtek_go-183033.wav", 'realtek go,ok google,vintage,hackney,crocodile,surroundings,oversaw,northwestern'], ["./recording/realtek_go/realtek_go-183036.wav", 'realtek go,ok google,vintage,hackney,crocodile,surroundings,oversaw,northwestern'], ["./recording/realtek_go/realtek_go-183039.wav", 'realtek go,ok google,vintage,hackney,crocodile,surroundings,oversaw,northwestern'], ["./recording/realtek_go/realtek_go-183043.wav", 'realtek go,ok google,vintage,hackney,crocodile,surroundings,oversaw,northwestern'], ], outputs=[gr.Textbox(label="keyword"), gr.Textbox(label="Confidence Score of keyword")], ) demo.launch(server_name='0.0.0.0', server_port=7860,share=True)