Delik's picture
Update app.py
93d373e verified
raw
history blame
2.27 kB
import torch
import spaces
import gradio as gr
import os
from pyannote.audio import Pipeline
# instantiate the pipeline
try:
pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.1",
use_auth_token=os.environ["api"]
)
# Move the pipeline to the GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pipeline.to(device)
except Exception as e:
print(f"Error initializing pipeline: {e}")
pipeline = None
def save_audio(audio):
if pipeline is None:
return "Error: Pipeline not initialized"
# Read the uploaded audio file as bytes
with open(audio, "rb") as f:
audio_data = f.read()
# Save the uploaded audio file to a temporary location
with open("temp.wav", "wb") as f:
f.write(audio_data)
return "temp.wav"
@spaces.GPU
def diarize_audio(temp_file, num_speakers, min_speakers, max_speakers):
if pipeline is None:
return "Error: Pipeline not initialized"
try:
params = {}
if num_speakers > 0:
params["num_speakers"] = num_speakers
if min_speakers > 0:
params["min_speakers"] = min_speakers
if max_speakers > 0:
params["max_speakers"] = max_speakers
diarization = pipeline(temp_file, **params)
except Exception as e:
return f"Error processing audio: {e}"
# Remove the temporary file
os.remove(temp_file)
# Return the diarization output
return str(diarization)
with gr.Blocks() as demo:
audio_input = gr.Audio(type="filepath", label="Upload Audio")
num_speakers_input = gr.Number(label="Number of Speakers", value=0)
min_speakers_input = gr.Number(label="Minimum Number of Speakers", value=0)
max_speakers_input = gr.Number(label="Maximum Number of Speakers", value=0)
process_button = gr.Button("Process")
diarization_output = gr.Textbox(label="Diarization Output")
process_button.click(
fn=lambda audio, num_speakers, min_speakers, max_speakers:
diarize_audio(save_audio(audio), num_speakers, min_speakers, max_speakers),
inputs=[audio_input, num_speakers_input, min_speakers_input, max_speakers_input],
outputs=diarization_output
)
demo.launch()