CL-KWS_202408_v1 / demo.py
Francis0917's picture
Upload folder using huggingface_hub
d58ae23 verified
import os, warnings, argparse
import tensorflow as tf
import numpy as np
from model import ukws
from dataset import dataloader_demo
import gradio as gr
# import librosa
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
warnings.filterwarnings('ignore')
warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning)
# np.warnings.filterwarnings('ignore', category=np.VisibleDeprecationWarning)
warnings.simplefilter("ignore")
seed = 42
tf.random.set_seed(seed)
np.random.seed(seed)
parser = argparse.ArgumentParser()
parser.add_argument('--text_input', required=False, type=str, default='g2p_embed')
parser.add_argument('--audio_input', required=False, type=str, default='both')
parser.add_argument('--load_checkpoint_path', default='/share/nas165/yiting/CL-KWS_202408_v1/checkpoint_results/checkpoint_guided_ctc/20240725-011006')
parser.add_argument('--keyword_list_length', default=8, type=int)
# parser.add_argument('--load_checkpoint_path', required=True, type=str)
# parser.add_argument('--keyword_list_length', required=True, type=int)
parser.add_argument('--stack_extractor', action='store_true')
parser.add_argument('--comment', required=False, type=str)
args = parser.parse_args()
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
try:
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
except RuntimeError as e:
print(e)
strategy = tf.distribute.MirroredStrategy()
batch_size = args.keyword_list_length
# Batch size per GPU
GLOBAL_BATCH_SIZE = batch_size * strategy.num_replicas_in_sync
# BATCH_SIZE_PER_REPLICA = GLOBAL_BATCH_SIZE / strategy.num_replicas_in_sync
# Make Dataloader
text_input = args.text_input
audio_input = args.audio_input
load_checkpoint_path = args.load_checkpoint_path
phonemes = ["<pad>", ] + ['AA0', 'AA1', 'AA2', 'AE0', 'AE1', 'AE2', 'AH0', 'AH1', 'AH2', 'AO0',
'AO1', 'AO2', 'AW0', 'AW1', 'AW2', 'AY0', 'AY1', 'AY2', 'B', 'CH',
'D', 'DH', 'EH0', 'EH1', 'EH2', 'ER0', 'ER1', 'ER2', 'EY0', 'EY1',
'EY2', 'F', 'G', 'HH', 'IH0', 'IH1', 'IH2', 'IY0', 'IY1', 'IY2',
'JH', 'K', 'L', 'M', 'N', 'NG', 'OW0', 'OW1', 'OW2', 'OY0',
'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH0', 'UH1',
'UH2', 'UW', 'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH',
' ']
# Number of phonemes
vocab = len(phonemes)
# Model params.
kwargs = {
'vocab' : vocab,
'text_input' : text_input,
'audio_input' : audio_input,
'frame_length' : 400,
'hop_length' : 160,
'num_mel' : 40,
'sample_rate' : 16000,
'log_mel' : False,
'stack_extractor' : args.stack_extractor,
}
# Make tensorboard dict.
global keyword
param = kwargs
param['comment'] = args.comment
with strategy.scope():
model = ukws.BaseUKWS(**kwargs)
if args.load_checkpoint_path:
checkpoint_dir=args.load_checkpoint_path
checkpoint = tf.train.Checkpoint(model=model)
checkpoint_manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=5)
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
if latest_checkpoint:
checkpoint.restore(latest_checkpoint)
print("Checkpoint restored!")
else:
print("No checkpoint found.")
def inference(audio,keyword):
if isinstance(keyword, str):
keyword = [kw.strip() for kw in keyword.split(',')]
test_google_dataset = dataloader_demo.GoogleCommandsDataloader(batch_size=GLOBAL_BATCH_SIZE, features=text_input, wav_path_or_object=audio, keyword = keyword)
test_google_dataset = dataloader_demo.convert_sequence_to_dataset(test_google_dataset)
test_google_dist_dataset = strategy.experimental_distribute_dataset(test_google_dataset)
# @tf.function
def test_step_metric_only(inputs,keyword_list):
clean_speech = inputs[0]
text = inputs[1]
labels = inputs[2]
prob, affinity_matrix = model(clean_speech, text, training=False)[:2]
prob=tf.round(prob * 1000) / 1000
prob = prob.numpy().flatten()
max_indices = np.argmax(prob,axis=0)
if prob[max_indices] >= 0.8:
keyword = keyword_list[ max_indices]
else :
keyword = 'no keyword'
print('keyword:',keyword_list)
print('prob',prob)
msg = ''
for k, p in zip(keyword_list, prob):
msg += '{} | {:.2f} \n'.format(k, p)
return keyword, msg
for x in test_google_dist_dataset:
keyword, prob = test_step_metric_only(x,keyword)
return keyword, prob
# keyword = ['realtek go','ok google','vintage','hackney','crocodile','surroundings','oversaw','northwestern']
# audio = '/share/nas165/yiting/recording/ok_google/Default_20240725-183000.wav'
# inference(audio,keyword)
demo = gr.Interface(
fn=inference,
inputs=[gr.Audio(source="upload", label="Sound"),
gr.Textbox(placeholder="Keyword List Here...", label="keyword_list")],
examples=[
["./recording/ok_google/ok_google-183000.wav", 'realtek go,ok google,vintage,hackney,crocodile,surroundings,oversaw,northwestern'],
["./recording/ok_google/ok_google-183005.wav", 'realtek go,ok google,vintage,hackney,crocodile,surroundings,oversaw,northwestern'],
["./recording/ok_google/ok_google-183008.wav", 'realtek go,ok google,vintage,hackney,crocodile,surroundings,oversaw,northwestern'],
["./recording/ok_google/ok_google-183011.wav", 'realtek go,ok google,vintage,hackney,crocodile,surroundings,oversaw,northwestern'],
["./recording/ok_google/ok_google-183015.wav", 'realtek go,ok google,vintage,hackney,crocodile,surroundings,oversaw,northwestern'],
["./recording/realtek_go/realtek_go-183029.wav", 'realtek go,ok google,vintage,hackney,crocodile,surroundings,oversaw,northwestern'],
["./recording/realtek_go/realtek_go-183033.wav", 'realtek go,ok google,vintage,hackney,crocodile,surroundings,oversaw,northwestern'],
["./recording/realtek_go/realtek_go-183036.wav", 'realtek go,ok google,vintage,hackney,crocodile,surroundings,oversaw,northwestern'],
["./recording/realtek_go/realtek_go-183039.wav", 'realtek go,ok google,vintage,hackney,crocodile,surroundings,oversaw,northwestern'],
["./recording/realtek_go/realtek_go-183043.wav", 'realtek go,ok google,vintage,hackney,crocodile,surroundings,oversaw,northwestern'],
],
outputs=[gr.Textbox(label="keyword"), gr.Textbox(label="Confidence Score of keyword")],
)
demo.launch(server_name='0.0.0.0', server_port=7860,share=True)