import gradio as gr import time import logging import torch from sys import platform from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor from transformers.utils import is_flash_attn_2_available from languages import get_language_names from subtitle_manager import Subtitle from pytube import YouTube import os logging.basicConfig(level=logging.INFO) last_model = None pipe = None def write_file(output_file, subtitle): with open(output_file, 'w', encoding='utf-8') as f: f.write(subtitle) def create_pipe(model, flash): if torch.cuda.is_available(): device = "cuda:0" elif platform == "darwin": device = "mps" else: device = "cpu" torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 model_id = model model = AutoModelForSpeechSeq2Seq.from_pretrained( model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True, attn_implementation="flash_attention_2" if flash and is_flash_attn_2_available() else "sdpa", ) model.to(device) processor = AutoProcessor.from_pretrained(model_id) pipe = pipeline( "automatic-speech-recognition", model=model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor, torch_dtype=torch_dtype, device=device, ) return pipe def download_youtube_audio(url): yt = YouTube(url) audio_stream = yt.streams.filter(only_audio=True).first() output_file = audio_stream.download(filename='temp_audio') return output_file def transcribe_webui_simple_progress(modelName, languageName, urlData, multipleFiles, microphoneData, task, flash, chunk_length_s, batch_size, progress=gr.Progress()): global last_model global pipe progress(0, desc="Loading Audio..") logging.info(f"urlData:{urlData}") logging.info(f"multipleFiles:{multipleFiles}") logging.info(f"microphoneData:{microphoneData}") logging.info(f"task: {task}") logging.info(f"is_flash_attn_2_available: {is_flash_attn_2_available()}") logging.info(f"chunk_length_s: {chunk_length_s}") logging.info(f"batch_size: {batch_size}") if last_model == None: logging.info("first model") progress(0.1, desc="Loading Model..") pipe = create_pipe(modelName, flash) elif modelName != last_model: logging.info("new model") torch.cuda.empty_cache() progress(0.1, desc="Loading Model..") pipe = create_pipe(modelName, flash) else: logging.info("Model not changed") last_model = modelName srt_sub = Subtitle("srt") vtt_sub = Subtitle("vtt") txt_sub = Subtitle("txt") files = [] if multipleFiles: files += multipleFiles if urlData: try: progress(0.2, desc="Downloading YouTube Audio..") url_audio_file = download_youtube_audio(urlData) files.append(url_audio_file) except Exception as e: logging.error(f"Failed to download YouTube audio: {e}") return ["Error"], "Error", "Error" if microphoneData: files.append(microphoneData) logging.info(files) generate_kwargs = {} if languageName != "Automatic Detection" and modelName.endswith(".en") == False: generate_kwargs["language"] = languageName if modelName.endswith(".en") == False: generate_kwargs["task"] = task files_out = [] for file in progress.tqdm(files, desc="Working..."): start_time = time.time() logging.info(file) outputs = pipe( file, chunk_length_s=chunk_length_s, batch_size=batch_size, generate_kwargs=generate_kwargs, return_timestamps=True, ) logging.debug(outputs) logging.info(f"transcribe: {time.time() - start_time} sec.") file_out = file.split('/')[-1] srt = srt_sub.get_subtitle(outputs["chunks"]) vtt = vtt_sub.get_subtitle(outputs["chunks"]) txt = txt_sub.get_subtitle(outputs["chunks"]) write_file(file_out + ".srt", srt) write_file(file_out + ".vtt", vtt) write_file(file_out + ".txt", txt) files_out += [file_out + ".srt", file_out + ".vtt", file_out + ".txt"] progress(1, desc="Completed!") return files_out, vtt, txt with gr.Blocks(title="Insanely Fast Whisper") as demo: description = "An opinionated CLI to transcribe Audio files w/ Whisper on-device! Powered by 🤗 Transformers, Optimum & flash-attn" whisper_models = [ "openai/whisper-tiny", "openai/whisper-tiny.en", "openai/whisper-base", "openai/whisper-base.en", "openai/whisper-small", "openai/whisper-small.en", "distil-whisper/distil-small.en", "openai/whisper-medium", "openai/whisper-medium.en", "distil-whisper/distil-medium.en", "openai/whisper-large", "openai/whisper-large-v1", "openai/whisper-large-v2", "distil-whisper/distil-large-v2", "openai/whisper-large-v3", "distil-whisper/distil-large-v3", "xaviviro/whisper-large-v3-catalan-finetuned-v2", ] waveform_options = gr.WaveformOptions( waveform_color="#01C6FF", waveform_progress_color="#0066B4", skip_length=2, show_controls=False, ) simple_transcribe = gr.Interface(fn=transcribe_webui_simple_progress, description=description, inputs=[ gr.Dropdown(choices=whisper_models, value="distil-whisper/distil-large-v2", label="Model", info="Select whisper model", interactive=True), gr.Dropdown(choices=["Automatic Detection"] + sorted(get_language_names()), value="Automatic Detection", label="Language", info="Select audio voice language", interactive=True), gr.Text(label="URL", info="(YouTube, etc.)", interactive=True), gr.File(label="Upload Files", file_count="multiple"), gr.Audio(sources=["upload", "microphone"], type="filepath", label="Input", waveform_options=waveform_options), gr.Dropdown(choices=["transcribe", "translate"], label="Task", value="transcribe", interactive=True), gr.Checkbox(label='Flash', info='Use Flash Attention 2'), gr.Number(label='chunk_length_s', value=30, interactive=True), gr.Number(label='batch_size', value=24, interactive=True) ], outputs=[ gr.File(label="Download"), gr.Text(label="Transcription"), gr.Text(label="Segments") ] ) if __name__ == "__main__": demo.launch()