legen / whisperx_utils.py
RafaG's picture
Upload 24 files
5fa5566 verified
import os
from pathlib import Path
import whisperx
import whisper # only for detect language
import whisper_utils
import subtitle_utils
from utils import time_task
def transcribe_audio(model: whisperx.asr.WhisperModel, audio_path: Path, srt_path: Path, lang: str = None, device: str = "cpu", batch_size: int = 4):
audio = whisperx.load_audio(file=audio_path.as_posix(), sr=model.model.feature_extractor.sampling_rate)
# Define the progress callback function
def progress_callback(state, current: int = None, total: int = None):
args = state, current, total
args = [arg for arg in args if arg is not None]
if len(args) == 1:
state = args[0]
if len(args) > 1:
total = args[-1]
current = args[-2]
state = None
if len(args) > 2:
state = args[-3]
try:
if state is None:
state = "WhisperX"
elif type(state) == 'String' or type(state) == int:
state = state
else:
state = state.value
except:
state = "WhisperX"
print('\r \r' + state + ((': ' + str(round(current/total*100)) + '%') if current and total else '') + ((' [' + str(current) + '/' + str(total) + ']') if current and total else ''), end=' ', flush=True)
# Transcribe
with time_task("Running WhisperX transcription engine...", end='\n'):
transcribe = model.transcribe(audio=audio, language=lang, batch_size=batch_size, on_progress=progress_callback)
# Align if possible
if lang in whisperx.alignment.DEFAULT_ALIGN_MODELS_HF or lang in whisperx.alignment.DEFAULT_ALIGN_MODELS_TORCH:
with time_task(message_start="Running alignment...", end='\n'):
try:
model_a, metadata = whisperx.load_align_model(language_code=lang, device="cuda")
transcribe = whisperx.align(transcript=transcribe["segments"], model=model_a, align_model_metadata=metadata, audio=audio, device="cuda", return_char_alignments=True, on_progress=progress_callback)
except Exception:
model_a, metadata = whisperx.load_align_model(language_code=lang, device="cpu") # force load on cpu due errors on gpu
transcribe = whisperx.align(transcript=transcribe["segments"], model=model_a, align_model_metadata=metadata, audio=audio, device="cpu", return_char_alignments=True, on_progress=progress_callback)
else:
print(f"Language {lang} not suported for alignment. Skipping this step")
# Format subtitles
segments = subtitle_utils.format_segments(transcribe['segments'])
# Save the subtitle file
subtitle_utils.SaveSegmentsToSrt(segments, srt_path)
return transcribe
def detect_language(model: whisperx.asr.WhisperModel, audio_path: Path):
try:
if os.getenv("COLAB_RELEASE_TAG"):
raise Exception("Method invalid for Google Colab")
audio = whisperx.load_audio(audio_path.as_posix(), model.model.feature_extractor.sampling_rate)
audio = whisper.pad_or_trim(audio, model.model.feature_extractor.n_samples)
mel = whisperx.asr.log_mel_spectrogram(audio, n_mels=model.model.model.n_mels)
encoder_output = model.model.encode(mel)
results = model.model.model.detect_language(encoder_output)
language_token, language_probability = results[0][0]
return language_token[2:-2]
except:
print("using whisper base model for detection: ", end='')
whisper_model = whisper.load_model("base", device="cpu", in_memory=True)
return whisper_utils.detect_language(model=whisper_model, audio_path=audio_path)