Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import pipeline | |
from config import MODEL_ID | |
# Load the model and pipeline using the model_id variable | |
pipe = pipeline("audio-classification", model=MODEL_ID) | |
def classify_audio(filepath): | |
preds = pipe(filepath) | |
outputs = {"normal": 0.0, "murmur": 0.0, "artifact": 0.0} | |
for p in preds: | |
label = p["label"].replace('_', ' ') | |
if label in outputs: | |
outputs[label] += p["score"] | |
else: | |
outputs["normal"] += p["score"] | |
return outputs | |
# Streamlit app layout | |
st.title("Heartbeat Sound Classification") | |
# Theme selection | |
theme = st.sidebar.selectbox( | |
"Select Theme", | |
["Light Green", "Light Blue", "Dark Green", "Dark Blue"] | |
) | |
# Add custom CSS for styling based on the selected theme | |
if theme == "Light Green": | |
st.markdown( | |
""" | |
<style> | |
body { | |
background-color: #e8f5e9; /* Light green background */ | |
} | |
.stApp { | |
color: #004d40; /* Dark green text */ | |
} | |
.stButton > button, .stFileUpload > div { | |
background-color: #004d40; /* Dark green button and file uploader background */ | |
color: white; /* White text */ | |
} | |
.stButton > button:hover, .stFileUpload > div:hover { | |
background-color: #00332c; /* Darker green on hover */ | |
} | |
</style> | |
""", | |
unsafe_allow_html=True | |
) | |
elif theme == "Light Blue": | |
st.markdown( | |
""" | |
<style> | |
body { | |
background-color: #e0f7fa; /* Light blue background */ | |
} | |
.stApp { | |
color: #006064; /* Dark blue text */ | |
} | |
.stButton > button, .stFileUpload > div { | |
background-color: #006064; /* Dark blue button and file uploader background */ | |
color: white; /* White text */ | |
} | |
.stButton > button:hover, .stFileUpload > div:hover { | |
background-color: #004d40; /* Darker blue on hover */ | |
} | |
</style> | |
""", | |
unsafe_allow_html=True | |
) | |
elif theme == "Dark Green": | |
st.markdown( | |
""" | |
<style> | |
body { | |
background-color: #1b5e20; /* Dark green background */ | |
} | |
.stApp { | |
color: #a5d6a7; /* Light green text */ | |
} | |
.stButton > button, .stFileUpload > div { | |
background-color: #004d40; /* Dark green button and file uploader background */ | |
color: white; /* White text */ | |
} | |
.stButton > button:hover, .stFileUpload > div:hover { | |
background-color: #00332c; /* Darker green on hover */ | |
} | |
</style> | |
""", | |
unsafe_allow_html=True | |
) | |
elif theme == "Dark Blue": | |
st.markdown( | |
""" | |
<style> | |
body { | |
background-color: #0d47a1; /* Dark blue background */ | |
} | |
.stApp { | |
color: #bbdefb; /* Light blue text */ | |
} | |
.stButton > button, .stFileUpload > div { | |
background-color: #006064; /* Dark blue button and file uploader background */ | |
color: white; /* White text */ | |
} | |
.stButton > button:hover, .stFileUpload > div:hover { | |
background-color: #004d40; /* Darker blue on hover */ | |
} | |
</style> | |
""", | |
unsafe_allow_html=True | |
) | |
# File uploader for audio files | |
uploaded_file = st.file_uploader("Upload an audio file", type=["wav", "mp3"]) | |
if uploaded_file is not None: | |
st.subheader("Uploaded Audio File") | |
# Load and display the audio file | |
audio_bytes = uploaded_file.read() | |
st.audio(audio_bytes, format='audio/wav') | |
# Save the uploaded file to a temporary location | |
with open("temp_audio_file.wav", "wb") as f: | |
f.write(audio_bytes) | |
# Classify the audio file | |
st.write("Classifying the audio...") | |
results = classify_audio("temp_audio_file.wav") | |
# Display the classification results in a dedicated output box | |
st.subheader("Classification Results") | |
results_box = st.empty() | |
results_str = "\n".join([f"{label}: {score:.2f}" for label, score in results.items()]) | |
results_box.text(results_str) | |
# Audio Test Samples for classification | |
st.write("Audio Test Samples:") | |
examples = ['normal.wav', 'murmur.wav', 'extra_systole.wav', 'extra_hystole.wav', 'artifact.wav'] | |
# Determine the number of columns based on the screen size | |
is_mobile = st.session_state.get("is_mobile", False) | |
num_columns = 1 if is_mobile else 3 | |
# Arrange buttons in the columns | |
cols = st.columns(num_columns) | |
for idx, example in enumerate(examples): | |
col = cols[idx % num_columns] # Rotate columns for better arrangement | |
if col.button(example): | |
col.subheader(f"Sample Audio: {example}") | |
audio_bytes = open(example, 'rb').read() | |
col.audio(audio_bytes, format='audio/wav') | |
results = classify_audio(example) | |
col.write("Results:") | |
results_str = "\n.join([f"{label}: {score:.2f}" for label, score in results.items()]) | |
col.text(results_str) | |
# JavaScript to detect if the user is on a mobile device | |
st.markdown( | |
""" | |
<script> | |
const isMobile = /iPhone|iPad|iPod|Android/i.test(navigator.userAgent); | |
window.parent.postMessage({type: 'streamlit:storeSessionState', key: 'is_mobile', value: isMobile}, '*'); | |
</script> | |
""", | |
unsafe_allow_html=True | |
) | |