litagin's picture
Fix design
d50db7b
raw
history blame
4.25 kB
import os
import time
import warnings
from pathlib import Path
import gradio as gr
import huggingface_hub
import librosa
import spaces
import torch
from loguru import logger
from transformers import pipeline
warnings.filterwarnings("ignore")
huggingface_hub.login(token=os.getenv("HF_TOKEN"))
is_hf = os.getenv("SYSTEM") == "spaces"
generate_kwargs = {
"language": "Japanese",
"do_sample": False,
"num_beams": 1,
"no_repeat_ngram_size": 0,
"max_new_tokens": 64,
}
model_dict = {
"whisper-large-v2": "openai/whisper-large-v2",
"whisper-large-v3": "openai/whisper-large-v3",
"kotoba-whisper-v2.0": "kotoba-tech/kotoba-whisper-v2.0",
"anime-whisper": "litagin/anime-whisper",
}
logger.info("Initializing pipelines...")
pipe_dict = {
k: pipeline(
"automatic-speech-recognition",
model=v,
device="cuda" if torch.cuda.is_available() else "cpu",
)
for k, v in model_dict.items()
}
logger.success("Pipelines initialized!")
@spaces.GPU
def transcribe_common(audio: str, model: str) -> str:
if not audio:
return "No audio file"
filename = Path(audio).name
logger.info(f"Model: {model}")
logger.info(f"Audio: {filename}")
# Read and resample audio to 16kHz
y, sr = librosa.load(audio, mono=True, sr=16000)
# Get duration of audio
duration = librosa.get_duration(y=y, sr=sr)
logger.info(f"Duration: {duration:.2f}s")
if duration > 15:
logger.error(f"Audio too long, limit is 15 seconds, got {duration:.2f}s")
return f"Audio too long, limit is 15 seconds, got {duration:.2f}s"
start_time = time.time()
result = pipe_dict[model](y, generate_kwargs=generate_kwargs)["text"]
end_time = time.time()
logger.success(f"Finished in {end_time - start_time:.2f}s\n{result}")
return result
def transcribe_others(audio) -> tuple[str, str, str]:
result_v2 = transcribe_common(audio, "whisper-large-v2")
result_v3 = transcribe_common(audio, "whisper-large-v3")
result_kotoba_v2 = transcribe_common(audio, "kotoba-whisper-v2.0")
return result_v2, result_v3, result_kotoba_v2
def transcribe_anime_whisper(audio) -> str:
return transcribe_common(audio, "anime-whisper")
initial_md = """
# Anime-Whisper Demo
- 音声認識モデル [kotoba-whisper-v2.0](https://huggingface.co/kotoba-tech/kotoba-whisper-v2.0) をファインチューンしたモデルのお試し
- https://huggingface.co/litagin/anime-whisper
- デモでは**音声は15秒まで**しか受け付けません
- 日本語のみ対応 (Japanese only)
- 比較のために [openai/whisper-large-v2](https://huggingface.co/openai/whisper-large-v2) と [openai/whisper-large-v3](https://huggingface.co/openai/whisper-large-v3) と [kotoba-tech/kotoba-whisper-v2.0](https://huggingface.co/kotoba-tech/kotoba-whisper-v2.0) も用意しています
pipeに渡しているkwargsは以下の最低限のもの:
```python
generate_kwargs = {
"language": "Japanese",
"do_sample": False,
"num_beams": 1,
"no_repeat_ngram_size": 0,
"max_new_tokens": 64, # 結果が長いときは途中で打ち切る
}
```
"""
with gr.Blocks() as app:
gr.Markdown(initial_md)
audio = gr.Audio(type="filepath")
with gr.Row():
with gr.Column():
gr.Markdown("### Anime-Whisper")
button_galgame = gr.Button("Transcribe with Anime-Whisper")
output_galgame = gr.Textbox(label="Result")
gr.Markdown("### Comparison")
button_others = gr.Button("Transcribe with other models")
with gr.Row():
with gr.Column():
gr.Markdown("### Whisper-Large-V2")
output_v2 = gr.Textbox(label="Result")
with gr.Column():
gr.Markdown("### Whisper-Large-V3")
output_v3 = gr.Textbox(label="Result")
with gr.Column():
gr.Markdown("### Kotoba-Whisper-V2.0")
output_kotoba_v2 = gr.Textbox(label="Result")
button_galgame.click(
transcribe_anime_whisper,
inputs=[audio],
outputs=[output_galgame],
)
button_others.click(
transcribe_others,
inputs=[audio],
outputs=[output_v2, output_v3, output_kotoba_v2],
)
app.launch(inbrowser=True)