CoSTA / ST /inference /inference_req_hf.py
bhavanishankarpullela's picture
Upload 360 files
b817ab5 verified
import os
import numpy as np
import pandas as pd
import whisper
# import torchaudio
# import librosa
import sys
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
from tqdm.notebook import tqdm
# from whisper.normalizers import EnglishTextNormalizer
class Decoder:
"""
Class to perform ASR predictions
"""
def __init__(self, model_type, language='mr'):
"""Initialization of the class
More details on the whisper model and its types can be found here: https://github.com/openai/whisper
Convert HF model to openai-whisper: https://github.com/openai/whisper/discussions/830
Args:
model_type (str): Should be one of 'tiny', 'base', 'small', 'medium', 'large', 'large-v2'
"""
assert model_type in ['tiny', 'base', 'small', 'medium', 'large', 'large-v2', 'large-v3'], "Wrong model type"
print('Info: Loading model')
self.model = whisper.load_model(model_type)
self.decode_options = whisper.DecodingOptions(language=language, without_timestamps=True)
self.device = "cuda" # if torch.cuda.is_available() else "cpu"
print("Info: Initialization done")
def decode(self, filepath):
"""Get the transcription(in hindi) for the audio file
Args:
filepath (str): Absolute path of the audio file
Returns:
str: transcription of the audio in hindi
"""
print()
result = self.model.transcribe(filepath, language="mr", verbose=False, without_timestamps=True, fp16=False)
return result["text"]
if __name__ == "__main__":
assert len(sys.argv) == 2, "Ohh no, audio file seems to be missing"
audio_folder_path = sys.argv[1]
# Initialize the Decoder
obj = Decoder('large-v3', language='mr')
# Create a DataFrame to store file names and corresponding transcripts
transcripts_df = pd.DataFrame(columns=['MP3_File', 'Transcript'])
count=0
# Iterate through all MP3 files in the folder
for filename in os.listdir(audio_folder_path):
if filename.endswith(".mp3"):
mp3_file_path = os.path.join(audio_folder_path, filename)
# Decode the MP3 file
asr_output = obj.decode(mp3_file_path)
# print(asr_output)
# Append the file name and transcript to the DataFrame
transcripts_df = transcripts_df.append({'MP3_File': filename, 'Transcript': asr_output}, ignore_index=True)
count+=1
if count % 10 == 0:
print(f'{count} files done')
# Save the transcript to a text file
output_dir = "./"
# asr_save_path = os.path.join(output_dir, filename.replace(".mp3", ".txt"))
# with open(asr_save_path, 'w') as f:
# f.write(asr_output)
# Save the DataFrame to a CSV file
csv_save_path = os.path.join(output_dir, "transcripts_marathi.csv")
transcripts_df.to_csv(csv_save_path, index=False)
print("Transcription and CSV file creation completed.")