Spaces:
Sleeping
Sleeping
import gradio as gr | |
import time | |
import logging | |
import torch | |
from sys import platform | |
from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor | |
from transformers.utils import is_flash_attn_2_available | |
from languages import get_language_names | |
from subtitle_manager import Subtitle | |
from pytube import YouTube | |
import os | |
logging.basicConfig(level=logging.INFO) | |
last_model = None | |
pipe = None | |
def write_file(output_file, subtitle): | |
with open(output_file, 'w', encoding='utf-8') as f: | |
f.write(subtitle) | |
def create_pipe(model, flash): | |
if torch.cuda.is_available(): | |
device = "cuda:0" | |
elif platform == "darwin": | |
device = "mps" | |
else: | |
device = "cpu" | |
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
model_id = model | |
model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
model_id, | |
torch_dtype=torch_dtype, | |
low_cpu_mem_usage=True, | |
use_safetensors=True, | |
attn_implementation="flash_attention_2" if flash and is_flash_attn_2_available() else "sdpa", | |
) | |
model.to(device) | |
processor = AutoProcessor.from_pretrained(model_id) | |
pipe = pipeline( | |
"automatic-speech-recognition", | |
model=model, | |
tokenizer=processor.tokenizer, | |
feature_extractor=processor.feature_extractor, | |
torch_dtype=torch_dtype, | |
device=device, | |
) | |
return pipe | |
def download_youtube_audio(url): | |
yt = YouTube(url) | |
audio_stream = yt.streams.filter(only_audio=True).first() | |
output_file = audio_stream.download(filename='temp_audio') | |
return output_file | |
def transcribe_webui_simple_progress(modelName, languageName, urlData, multipleFiles, microphoneData, task, flash, | |
chunk_length_s, batch_size, progress=gr.Progress()): | |
global last_model | |
global pipe | |
progress(0, desc="Loading Audio..") | |
logging.info(f"urlData:{urlData}") | |
logging.info(f"multipleFiles:{multipleFiles}") | |
logging.info(f"microphoneData:{microphoneData}") | |
logging.info(f"task: {task}") | |
logging.info(f"is_flash_attn_2_available: {is_flash_attn_2_available()}") | |
logging.info(f"chunk_length_s: {chunk_length_s}") | |
logging.info(f"batch_size: {batch_size}") | |
if last_model == None: | |
logging.info("first model") | |
progress(0.1, desc="Loading Model..") | |
pipe = create_pipe(modelName, flash) | |
elif modelName != last_model: | |
logging.info("new model") | |
torch.cuda.empty_cache() | |
progress(0.1, desc="Loading Model..") | |
pipe = create_pipe(modelName, flash) | |
else: | |
logging.info("Model not changed") | |
last_model = modelName | |
srt_sub = Subtitle("srt") | |
vtt_sub = Subtitle("vtt") | |
txt_sub = Subtitle("txt") | |
files = [] | |
if multipleFiles: | |
files += multipleFiles | |
if urlData: | |
try: | |
progress(0.2, desc="Downloading YouTube Audio..") | |
url_audio_file = download_youtube_audio(urlData) | |
files.append(url_audio_file) | |
except Exception as e: | |
logging.error(f"Failed to download YouTube audio: {e}") | |
return ["Error"], "Error", "Error" | |
if microphoneData: | |
files.append(microphoneData) | |
logging.info(files) | |
generate_kwargs = {} | |
if languageName != "Automatic Detection" and modelName.endswith(".en") == False: | |
generate_kwargs["language"] = languageName | |
if modelName.endswith(".en") == False: | |
generate_kwargs["task"] = task | |
files_out = [] | |
for file in progress.tqdm(files, desc="Working..."): | |
start_time = time.time() | |
logging.info(file) | |
outputs = pipe( | |
file, | |
chunk_length_s=chunk_length_s, | |
batch_size=batch_size, | |
generate_kwargs=generate_kwargs, | |
return_timestamps=True, | |
) | |
logging.debug(outputs) | |
logging.info(f"transcribe: {time.time() - start_time} sec.") | |
file_out = file.split('/')[-1] | |
srt = srt_sub.get_subtitle(outputs["chunks"]) | |
vtt = vtt_sub.get_subtitle(outputs["chunks"]) | |
txt = txt_sub.get_subtitle(outputs["chunks"]) | |
write_file(file_out + ".srt", srt) | |
write_file(file_out + ".vtt", vtt) | |
write_file(file_out + ".txt", txt) | |
files_out += [file_out + ".srt", file_out + ".vtt", file_out + ".txt"] | |
progress(1, desc="Completed!") | |
return files_out, vtt, txt | |
with gr.Blocks(title="Insanely Fast Whisper") as demo: | |
description = "An opinionated CLI to transcribe Audio files w/ Whisper on-device! Powered by 🤗 Transformers, Optimum & flash-attn" | |
whisper_models = [ | |
"openai/whisper-tiny", "openai/whisper-tiny.en", | |
"openai/whisper-base", "openai/whisper-base.en", | |
"openai/whisper-small", "openai/whisper-small.en", "distil-whisper/distil-small.en", | |
"openai/whisper-medium", "openai/whisper-medium.en", "distil-whisper/distil-medium.en", | |
"openai/whisper-large", | |
"openai/whisper-large-v1", | |
"openai/whisper-large-v2", "distil-whisper/distil-large-v2", | |
"openai/whisper-large-v3", "distil-whisper/distil-large-v3", "xaviviro/whisper-large-v3-catalan-finetuned-v2", | |
] | |
waveform_options = gr.WaveformOptions( | |
waveform_color="#01C6FF", | |
waveform_progress_color="#0066B4", | |
skip_length=2, | |
show_controls=False, | |
) | |
simple_transcribe = gr.Interface(fn=transcribe_webui_simple_progress, | |
description=description, | |
inputs=[ | |
gr.Dropdown(choices=whisper_models, value="distil-whisper/distil-large-v2", label="Model", info="Select whisper model", interactive=True), | |
gr.Dropdown(choices=["Automatic Detection"] + sorted(get_language_names()), value="Automatic Detection", label="Language", info="Select audio voice language", interactive=True), | |
gr.Text(label="URL", info="(YouTube, etc.)", interactive=True), | |
gr.File(label="Upload Files", file_count="multiple"), | |
gr.Audio(sources=["upload", "microphone"], type="filepath", label="Input", waveform_options=waveform_options), | |
gr.Dropdown(choices=["transcribe", "translate"], label="Task", value="transcribe", interactive=True), | |
gr.Checkbox(label='Flash', info='Use Flash Attention 2'), | |
gr.Number(label='chunk_length_s', value=30, interactive=True), | |
gr.Number(label='batch_size', value=24, interactive=True) | |
], outputs=[ | |
gr.File(label="Download"), | |
gr.Text(label="Transcription"), | |
gr.Text(label="Segments") | |
] | |
) | |
if __name__ == "__main__": | |
demo.launch() |