Spaces:
Sleeping
Sleeping
import torch | |
import datetime | |
import numpy as np | |
from transformers import WhisperForConditionalGeneration, WhisperProcessor, WhisperTokenizer | |
from pydub import AudioSegment | |
import logging | |
# Set the path to FFmpeg and FFprobe | |
ffmpeg_path = "/usr/bin/ffmpeg" | |
ffprobe_path = "/usr/bin/ffprobe" | |
# Set the paths for pydub | |
AudioSegment.converter = ffmpeg_path | |
AudioSegment.ffmpeg = ffmpeg_path | |
AudioSegment.ffprobe = ffprobe_path | |
# Initialize logger | |
logger = logging.getLogger(__name__) | |
def set_logger(new_logger): | |
global logger | |
logger = new_logger | |
def format_timestamp(milliseconds): | |
"""Convert milliseconds to SRT timestamp format.""" | |
delta = datetime.timedelta(milliseconds=milliseconds) | |
hours, remainder = divmod(delta.seconds, 3600) | |
minutes, seconds = divmod(remainder, 60) | |
return f"{hours:02d}:{minutes:02d}:{seconds:02d},{milliseconds % 1000:03d}" | |
def generate_srt(transcriptions, chunk_length_ms, subtitle_duration_ms=5000): | |
"""Generate SRT content from transcribed chunks with specified subtitle duration.""" | |
srt_output = "" | |
srt_index = 1 | |
for i, chunk_text in enumerate(transcriptions): | |
chunk_start_time = i * chunk_length_ms | |
chunk_end_time = (i + 1) * chunk_length_ms | |
# Split chunk text into words | |
words = chunk_text.split() | |
# Calculate number of subtitles for this chunk | |
num_subtitles = max(1, int(chunk_length_ms / subtitle_duration_ms)) | |
words_per_subtitle = max(1, len(words) // num_subtitles) | |
for j in range(0, len(words), words_per_subtitle): | |
subtitle_words = words[j:j+words_per_subtitle] | |
subtitle_text = " ".join(subtitle_words) | |
start_time = chunk_start_time + (j // words_per_subtitle) * subtitle_duration_ms | |
end_time = min(start_time + subtitle_duration_ms, chunk_end_time) | |
srt_output += f"{srt_index}\n" | |
srt_output += f"{format_timestamp(start_time)} --> {format_timestamp(end_time)}\n" | |
srt_output += f"{subtitle_text}\n\n" | |
srt_index += 1 | |
return srt_output | |
def transcribe_audio(model_name, audio_path, language='en', chunk_length_ms=30000): | |
try: | |
# Check if CUDA is available | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
logger.info(f"Using device: {device}") | |
# Load model and processor | |
model = WhisperForConditionalGeneration.from_pretrained(model_name).to(device) | |
processor = WhisperProcessor.from_pretrained(model_name) | |
tokenizer = WhisperTokenizer.from_pretrained(model_name) | |
# Load audio | |
audio = AudioSegment.from_file(audio_path) | |
# Resample to 16000 Hz | |
audio = audio.set_frame_rate(16000) | |
# Initialize lists to store chunk transcriptions | |
chunk_transcriptions = [] | |
# Process audio in chunks | |
for i in range(0, len(audio), chunk_length_ms): | |
chunk = audio[i:i+chunk_length_ms] | |
# Convert chunk to numpy array | |
chunk_array = np.array(chunk.get_array_of_samples()).astype(np.float32) | |
# Normalize | |
chunk_array = chunk_array / np.max(np.abs(chunk_array)) | |
# Process audio chunk | |
input_features = processor(chunk_array, sampling_rate=16000, return_tensors="pt").input_features | |
input_features = input_features.to(device) | |
# Generate token ids | |
forced_decoder_ids = tokenizer.get_decoder_prompt_ids(language=language, task="transcribe") | |
predicted_ids = model.generate(input_features, forced_decoder_ids=forced_decoder_ids) | |
# Decode token ids to text | |
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True) | |
chunk_text = transcription[0].strip() | |
chunk_transcriptions.append(chunk_text) | |
# Print chunk transcription in real-time | |
print(f"Chunk {i // chunk_length_ms + 1} transcription:") | |
print(chunk_text) | |
print("-" * 50) | |
# Combine all chunk transcriptions | |
full_text = " ".join(chunk_transcriptions) | |
# Generate SRT content with 5-second subtitles | |
srt_content = generate_srt(chunk_transcriptions, chunk_length_ms, subtitle_duration_ms=5000) | |
return full_text, srt_content | |
except Exception as e: | |
logger.error(f"An error occurred during transcription: {str(e)}") | |
return None, None | |