Spaces:
Runtime error
Runtime error
import os, warnings, argparse | |
import tensorflow as tf | |
import numpy as np | |
from model import ukws | |
from dataset import dataloader_demo | |
import gradio as gr | |
# import librosa | |
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' | |
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) | |
warnings.filterwarnings('ignore') | |
warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning) | |
# np.warnings.filterwarnings('ignore', category=np.VisibleDeprecationWarning) | |
warnings.simplefilter("ignore") | |
seed = 42 | |
tf.random.set_seed(seed) | |
np.random.seed(seed) | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--text_input', required=False, type=str, default='g2p_embed') | |
parser.add_argument('--audio_input', required=False, type=str, default='both') | |
parser.add_argument('--load_checkpoint_path', default='/share/nas165/yiting/CL-KWS_202408_v1/checkpoint_results/checkpoint_guided_ctc/20240725-011006') | |
parser.add_argument('--keyword_list_length', default=8, type=int) | |
# parser.add_argument('--load_checkpoint_path', required=True, type=str) | |
# parser.add_argument('--keyword_list_length', required=True, type=int) | |
parser.add_argument('--stack_extractor', action='store_true') | |
parser.add_argument('--comment', required=False, type=str) | |
args = parser.parse_args() | |
gpus = tf.config.experimental.list_physical_devices('GPU') | |
if gpus: | |
try: | |
for gpu in gpus: | |
tf.config.experimental.set_memory_growth(gpu, True) | |
except RuntimeError as e: | |
print(e) | |
strategy = tf.distribute.MirroredStrategy() | |
batch_size = args.keyword_list_length | |
# Batch size per GPU | |
GLOBAL_BATCH_SIZE = batch_size * strategy.num_replicas_in_sync | |
# BATCH_SIZE_PER_REPLICA = GLOBAL_BATCH_SIZE / strategy.num_replicas_in_sync | |
# Make Dataloader | |
text_input = args.text_input | |
audio_input = args.audio_input | |
load_checkpoint_path = args.load_checkpoint_path | |
phonemes = ["<pad>", ] + ['AA0', 'AA1', 'AA2', 'AE0', 'AE1', 'AE2', 'AH0', 'AH1', 'AH2', 'AO0', | |
'AO1', 'AO2', 'AW0', 'AW1', 'AW2', 'AY0', 'AY1', 'AY2', 'B', 'CH', | |
'D', 'DH', 'EH0', 'EH1', 'EH2', 'ER0', 'ER1', 'ER2', 'EY0', 'EY1', | |
'EY2', 'F', 'G', 'HH', 'IH0', 'IH1', 'IH2', 'IY0', 'IY1', 'IY2', | |
'JH', 'K', 'L', 'M', 'N', 'NG', 'OW0', 'OW1', 'OW2', 'OY0', | |
'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH0', 'UH1', | |
'UH2', 'UW', 'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH', | |
' '] | |
# Number of phonemes | |
vocab = len(phonemes) | |
# Model params. | |
kwargs = { | |
'vocab' : vocab, | |
'text_input' : text_input, | |
'audio_input' : audio_input, | |
'frame_length' : 400, | |
'hop_length' : 160, | |
'num_mel' : 40, | |
'sample_rate' : 16000, | |
'log_mel' : False, | |
'stack_extractor' : args.stack_extractor, | |
} | |
# Make tensorboard dict. | |
global keyword | |
param = kwargs | |
param['comment'] = args.comment | |
with strategy.scope(): | |
model = ukws.BaseUKWS(**kwargs) | |
if args.load_checkpoint_path: | |
checkpoint_dir=args.load_checkpoint_path | |
checkpoint = tf.train.Checkpoint(model=model) | |
checkpoint_manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=5) | |
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir) | |
if latest_checkpoint: | |
checkpoint.restore(latest_checkpoint) | |
print("Checkpoint restored!") | |
else: | |
print("No checkpoint found.") | |
def inference(audio,keyword): | |
if isinstance(keyword, str): | |
keyword = [kw.strip() for kw in keyword.split(',')] | |
test_google_dataset = dataloader_demo.GoogleCommandsDataloader(batch_size=GLOBAL_BATCH_SIZE, features=text_input, wav_path_or_object=audio, keyword = keyword) | |
test_google_dataset = dataloader_demo.convert_sequence_to_dataset(test_google_dataset) | |
test_google_dist_dataset = strategy.experimental_distribute_dataset(test_google_dataset) | |
# @tf.function | |
def test_step_metric_only(inputs,keyword_list): | |
clean_speech = inputs[0] | |
text = inputs[1] | |
labels = inputs[2] | |
prob, affinity_matrix = model(clean_speech, text, training=False)[:2] | |
prob=tf.round(prob * 1000) / 1000 | |
prob = prob.numpy().flatten() | |
max_indices = np.argmax(prob,axis=0) | |
if prob[max_indices] >= 0.8: | |
keyword = keyword_list[ max_indices] | |
else : | |
keyword = 'no keyword' | |
print('keyword:',keyword_list) | |
print('prob',prob) | |
msg = '' | |
for k, p in zip(keyword_list, prob): | |
msg += '{} | {:.2f} \n'.format(k, p) | |
return keyword, msg | |
for x in test_google_dist_dataset: | |
keyword, prob = test_step_metric_only(x,keyword) | |
return keyword, prob | |
# keyword = ['realtek go','ok google','vintage','hackney','crocodile','surroundings','oversaw','northwestern'] | |
# audio = '/share/nas165/yiting/recording/ok_google/Default_20240725-183000.wav' | |
# inference(audio,keyword) | |
demo = gr.Interface( | |
fn=inference, | |
inputs=[gr.Audio(source="upload", label="Sound"), | |
gr.Textbox(placeholder="Keyword List Here...", label="keyword_list")], | |
examples=[ | |
["./recording/ok_google/ok_google-183000.wav", 'realtek go,ok google,vintage,hackney,crocodile,surroundings,oversaw,northwestern'], | |
["./recording/ok_google/ok_google-183005.wav", 'realtek go,ok google,vintage,hackney,crocodile,surroundings,oversaw,northwestern'], | |
["./recording/ok_google/ok_google-183008.wav", 'realtek go,ok google,vintage,hackney,crocodile,surroundings,oversaw,northwestern'], | |
["./recording/ok_google/ok_google-183011.wav", 'realtek go,ok google,vintage,hackney,crocodile,surroundings,oversaw,northwestern'], | |
["./recording/ok_google/ok_google-183015.wav", 'realtek go,ok google,vintage,hackney,crocodile,surroundings,oversaw,northwestern'], | |
["./recording/realtek_go/realtek_go-183029.wav", 'realtek go,ok google,vintage,hackney,crocodile,surroundings,oversaw,northwestern'], | |
["./recording/realtek_go/realtek_go-183033.wav", 'realtek go,ok google,vintage,hackney,crocodile,surroundings,oversaw,northwestern'], | |
["./recording/realtek_go/realtek_go-183036.wav", 'realtek go,ok google,vintage,hackney,crocodile,surroundings,oversaw,northwestern'], | |
["./recording/realtek_go/realtek_go-183039.wav", 'realtek go,ok google,vintage,hackney,crocodile,surroundings,oversaw,northwestern'], | |
["./recording/realtek_go/realtek_go-183043.wav", 'realtek go,ok google,vintage,hackney,crocodile,surroundings,oversaw,northwestern'], | |
], | |
outputs=[gr.Textbox(label="keyword"), gr.Textbox(label="Confidence Score of keyword")], | |
) | |
demo.launch(server_name='0.0.0.0', server_port=7860,share=True) | |