Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import pipeline | |
from config import MODEL_ID | |
# Load the model and pipeline using the model_id variable | |
pipe = pipeline("audio-classification", model=MODEL_ID) | |
def classify_audio(filepath): | |
preds = pipe(filepath) | |
outputs = {"normal": 0.0, "murmur": 0.0, "artifact": 0.0} | |
for p in preds: | |
label = p["label"].replace('_', ' ') | |
if label in outputs: | |
outputs[label] += p["score"] | |
else: | |
outputs["normal"] += p["score"] | |
return outputs | |
# Streamlit app layout | |
st.title("Heartbeat Sound Classification") | |
# File uploader for audio files | |
uploaded_file = st.file_uploader("Upload an audio file", type=["wav", "mp3"]) | |
if uploaded_file is not None: | |
st.subheader("Uploaded Audio File") | |
# Load and display the audio file | |
audio_bytes = uploaded_file.read() | |
st.audio(audio_bytes, format='audio/wav') | |
# Save the uploaded file to a temporary location | |
with open("temp_audio_file.wav", "wb") as f: | |
f.write(audio_bytes) | |
# Classify the audio file | |
st.write("Classifying the audio...") | |
results = classify_audio("temp_audio_file.wav") | |
# Display the classification results in a dedicated output box | |
st.subheader("Classification Results") | |
results_box = st.empty() | |
results_str = "\n".join([f"{label}: {score:.2f}" for label, score in results.items()]) | |
results_box.text(results_str) | |
# Audio Test Samples for classification | |
st.write("Audio Test Samples:") | |
examples = ['normal.wav', 'murmur.wav', 'extra_systole.wav', 'extra_hystole.wav', 'artifact.wav'] | |
cols = st.columns(3) | |
for idx, example in enumerate(examples): | |
col = cols[idx % 3] # Rotate columns for better arrangement | |
if col.button(example): | |
col.subheader(f"Sample Audio: {example}") | |
audio_bytes = open(example, 'rb').read() | |
col.audio(audio_bytes, format='audio/wav') | |
results = classify_audio(example) | |
col.write("Results:") | |
results_str = "\n".join([f"{label}: {score:.2f}" for label, score in results.items()]) | |
col.text(results_str) | |