genre_classify / app.py
azeus
adding fb model
1e2e376
raw
history blame
4.63 kB
import streamlit as st
import numpy as np
import torch
from transformers import Wav2Vec2Processor, Wav2Vec2Model
import torchaudio
import io
# Initialize model and processor
@st.cache_resource
def load_model():
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base")
model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base")
return processor, model
# Audio processing function
def process_audio(audio_file, processor, model):
# Read audio file
audio_bytes = audio_file.read()
waveform, sample_rate = torchaudio.load(io.BytesIO(audio_bytes))
# Resample if needed
if sample_rate != 16000:
resampler = torchaudio.transforms.Resample(sample_rate, 16000)
waveform = resampler(waveform)
# Convert to mono if stereo
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
# Process through Wav2Vec2
inputs = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt", padding=True)
with torch.no_grad():
outputs = model(**inputs)
# Get features from last hidden states
features = outputs.last_hidden_state.mean(dim=1).squeeze().numpy()
return features
# Simple genre classifier (we'll use a basic classifier for demonstration)
class SimpleGenreClassifier:
def __init__(self):
self.genres = ["Rock", "Pop", "Hip Hop", "Classical", "Jazz"]
# Simulated learned weights (in real application, these would be trained)
self.weights = np.random.randn(768, len(self.genres))
def predict(self, features):
# Simple linear classification
logits = np.dot(features, self.weights)
probabilities = self.softmax(logits)
return probabilities
@staticmethod
def softmax(x):
exp_x = np.exp(x - np.max(x))
return exp_x / exp_x.sum()
# Page setup
st.title("🎵 Music Genre Classifier")
st.write("Upload an audio file to analyze its genre using Wav2Vec2")
# Load models
try:
with st.spinner("Loading models..."):
processor, wav2vec_model = load_model()
classifier = SimpleGenreClassifier()
st.success("Models loaded successfully!")
except Exception as e:
st.error(f"Error loading models: {str(e)}")
st.stop()
# Create two columns for layout
col1, col2 = st.columns(2)
with col1:
# File upload
audio_file = st.file_uploader("Upload an audio file (MP3, WAV)", type=['mp3', 'wav'])
if audio_file is not None:
# Display audio player
st.audio(audio_file)
st.success("File uploaded successfully!")
# Add classify button
if st.button("Classify Genre"):
try:
with st.spinner("Analyzing audio..."):
# Extract features using Wav2Vec2
features = process_audio(audio_file, processor, wav2vec_model)
# Get genre predictions
probabilities = classifier.predict(features)
# Show results
st.write("### Genre Analysis Results:")
for genre, prob in zip(classifier.genres, probabilities):
# Create a progress bar for each genre
st.write(f"{genre}:")
st.progress(float(prob))
st.write(f"{prob:.2%}")
# Show top prediction
top_genre = classifier.genres[np.argmax(probabilities)]
st.write(f"**Predicted Genre:** {top_genre}")
except Exception as e:
st.error(f"Error during analysis: {str(e)}")
with col2:
# Display information about the model
st.write("### About the Model:")
st.write("""
This classifier uses:
- Facebook's Wav2Vec2 for audio feature extraction
- Custom genre classification layer
- Pre-trained on speech recognition
""")
st.write("### Supported Genres:")
for genre in classifier.genres:
st.write(f"- {genre}")
# Add usage tips
st.write("### Tips for best results:")
st.write("- Upload clear, high-quality audio")
st.write("- Ideal length: 10-30 seconds")
st.write("- Avoid audio with multiple overlapping genres")
st.write("- Ensure minimal background noise")
# Update requirements.txt
if st.sidebar.checkbox("Show requirements.txt contents"):
st.sidebar.code("""
streamlit==1.31.0
torch==2.0.1
torchaudio==2.0.1
transformers==4.30.2
numpy==1.24.3
""")
# Footer
st.markdown("---")
st.write("Made with ❤️ using Streamlit and Hugging Face Transformers")