""" Inference main class. Author: Marcely Zanon Boito, 2024 """ from .CTC_model import mHubertForCTC import torch from transformers import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor, Wav2Vec2Processor from transformers import HubertConfig from datasets import load_dataset fbk_test_id = 'FBK-MT/Speech-MASSIVE-test' mhubert_id = 'utter-project/mHuBERT-147' def load_asr_model(): # Load the ASR model tokenizer = Wav2Vec2CTCTokenizer('vocab.json', unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|") feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(mhubert_id) processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer) config = HubertConfig.from_pretrained('config.json') model = mHubertForCTC.from_pretrained("naver/mHuBERT-147-ASR-fr", config=config) model.eval() return model, processor def run_asr_inference(model, processor, example): audio = processor(example["array"], sampling_rate=example["sampling_rate"]).input_values[0] input_values = torch.tensor(audio).unsqueeze(0) with torch.no_grad(): logits = model(input_values).logits pred_ids = torch.argmax(logits, dim=-1) prediction = processor.batch_decode(pred_ids)[0].replace('[CTC]', "") return prediction if __name__ == '__main__': # Load the dataset in streaming mode dataset = load_dataset(fbk_test_id, 'fr-FR', streaming=True) dataset = dataset['test'] generator = iter(dataset) # load model model, processor = load_asr_model() print(model) # decode 10 examples from speech-MASSIVE num_examples= 10 while num_examples >= 0: example = next(generator) prediction = run_inference(model, processor, example['audio']) gold_standard = example['utt'] print("Gold standard:", gold_standard) print("Prediction:", prediction) print() num_examples-=1