azeus commited on
Commit
eb6bacc
·
1 Parent(s): f3617b6

adapting to audio formats

Browse files
Files changed (1) hide show
  1. app.py +71 -48
app.py CHANGED
@@ -4,6 +4,9 @@ import torch
4
  from transformers import Wav2Vec2Processor, Wav2Vec2Model
5
  import torchaudio
6
  import io
 
 
 
7
 
8
 
9
  # Initialize model and processor
@@ -14,40 +17,68 @@ def load_model():
14
  return processor, model
15
 
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  # Audio processing function
18
  def process_audio(audio_file, processor, model):
19
- # Read audio file
20
- audio_bytes = audio_file.read()
21
- waveform, sample_rate = torchaudio.load(io.BytesIO(audio_bytes))
 
 
 
 
 
 
22
 
23
- # Resample if needed
24
- if sample_rate != 16000:
25
- resampler = torchaudio.transforms.Resample(sample_rate, 16000)
26
- waveform = resampler(waveform)
27
 
28
- # Convert to mono if stereo
29
- if waveform.shape[0] > 1:
30
- waveform = torch.mean(waveform, dim=0, keepdim=True)
31
 
32
- # Process through Wav2Vec2
33
- inputs = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt", padding=True)
34
- with torch.no_grad():
35
- outputs = model(**inputs)
36
 
37
- # Get features from last hidden states
38
- features = outputs.last_hidden_state.mean(dim=1).squeeze().numpy()
39
- return features
 
40
 
 
 
 
41
 
42
- # Simple genre classifier (we'll use a basic classifier for demonstration)
 
 
 
 
 
43
  class SimpleGenreClassifier:
44
  def __init__(self):
45
  self.genres = ["Rock", "Pop", "Hip Hop", "Classical", "Jazz"]
46
- # Simulated learned weights (in real application, these would be trained)
 
47
  self.weights = np.random.randn(768, len(self.genres))
48
 
49
  def predict(self, features):
50
- # Simple linear classification
51
  logits = np.dot(features, self.weights)
52
  probabilities = self.softmax(logits)
53
  return probabilities
@@ -72,7 +103,7 @@ except Exception as e:
72
  st.error(f"Error loading models: {str(e)}")
73
  st.stop()
74
 
75
- # Create two columns for layout
76
  col1, col2 = st.columns(2)
77
 
78
  with col1:
@@ -84,61 +115,53 @@ with col1:
84
  st.audio(audio_file)
85
  st.success("File uploaded successfully!")
86
 
 
 
 
87
  # Add classify button
88
  if st.button("Classify Genre"):
89
  try:
90
  with st.spinner("Analyzing audio..."):
91
- # Extract features using Wav2Vec2
92
  features = process_audio(audio_file, processor, wav2vec_model)
93
 
94
- # Get genre predictions
95
- probabilities = classifier.predict(features)
 
96
 
97
- # Show results
98
- st.write("### Genre Analysis Results:")
99
- for genre, prob in zip(classifier.genres, probabilities):
100
- # Create a progress bar for each genre
101
- st.write(f"{genre}:")
102
- st.progress(float(prob))
103
- st.write(f"{prob:.2%}")
 
 
 
104
 
105
- # Show top prediction
106
- top_genre = classifier.genres[np.argmax(probabilities)]
107
- st.write(f"**Predicted Genre:** {top_genre}")
108
  except Exception as e:
109
  st.error(f"Error during analysis: {str(e)}")
110
 
111
  with col2:
112
- # Display information about the model
113
  st.write("### About the Model:")
114
  st.write("""
115
  This classifier uses:
116
  - Facebook's Wav2Vec2 for audio feature extraction
117
  - Custom genre classification layer
118
- - Pre-trained on speech recognition
119
  """)
120
 
121
  st.write("### Supported Genres:")
122
  for genre in classifier.genres:
123
  st.write(f"- {genre}")
124
 
125
- # Add usage tips
126
  st.write("### Tips for best results:")
127
  st.write("- Upload clear, high-quality audio")
128
- st.write("- Ideal length: 10-30 seconds")
129
  st.write("- Avoid audio with multiple overlapping genres")
130
  st.write("- Ensure minimal background noise")
131
 
132
- # Update requirements.txt
133
- if st.sidebar.checkbox("Show requirements.txt contents"):
134
- st.sidebar.code("""
135
- streamlit==1.31.0
136
- torch==2.0.1
137
- torchaudio==2.0.1
138
- transformers==4.30.2
139
- numpy==1.24.3
140
- """)
141
-
142
  # Footer
143
  st.markdown("---")
144
  st.write("Made with ❤️ using Streamlit and Hugging Face Transformers")
 
4
  from transformers import Wav2Vec2Processor, Wav2Vec2Model
5
  import torchaudio
6
  import io
7
+ from pydub import AudioSegment
8
+ import tempfile
9
+ import os
10
 
11
 
12
  # Initialize model and processor
 
17
  return processor, model
18
 
19
 
20
+ def convert_audio_to_wav(audio_file):
21
+ """Convert uploaded audio to WAV format"""
22
+ # Read uploaded file
23
+ audio_bytes = audio_file.read()
24
+
25
+ # Create a temporary file
26
+ with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_wav:
27
+ # Convert audio using pydub
28
+ audio = AudioSegment.from_file(io.BytesIO(audio_bytes))
29
+ audio.export(temp_wav.name, format='wav')
30
+ return temp_wav.name
31
+
32
+
33
  # Audio processing function
34
  def process_audio(audio_file, processor, model):
35
+ try:
36
+ # Convert audio to WAV format
37
+ wav_path = convert_audio_to_wav(audio_file)
38
+
39
+ # Load the WAV file
40
+ waveform, sample_rate = torchaudio.load(wav_path)
41
+
42
+ # Clean up temporary file
43
+ os.remove(wav_path)
44
 
45
+ # Resample if needed
46
+ if sample_rate != 16000:
47
+ resampler = torchaudio.transforms.Resample(sample_rate, 16000)
48
+ waveform = resampler(waveform)
49
 
50
+ # Convert to mono if stereo
51
+ if waveform.shape[0] > 1:
52
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
53
 
54
+ # Limit audio length to 30 seconds
55
+ max_length = 16000 * 30 # 30 seconds at 16kHz
56
+ if waveform.shape[1] > max_length:
57
+ waveform = waveform[:, :max_length]
58
 
59
+ # Process through Wav2Vec2
60
+ inputs = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt", padding=True)
61
+ with torch.no_grad():
62
+ outputs = model(**inputs)
63
 
64
+ # Get features from last hidden states
65
+ features = outputs.last_hidden_state.mean(dim=1).squeeze().numpy()
66
+ return features
67
 
68
+ except Exception as e:
69
+ st.error(f"Error processing audio: {str(e)}")
70
+ return None
71
+
72
+
73
+ # Simple genre classifier
74
  class SimpleGenreClassifier:
75
  def __init__(self):
76
  self.genres = ["Rock", "Pop", "Hip Hop", "Classical", "Jazz"]
77
+ # Simulated learned weights
78
+ np.random.seed(42) # For consistent results
79
  self.weights = np.random.randn(768, len(self.genres))
80
 
81
  def predict(self, features):
 
82
  logits = np.dot(features, self.weights)
83
  probabilities = self.softmax(logits)
84
  return probabilities
 
103
  st.error(f"Error loading models: {str(e)}")
104
  st.stop()
105
 
106
+ # Create two columns
107
  col1, col2 = st.columns(2)
108
 
109
  with col1:
 
115
  st.audio(audio_file)
116
  st.success("File uploaded successfully!")
117
 
118
+ # Reset file pointer
119
+ audio_file.seek(0)
120
+
121
  # Add classify button
122
  if st.button("Classify Genre"):
123
  try:
124
  with st.spinner("Analyzing audio..."):
125
+ # Extract features
126
  features = process_audio(audio_file, processor, wav2vec_model)
127
 
128
+ if features is not None:
129
+ # Get predictions
130
+ probabilities = classifier.predict(features)
131
 
132
+ # Show results
133
+ st.write("### Genre Analysis Results:")
134
+ for genre, prob in zip(classifier.genres, probabilities):
135
+ st.write(f"{genre}:")
136
+ st.progress(float(prob))
137
+ st.write(f"{prob:.2%}")
138
+
139
+ # Show top prediction
140
+ top_genre = classifier.genres[np.argmax(probabilities)]
141
+ st.write(f"**Predicted Genre:** {top_genre}")
142
 
 
 
 
143
  except Exception as e:
144
  st.error(f"Error during analysis: {str(e)}")
145
 
146
  with col2:
 
147
  st.write("### About the Model:")
148
  st.write("""
149
  This classifier uses:
150
  - Facebook's Wav2Vec2 for audio feature extraction
151
  - Custom genre classification layer
152
+ - Handles MP3 and WAV formats
153
  """)
154
 
155
  st.write("### Supported Genres:")
156
  for genre in classifier.genres:
157
  st.write(f"- {genre}")
158
 
 
159
  st.write("### Tips for best results:")
160
  st.write("- Upload clear, high-quality audio")
161
+ st.write("- Best length: 10-30 seconds")
162
  st.write("- Avoid audio with multiple overlapping genres")
163
  st.write("- Ensure minimal background noise")
164
 
 
 
 
 
 
 
 
 
 
 
165
  # Footer
166
  st.markdown("---")
167
  st.write("Made with ❤️ using Streamlit and Hugging Face Transformers")