Spaces:
Running
Running
import streamlit as st | |
import numpy as np | |
import torch | |
from transformers import Wav2Vec2Processor, Wav2Vec2Model | |
import torchaudio | |
import io | |
from pydub import AudioSegment | |
import tempfile | |
import os | |
# Initialize model and processor | |
def load_model(): | |
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base") | |
model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base") | |
return processor, model | |
def convert_audio_to_wav(audio_file): | |
"""Convert uploaded audio to WAV format""" | |
# Read uploaded file | |
audio_bytes = audio_file.read() | |
# Create a temporary file | |
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_wav: | |
# Convert audio using pydub | |
audio = AudioSegment.from_file(io.BytesIO(audio_bytes)) | |
audio.export(temp_wav.name, format='wav') | |
return temp_wav.name | |
# Audio processing function | |
def process_audio(audio_file, processor, model): | |
try: | |
# Convert audio to WAV format | |
wav_path = convert_audio_to_wav(audio_file) | |
# Load the WAV file | |
waveform, sample_rate = torchaudio.load(wav_path) | |
# Clean up temporary file | |
os.remove(wav_path) | |
# Resample if needed | |
if sample_rate != 16000: | |
resampler = torchaudio.transforms.Resample(sample_rate, 16000) | |
waveform = resampler(waveform) | |
# Convert to mono if stereo | |
if waveform.shape[0] > 1: | |
waveform = torch.mean(waveform, dim=0, keepdim=True) | |
# Limit audio length to 30 seconds | |
max_length = 16000 * 30 # 30 seconds at 16kHz | |
if waveform.shape[1] > max_length: | |
waveform = waveform[:, :max_length] | |
# Process through Wav2Vec2 | |
inputs = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt", padding=True) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
# Get features from last hidden states | |
features = outputs.last_hidden_state.mean(dim=1).squeeze().numpy() | |
return features | |
except Exception as e: | |
st.error(f"Error processing audio: {str(e)}") | |
return None | |
# Simple genre classifier | |
class SimpleGenreClassifier: | |
def __init__(self): | |
self.genres = ["Rock", "Pop", "Hip Hop", "Classical", "Jazz"] | |
# Simulated learned weights | |
np.random.seed(42) # For consistent results | |
self.weights = np.random.randn(768, len(self.genres)) | |
def predict(self, features): | |
logits = np.dot(features, self.weights) | |
probabilities = self.softmax(logits) | |
return probabilities | |
def softmax(x): | |
exp_x = np.exp(x - np.max(x)) | |
return exp_x / exp_x.sum() | |
# Page setup | |
st.title("🎵 Music Genre Classifier") | |
st.write("Upload an audio file to analyze its genre using Wav2Vec2") | |
# Load models | |
try: | |
with st.spinner("Loading models..."): | |
processor, wav2vec_model = load_model() | |
classifier = SimpleGenreClassifier() | |
st.success("Models loaded successfully!") | |
except Exception as e: | |
st.error(f"Error loading models: {str(e)}") | |
st.stop() | |
# Create two columns | |
col1, col2 = st.columns(2) | |
with col1: | |
# File upload | |
audio_file = st.file_uploader("Upload an audio file (MP3, WAV)", type=['mp3', 'wav']) | |
if audio_file is not None: | |
# Display audio player | |
st.audio(audio_file) | |
st.success("File uploaded successfully!") | |
# Reset file pointer | |
audio_file.seek(0) | |
# Add classify button | |
if st.button("Classify Genre"): | |
try: | |
with st.spinner("Analyzing audio..."): | |
# Extract features | |
features = process_audio(audio_file, processor, wav2vec_model) | |
if features is not None: | |
# Get predictions | |
probabilities = classifier.predict(features) | |
# Show results | |
st.write("### Genre Analysis Results:") | |
for genre, prob in zip(classifier.genres, probabilities): | |
st.write(f"{genre}:") | |
st.progress(float(prob)) | |
st.write(f"{prob:.2%}") | |
# Show top prediction | |
top_genre = classifier.genres[np.argmax(probabilities)] | |
st.write(f"**Predicted Genre:** {top_genre}") | |
except Exception as e: | |
st.error(f"Error during analysis: {str(e)}") | |
with col2: | |
st.write("### About the Model:") | |
st.write(""" | |
This classifier uses: | |
- Facebook's Wav2Vec2 for audio feature extraction | |
- Custom genre classification layer | |
- Handles MP3 and WAV formats | |
""") | |
st.write("### Supported Genres:") | |
for genre in classifier.genres: | |
st.write(f"- {genre}") | |
st.write("### Tips for best results:") | |
st.write("- Upload clear, high-quality audio") | |
st.write("- Best length: 10-30 seconds") | |
st.write("- Avoid audio with multiple overlapping genres") | |
st.write("- Ensure minimal background noise") | |
# Footer | |
st.markdown("---") | |
st.write("Made with ❤️ using Streamlit and Hugging Face Transformers") |