File size: 4,632 Bytes
97751f5
c00ec95
1e2e376
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97751f5
c00ec95
b6a2312
1e2e376
97751f5
1e2e376
 
 
 
 
 
 
 
 
 
 
c00ec95
97751f5
c00ec95
 
 
97751f5
c00ec95
 
 
 
 
1e2e376
c00ec95
1e2e376
 
 
 
 
 
 
c00ec95
1e2e376
 
 
 
 
 
 
c00ec95
1e2e376
 
 
 
 
c00ec95
 
1e2e376
 
 
 
 
 
 
 
 
 
 
 
 
 
c00ec95
1e2e376
c00ec95
1e2e376
 
c00ec95
1e2e376
 
 
 
 
 
 
 
 
c00ec95
 
 
1e2e376
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import streamlit as st
import numpy as np
import torch
from transformers import Wav2Vec2Processor, Wav2Vec2Model
import torchaudio
import io


# Initialize model and processor
@st.cache_resource
def load_model():
    processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base")
    model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base")
    return processor, model


# Audio processing function
def process_audio(audio_file, processor, model):
    # Read audio file
    audio_bytes = audio_file.read()
    waveform, sample_rate = torchaudio.load(io.BytesIO(audio_bytes))

    # Resample if needed
    if sample_rate != 16000:
        resampler = torchaudio.transforms.Resample(sample_rate, 16000)
        waveform = resampler(waveform)

    # Convert to mono if stereo
    if waveform.shape[0] > 1:
        waveform = torch.mean(waveform, dim=0, keepdim=True)

    # Process through Wav2Vec2
    inputs = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt", padding=True)
    with torch.no_grad():
        outputs = model(**inputs)

    # Get features from last hidden states
    features = outputs.last_hidden_state.mean(dim=1).squeeze().numpy()
    return features


# Simple genre classifier (we'll use a basic classifier for demonstration)
class SimpleGenreClassifier:
    def __init__(self):
        self.genres = ["Rock", "Pop", "Hip Hop", "Classical", "Jazz"]
        # Simulated learned weights (in real application, these would be trained)
        self.weights = np.random.randn(768, len(self.genres))

    def predict(self, features):
        # Simple linear classification
        logits = np.dot(features, self.weights)
        probabilities = self.softmax(logits)
        return probabilities

    @staticmethod
    def softmax(x):
        exp_x = np.exp(x - np.max(x))
        return exp_x / exp_x.sum()


# Page setup
st.title("🎵 Music Genre Classifier")
st.write("Upload an audio file to analyze its genre using Wav2Vec2")

# Load models
try:
    with st.spinner("Loading models..."):
        processor, wav2vec_model = load_model()
        classifier = SimpleGenreClassifier()
    st.success("Models loaded successfully!")
except Exception as e:
    st.error(f"Error loading models: {str(e)}")
    st.stop()

# Create two columns for layout
col1, col2 = st.columns(2)

with col1:
    # File upload
    audio_file = st.file_uploader("Upload an audio file (MP3, WAV)", type=['mp3', 'wav'])

    if audio_file is not None:
        # Display audio player
        st.audio(audio_file)
        st.success("File uploaded successfully!")

        # Add classify button
        if st.button("Classify Genre"):
            try:
                with st.spinner("Analyzing audio..."):
                    # Extract features using Wav2Vec2
                    features = process_audio(audio_file, processor, wav2vec_model)

                    # Get genre predictions
                    probabilities = classifier.predict(features)

                    # Show results
                    st.write("### Genre Analysis Results:")
                    for genre, prob in zip(classifier.genres, probabilities):
                        # Create a progress bar for each genre
                        st.write(f"{genre}:")
                        st.progress(float(prob))
                        st.write(f"{prob:.2%}")

                    # Show top prediction
                    top_genre = classifier.genres[np.argmax(probabilities)]
                    st.write(f"**Predicted Genre:** {top_genre}")
            except Exception as e:
                st.error(f"Error during analysis: {str(e)}")

with col2:
    # Display information about the model
    st.write("### About the Model:")
    st.write("""
    This classifier uses:
    - Facebook's Wav2Vec2 for audio feature extraction
    - Custom genre classification layer
    - Pre-trained on speech recognition
    """)

    st.write("### Supported Genres:")
    for genre in classifier.genres:
        st.write(f"- {genre}")

    # Add usage tips
    st.write("### Tips for best results:")
    st.write("- Upload clear, high-quality audio")
    st.write("- Ideal length: 10-30 seconds")
    st.write("- Avoid audio with multiple overlapping genres")
    st.write("- Ensure minimal background noise")

# Update requirements.txt
if st.sidebar.checkbox("Show requirements.txt contents"):
    st.sidebar.code("""
    streamlit==1.31.0
    torch==2.0.1
    torchaudio==2.0.1
    transformers==4.30.2
    numpy==1.24.3
    """)

# Footer
st.markdown("---")
st.write("Made with ❤️ using Streamlit and Hugging Face Transformers")