genre_classify / app.py
azeus
adapting to audio formats
eb6bacc
raw
history blame
5.28 kB
import streamlit as st
import numpy as np
import torch
from transformers import Wav2Vec2Processor, Wav2Vec2Model
import torchaudio
import io
from pydub import AudioSegment
import tempfile
import os
# Initialize model and processor
@st.cache_resource
def load_model():
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base")
model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base")
return processor, model
def convert_audio_to_wav(audio_file):
"""Convert uploaded audio to WAV format"""
# Read uploaded file
audio_bytes = audio_file.read()
# Create a temporary file
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_wav:
# Convert audio using pydub
audio = AudioSegment.from_file(io.BytesIO(audio_bytes))
audio.export(temp_wav.name, format='wav')
return temp_wav.name
# Audio processing function
def process_audio(audio_file, processor, model):
try:
# Convert audio to WAV format
wav_path = convert_audio_to_wav(audio_file)
# Load the WAV file
waveform, sample_rate = torchaudio.load(wav_path)
# Clean up temporary file
os.remove(wav_path)
# Resample if needed
if sample_rate != 16000:
resampler = torchaudio.transforms.Resample(sample_rate, 16000)
waveform = resampler(waveform)
# Convert to mono if stereo
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
# Limit audio length to 30 seconds
max_length = 16000 * 30 # 30 seconds at 16kHz
if waveform.shape[1] > max_length:
waveform = waveform[:, :max_length]
# Process through Wav2Vec2
inputs = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt", padding=True)
with torch.no_grad():
outputs = model(**inputs)
# Get features from last hidden states
features = outputs.last_hidden_state.mean(dim=1).squeeze().numpy()
return features
except Exception as e:
st.error(f"Error processing audio: {str(e)}")
return None
# Simple genre classifier
class SimpleGenreClassifier:
def __init__(self):
self.genres = ["Rock", "Pop", "Hip Hop", "Classical", "Jazz"]
# Simulated learned weights
np.random.seed(42) # For consistent results
self.weights = np.random.randn(768, len(self.genres))
def predict(self, features):
logits = np.dot(features, self.weights)
probabilities = self.softmax(logits)
return probabilities
@staticmethod
def softmax(x):
exp_x = np.exp(x - np.max(x))
return exp_x / exp_x.sum()
# Page setup
st.title("🎵 Music Genre Classifier")
st.write("Upload an audio file to analyze its genre using Wav2Vec2")
# Load models
try:
with st.spinner("Loading models..."):
processor, wav2vec_model = load_model()
classifier = SimpleGenreClassifier()
st.success("Models loaded successfully!")
except Exception as e:
st.error(f"Error loading models: {str(e)}")
st.stop()
# Create two columns
col1, col2 = st.columns(2)
with col1:
# File upload
audio_file = st.file_uploader("Upload an audio file (MP3, WAV)", type=['mp3', 'wav'])
if audio_file is not None:
# Display audio player
st.audio(audio_file)
st.success("File uploaded successfully!")
# Reset file pointer
audio_file.seek(0)
# Add classify button
if st.button("Classify Genre"):
try:
with st.spinner("Analyzing audio..."):
# Extract features
features = process_audio(audio_file, processor, wav2vec_model)
if features is not None:
# Get predictions
probabilities = classifier.predict(features)
# Show results
st.write("### Genre Analysis Results:")
for genre, prob in zip(classifier.genres, probabilities):
st.write(f"{genre}:")
st.progress(float(prob))
st.write(f"{prob:.2%}")
# Show top prediction
top_genre = classifier.genres[np.argmax(probabilities)]
st.write(f"**Predicted Genre:** {top_genre}")
except Exception as e:
st.error(f"Error during analysis: {str(e)}")
with col2:
st.write("### About the Model:")
st.write("""
This classifier uses:
- Facebook's Wav2Vec2 for audio feature extraction
- Custom genre classification layer
- Handles MP3 and WAV formats
""")
st.write("### Supported Genres:")
for genre in classifier.genres:
st.write(f"- {genre}")
st.write("### Tips for best results:")
st.write("- Upload clear, high-quality audio")
st.write("- Best length: 10-30 seconds")
st.write("- Avoid audio with multiple overlapping genres")
st.write("- Ensure minimal background noise")
# Footer
st.markdown("---")
st.write("Made with ❤️ using Streamlit and Hugging Face Transformers")