Spaces:
Sleeping
Sleeping
import streamlit as st | |
import numpy as np | |
import torch | |
from transformers import Wav2Vec2Processor, Wav2Vec2Model | |
import torchaudio | |
import io | |
# Initialize model and processor | |
def load_model(): | |
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base") | |
model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base") | |
return processor, model | |
# Audio processing function | |
def process_audio(audio_file, processor, model): | |
# Read audio file | |
audio_bytes = audio_file.read() | |
waveform, sample_rate = torchaudio.load(io.BytesIO(audio_bytes)) | |
# Resample if needed | |
if sample_rate != 16000: | |
resampler = torchaudio.transforms.Resample(sample_rate, 16000) | |
waveform = resampler(waveform) | |
# Convert to mono if stereo | |
if waveform.shape[0] > 1: | |
waveform = torch.mean(waveform, dim=0, keepdim=True) | |
# Process through Wav2Vec2 | |
inputs = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt", padding=True) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
# Get features from last hidden states | |
features = outputs.last_hidden_state.mean(dim=1).squeeze().numpy() | |
return features | |
# Simple genre classifier (we'll use a basic classifier for demonstration) | |
class SimpleGenreClassifier: | |
def __init__(self): | |
self.genres = ["Rock", "Pop", "Hip Hop", "Classical", "Jazz"] | |
# Simulated learned weights (in real application, these would be trained) | |
self.weights = np.random.randn(768, len(self.genres)) | |
def predict(self, features): | |
# Simple linear classification | |
logits = np.dot(features, self.weights) | |
probabilities = self.softmax(logits) | |
return probabilities | |
def softmax(x): | |
exp_x = np.exp(x - np.max(x)) | |
return exp_x / exp_x.sum() | |
# Page setup | |
st.title("🎵 Music Genre Classifier") | |
st.write("Upload an audio file to analyze its genre using Wav2Vec2") | |
# Load models | |
try: | |
with st.spinner("Loading models..."): | |
processor, wav2vec_model = load_model() | |
classifier = SimpleGenreClassifier() | |
st.success("Models loaded successfully!") | |
except Exception as e: | |
st.error(f"Error loading models: {str(e)}") | |
st.stop() | |
# Create two columns for layout | |
col1, col2 = st.columns(2) | |
with col1: | |
# File upload | |
audio_file = st.file_uploader("Upload an audio file (MP3, WAV)", type=['mp3', 'wav']) | |
if audio_file is not None: | |
# Display audio player | |
st.audio(audio_file) | |
st.success("File uploaded successfully!") | |
# Add classify button | |
if st.button("Classify Genre"): | |
try: | |
with st.spinner("Analyzing audio..."): | |
# Extract features using Wav2Vec2 | |
features = process_audio(audio_file, processor, wav2vec_model) | |
# Get genre predictions | |
probabilities = classifier.predict(features) | |
# Show results | |
st.write("### Genre Analysis Results:") | |
for genre, prob in zip(classifier.genres, probabilities): | |
# Create a progress bar for each genre | |
st.write(f"{genre}:") | |
st.progress(float(prob)) | |
st.write(f"{prob:.2%}") | |
# Show top prediction | |
top_genre = classifier.genres[np.argmax(probabilities)] | |
st.write(f"**Predicted Genre:** {top_genre}") | |
except Exception as e: | |
st.error(f"Error during analysis: {str(e)}") | |
with col2: | |
# Display information about the model | |
st.write("### About the Model:") | |
st.write(""" | |
This classifier uses: | |
- Facebook's Wav2Vec2 for audio feature extraction | |
- Custom genre classification layer | |
- Pre-trained on speech recognition | |
""") | |
st.write("### Supported Genres:") | |
for genre in classifier.genres: | |
st.write(f"- {genre}") | |
# Add usage tips | |
st.write("### Tips for best results:") | |
st.write("- Upload clear, high-quality audio") | |
st.write("- Ideal length: 10-30 seconds") | |
st.write("- Avoid audio with multiple overlapping genres") | |
st.write("- Ensure minimal background noise") | |
# Update requirements.txt | |
if st.sidebar.checkbox("Show requirements.txt contents"): | |
st.sidebar.code(""" | |
streamlit==1.31.0 | |
torch==2.0.1 | |
torchaudio==2.0.1 | |
transformers==4.30.2 | |
numpy==1.24.3 | |
""") | |
# Footer | |
st.markdown("---") | |
st.write("Made with ❤️ using Streamlit and Hugging Face Transformers") |