mms-zeroshot / zeroshot.py
vineelpratap's picture
Update zeroshot.py
56eee40 verified
raw
history blame
7.61 kB
import os
import tempfile
import re
import librosa
import torch
import json
import numpy as np
from transformers import Wav2Vec2ForCTC, AutoProcessor
from huggingface_hub import hf_hub_download
from torchaudio.models.decoder import ctc_decoder
from utils.text_norm import text_normalize
from utils.lm import create_unigram_lm, maybe_generate_pseudo_bigram_arpa
uroman_dir = "uroman"
assert os.path.exists(uroman_dir)
UROMAN_PL = os.path.join(uroman_dir, "bin", "uroman.pl")
ASR_SAMPLING_RATE = 16_000
WORD_SCORE_DEFAULT_IF_LM = -0.18
WORD_SCORE_DEFAULT_IF_NOLM = -3.5
LM_SCORE_DEFAULT = 1.48
MODEL_ID = "upload/mms_zs"
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
token_file = "upload/mms_zs/tokens.txt"
class MY_LOG:
def __init__(self):
self.text = "[START]"
def add(self, new_log, new_line=True):
self.text = self.text + ("\n" if new_line else " ") + new_log
self.text = self.text.strip()
return self.text
def error_check_file(filepath):
if not isinstance(filepath, str):
return "Expected file to be of type 'str'. Instead got {}".format(
type(filepath)
)
if not os.path.exists(filepath):
return "Input file '{}' doesn't exists".format(type(filepath))
def norm_uroman(text):
text = text.lower()
text = text.replace("’", "'")
text = re.sub("([^a-z' ])", " ", text)
text = re.sub(" +", " ", text)
return text.strip()
def uromanize(words):
iso = "xxx"
with tempfile.NamedTemporaryFile() as tf, tempfile.NamedTemporaryFile() as tf2:
with open(tf.name, "w") as f:
f.write("\n".join(words))
cmd = f"perl " + UROMAN_PL
cmd += f" -l {iso} "
cmd += f" < {tf.name} > {tf2.name}"
os.system(cmd)
lexicon = {}
with open(tf2.name) as f:
for idx, line in enumerate(f):
if not line.strip():
continue
line = re.sub(r"\s+", "", norm_uroman(line)).strip()
lexicon[words[idx]] = " ".join(line) + " |"
return lexicon
def filter_lexicon(lexicon, word_counts):
spelling_to_words = {}
for w, s in lexicon.items():
spelling_to_words.setdefault(s, [])
spelling_to_words[s].append(w)
lexicon = {}
for s, ws in spelling_to_words.items():
if len(ws) > 1:
# use the word which has higest counts, fewed additional characters
ws.sort(key=lambda w: (-word_counts[w], len(w)))
lexicon[ws[0]] = s
return lexicon
def load_words(filepath):
words = {}
with open(filepath) as f:
lines = f.readlines()
num_sentences = len(lines)
all_sentences = " ".join([l.strip() for l in lines])
norm_all_sentences = text_normalize(all_sentences)
for w in norm_all_sentences.split():
words.setdefault(w, 0)
words[w] += 1
return words, num_sentences
def process(
audio_data,
words_file,
lm_path=None,
wscore=None,
lmscore=None,
wscore_usedefault=True,
lmscore_usedefault=True,
autolm=True,
reference=None,
):
transcription, logs = "", MY_LOG()
if not audio_data or not words_file:
yield "ERROR: Empty audio data or words file", logs.text
return
if isinstance(audio_data, tuple):
# microphone
sr, audio_samples = audio_data
audio_samples = (audio_samples / 32768.0).astype(float)
if sr != ASR_SAMPLING_RATE:
audio_samples = librosa.resample(
audio_samples, orig_sr=sr, target_sr=ASR_SAMPLING_RATE
)
else:
# file upload
assert isinstance(audio_data, str)
audio_samples = librosa.load(audio_data, sr=ASR_SAMPLING_RATE, mono=True)[0]
yield transcription, logs.add(f"Number of audio samples: {len(audio_samples)}")
inputs = processor(
audio_samples, sampling_rate=ASR_SAMPLING_RATE, return_tensors="pt"
)
# set device
if torch.cuda.is_available():
device = torch.device("cuda")
elif (
hasattr(torch.backends, "mps")
and torch.backends.mps.is_available()
and torch.backends.mps.is_built()
):
device = torch.device("mps")
else:
device = torch.device("cpu")
#device = torch.device("cpu")
model.to(device)
inputs = inputs.to(device)
yield transcription, logs.add(f"Using device: {device}")
with torch.no_grad():
outputs = model(**inputs).logits
# Setup lexicon and decoder
yield transcription, logs.add(f"Loading words....")
try:
word_counts, num_sentences = load_words(words_file)
except Exception as e:
yield f"ERROR: Loading words failed '{str(e)}'", logs.text
return
yield transcription, logs.add(
f"Loaded {len(word_counts)} words from {num_sentences} lines.\nPreparing lexicon...."
)
try:
lexicon = uromanize(list(word_counts.keys()))
except Exception as e:
yield f"ERROR: Creating lexicon failed '{str(e)}'", logs.text
return
yield transcription, logs.add(f"Leixcon size: {len(lexicon)}")
# Input could be sentences OR list of words. Check if atleast one word has a count > 1 to diffentiate
tmp_file = tempfile.NamedTemporaryFile() # could be used for LM
if autolm and any([cnt > 2 for cnt in word_counts.values()]):
yield transcription, logs.add(f"Creating unigram LM...", False)
lm_path = tmp_file.name
create_unigram_lm(word_counts, num_sentences, lm_path)
yield transcription, logs.add(f"OK")
if lm_path is None:
yield transcription, logs.add(f"Filtering lexicon....")
lexicon = filter_lexicon(lexicon, word_counts)
yield transcription, logs.add(
f"Ok. Leixcon size after filtering: {len(lexicon)}"
)
else:
# kenlm throws an error if unigram LM is being used
# HACK: generate a bigram LM from unigram LM and a dummy bigram to trick it
maybe_generate_pseudo_bigram_arpa(lm_path)
with tempfile.NamedTemporaryFile() as lexicon_file:
if lm_path is not None and not lm_path.strip():
lm_path = None
with open(lexicon_file.name, "w") as f:
idx = 10
for word, spelling in lexicon.items():
f.write(word + " " + spelling + "\n")
idx += 1
if wscore_usedefault:
wscore = (
WORD_SCORE_DEFAULT_IF_LM
if lm_path is not None
else WORD_SCORE_DEFAULT_IF_NOLM
)
if lmscore_usedefault:
lmscore = LM_SCORE_DEFAULT if lm_path is not None else 0
yield transcription, logs.add(
f"Using word score: {wscore}\nUsing lm score: {lmscore}"
)
beam_search_decoder = ctc_decoder(
lexicon=lexicon_file.name,
tokens=token_file,
lm=lm_path,
nbest=1,
beam_size=500,
beam_size_token=50,
lm_weight=lmscore,
word_score=wscore,
sil_score=0,
blank_token="<s>",
)
beam_search_result = beam_search_decoder(outputs.to("cpu"))
transcription = " ".join(beam_search_result[0][0].words).strip()
yield transcription, logs.add(f"[DONE]")
# for i in process("upload/english/english.mp3", "upload/english/c4_5k_sentences.txt"):
# print(i)
# for i in process("upload/ligurian/ligurian_1.mp3", "upload/ligurian/zenamt_5k_sentences.txt"):
# print(i)