Spaces:
Sleeping
Sleeping
File size: 4,632 Bytes
97751f5 c00ec95 1e2e376 97751f5 c00ec95 b6a2312 1e2e376 97751f5 1e2e376 c00ec95 97751f5 c00ec95 97751f5 c00ec95 1e2e376 c00ec95 1e2e376 c00ec95 1e2e376 c00ec95 1e2e376 c00ec95 1e2e376 c00ec95 1e2e376 c00ec95 1e2e376 c00ec95 1e2e376 c00ec95 1e2e376 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import streamlit as st
import numpy as np
import torch
from transformers import Wav2Vec2Processor, Wav2Vec2Model
import torchaudio
import io
# Initialize model and processor
@st.cache_resource
def load_model():
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base")
model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base")
return processor, model
# Audio processing function
def process_audio(audio_file, processor, model):
# Read audio file
audio_bytes = audio_file.read()
waveform, sample_rate = torchaudio.load(io.BytesIO(audio_bytes))
# Resample if needed
if sample_rate != 16000:
resampler = torchaudio.transforms.Resample(sample_rate, 16000)
waveform = resampler(waveform)
# Convert to mono if stereo
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
# Process through Wav2Vec2
inputs = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt", padding=True)
with torch.no_grad():
outputs = model(**inputs)
# Get features from last hidden states
features = outputs.last_hidden_state.mean(dim=1).squeeze().numpy()
return features
# Simple genre classifier (we'll use a basic classifier for demonstration)
class SimpleGenreClassifier:
def __init__(self):
self.genres = ["Rock", "Pop", "Hip Hop", "Classical", "Jazz"]
# Simulated learned weights (in real application, these would be trained)
self.weights = np.random.randn(768, len(self.genres))
def predict(self, features):
# Simple linear classification
logits = np.dot(features, self.weights)
probabilities = self.softmax(logits)
return probabilities
@staticmethod
def softmax(x):
exp_x = np.exp(x - np.max(x))
return exp_x / exp_x.sum()
# Page setup
st.title("🎵 Music Genre Classifier")
st.write("Upload an audio file to analyze its genre using Wav2Vec2")
# Load models
try:
with st.spinner("Loading models..."):
processor, wav2vec_model = load_model()
classifier = SimpleGenreClassifier()
st.success("Models loaded successfully!")
except Exception as e:
st.error(f"Error loading models: {str(e)}")
st.stop()
# Create two columns for layout
col1, col2 = st.columns(2)
with col1:
# File upload
audio_file = st.file_uploader("Upload an audio file (MP3, WAV)", type=['mp3', 'wav'])
if audio_file is not None:
# Display audio player
st.audio(audio_file)
st.success("File uploaded successfully!")
# Add classify button
if st.button("Classify Genre"):
try:
with st.spinner("Analyzing audio..."):
# Extract features using Wav2Vec2
features = process_audio(audio_file, processor, wav2vec_model)
# Get genre predictions
probabilities = classifier.predict(features)
# Show results
st.write("### Genre Analysis Results:")
for genre, prob in zip(classifier.genres, probabilities):
# Create a progress bar for each genre
st.write(f"{genre}:")
st.progress(float(prob))
st.write(f"{prob:.2%}")
# Show top prediction
top_genre = classifier.genres[np.argmax(probabilities)]
st.write(f"**Predicted Genre:** {top_genre}")
except Exception as e:
st.error(f"Error during analysis: {str(e)}")
with col2:
# Display information about the model
st.write("### About the Model:")
st.write("""
This classifier uses:
- Facebook's Wav2Vec2 for audio feature extraction
- Custom genre classification layer
- Pre-trained on speech recognition
""")
st.write("### Supported Genres:")
for genre in classifier.genres:
st.write(f"- {genre}")
# Add usage tips
st.write("### Tips for best results:")
st.write("- Upload clear, high-quality audio")
st.write("- Ideal length: 10-30 seconds")
st.write("- Avoid audio with multiple overlapping genres")
st.write("- Ensure minimal background noise")
# Update requirements.txt
if st.sidebar.checkbox("Show requirements.txt contents"):
st.sidebar.code("""
streamlit==1.31.0
torch==2.0.1
torchaudio==2.0.1
transformers==4.30.2
numpy==1.24.3
""")
# Footer
st.markdown("---")
st.write("Made with ❤️ using Streamlit and Hugging Face Transformers") |