Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -10,57 +10,13 @@ import io
|
|
10 |
import matplotlib.pyplot as plt
|
11 |
import librosa.display
|
12 |
from PIL import Image # For image conversion
|
13 |
-
from datetime import datetime
|
14 |
|
15 |
-
#
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
# Update the path below to your Firebase service account key JSON file
|
23 |
-
SERVICE_ACCOUNT_KEY = "serviceAccountKey.json" # <-- Ensure this file exists in your project directory
|
24 |
-
if not os.path.exists(SERVICE_ACCOUNT_KEY):
|
25 |
-
raise FileNotFoundError(f"Firebase credentials file {SERVICE_ACCOUNT_KEY} not found!")
|
26 |
-
|
27 |
-
# Initialize Firebase Admin for Realtime Database
|
28 |
-
cred = credentials.Certificate(SERVICE_ACCOUNT_KEY)
|
29 |
-
firebase_admin.initialize_app(cred, {
|
30 |
-
'databaseURL': 'https://sentiment-analysis-7562e-default-rtdb.firebaseio.com/' # <-- Update with your actual DB URL
|
31 |
-
})
|
32 |
-
|
33 |
-
def upload_file_to_firebase(file_path, destination_blob_name):
|
34 |
-
"""
|
35 |
-
Uploads a file to Firebase Storage and returns its public URL.
|
36 |
-
"""
|
37 |
-
# Update the bucket name to match your Firebase Storage bucket (usually: your-project-id.appspot.com)
|
38 |
-
bucket_name = "sentiment-analysis-7562e.appspot.com" # <-- Update with your storage bucket name
|
39 |
-
storage_client = storage.Client.from_service_account_json(SERVICE_ACCOUNT_KEY)
|
40 |
-
bucket = storage_client.bucket(bucket_name)
|
41 |
-
blob = bucket.blob(destination_blob_name)
|
42 |
-
blob.upload_from_filename(file_path)
|
43 |
-
blob.make_public()
|
44 |
-
print(f"File uploaded to {blob.public_url}")
|
45 |
-
return blob.public_url
|
46 |
-
|
47 |
-
def store_prediction_metadata(file_url, predicted_emotion):
|
48 |
-
"""
|
49 |
-
Stores the file URL, predicted emotion, and timestamp in Firebase Realtime Database.
|
50 |
-
"""
|
51 |
-
ref = db.reference('predictions')
|
52 |
-
data = {
|
53 |
-
'file_url': file_url,
|
54 |
-
'predicted_emotion': predicted_emotion,
|
55 |
-
'timestamp': datetime.now().isoformat()
|
56 |
-
}
|
57 |
-
new_record_ref = ref.push(data)
|
58 |
-
print(f"Stored metadata with key: {new_record_ref.key}")
|
59 |
-
return new_record_ref.key
|
60 |
-
|
61 |
-
# ---------------------------
|
62 |
-
# Emotion Recognition Code
|
63 |
-
# ---------------------------
|
64 |
|
65 |
# Mapping from emotion labels to emojis
|
66 |
emotion_to_emoji = {
|
@@ -79,7 +35,7 @@ def add_emoji_to_label(label):
|
|
79 |
emoji = emotion_to_emoji.get(label.lower(), "")
|
80 |
return f"{label.capitalize()} {emoji}"
|
81 |
|
82 |
-
# Load the pre-trained SpeechBrain classifier
|
83 |
classifier = foreign_class(
|
84 |
source="speechbrain/emotion-recognition-wav2vec2-IEMOCAP",
|
85 |
pymodule_file="custom_interface.py",
|
@@ -96,11 +52,6 @@ def preprocess_audio(audio_file, apply_noise_reduction=False):
|
|
96 |
Saves the processed audio to a temporary file and returns its path.
|
97 |
"""
|
98 |
y, sr = librosa.load(audio_file, sr=16000, mono=True)
|
99 |
-
try:
|
100 |
-
import noisereduce as nr
|
101 |
-
NOISEREDUCE_AVAILABLE = True
|
102 |
-
except ImportError:
|
103 |
-
NOISEREDUCE_AVAILABLE = False
|
104 |
if apply_noise_reduction and NOISEREDUCE_AVAILABLE:
|
105 |
y = nr.reduce_noise(y=y, sr=sr)
|
106 |
if np.max(np.abs(y)) > 0:
|
@@ -159,7 +110,7 @@ def predict_emotion(audio_file, use_ensemble=False, apply_noise_reduction=False,
|
|
159 |
result = classifier.classify_file(temp_file)
|
160 |
os.remove(temp_file)
|
161 |
if isinstance(result, tuple) and len(result) > 3:
|
162 |
-
label = result[3][0]
|
163 |
else:
|
164 |
label = str(result)
|
165 |
return add_emoji_to_label(label.lower())
|
@@ -183,52 +134,39 @@ def plot_waveform(audio_file):
|
|
183 |
def predict_and_plot(audio_file, use_ensemble, apply_noise_reduction, segment_duration, overlap):
|
184 |
"""
|
185 |
Run emotion prediction and generate a waveform plot.
|
186 |
-
Additionally, upload the audio file to Firebase Storage and store the metadata in Firebase Realtime Database.
|
187 |
Returns a tuple: (emotion label with emoji, waveform image as a PIL Image).
|
188 |
"""
|
189 |
-
# Upload the original audio file to Firebase Storage
|
190 |
-
destination_blob_name = os.path.basename(audio_file)
|
191 |
-
file_url = upload_file_to_firebase(audio_file, destination_blob_name)
|
192 |
-
|
193 |
-
# Predict emotion and generate waveform
|
194 |
emotion = predict_emotion(audio_file, use_ensemble, apply_noise_reduction, segment_duration, overlap)
|
195 |
waveform = plot_waveform(audio_file)
|
196 |
-
|
197 |
-
# Store metadata (file URL and predicted emotion) in Firebase Realtime Database
|
198 |
-
record_key = store_prediction_metadata(file_url, emotion)
|
199 |
-
print(f"Record stored with key: {record_key}")
|
200 |
-
|
201 |
return emotion, waveform
|
202 |
|
203 |
-
# ---------------------------
|
204 |
-
# Gradio App UI
|
205 |
-
# ---------------------------
|
206 |
with gr.Blocks(css=".gradio-container {background-color: #f7f7f7; font-family: Arial;}") as demo:
|
207 |
gr.Markdown("<h1 style='text-align: center;'>Enhanced Emotion Recognition</h1>")
|
208 |
gr.Markdown(
|
209 |
"Upload an audio file, and the model will predict the emotion using a wav2vec2 model fine-tuned on IEMOCAP data. "
|
210 |
-
"The prediction is accompanied by an emoji, and you can view the audio's waveform. "
|
211 |
-
"
|
212 |
)
|
213 |
|
214 |
with gr.Tabs():
|
215 |
with gr.TabItem("Emotion Recognition"):
|
216 |
with gr.Row():
|
217 |
audio_input = gr.Audio(type="filepath", label="Upload Audio")
|
218 |
-
|
219 |
-
|
220 |
with gr.Row():
|
221 |
-
|
222 |
-
|
223 |
predict_button = gr.Button("Predict Emotion")
|
224 |
result_text = gr.Textbox(label="Predicted Emotion")
|
225 |
waveform_image = gr.Image(label="Audio Waveform", type="pil")
|
226 |
|
227 |
predict_button.click(
|
228 |
predict_and_plot,
|
229 |
-
inputs=[audio_input,
|
230 |
outputs=[result_text, waveform_image]
|
231 |
)
|
|
|
232 |
with gr.TabItem("About"):
|
233 |
gr.Markdown("""
|
234 |
**Enhanced Emotion Recognition App**
|
@@ -238,8 +176,7 @@ with gr.Blocks(css=".gradio-container {background-color: #f7f7f7; font-family: A
|
|
238 |
- Ensemble Prediction for long audio files.
|
239 |
- Optional Noise Reduction.
|
240 |
- Visualization of the audio waveform.
|
241 |
-
- Emoji representation of the predicted emotion.
|
242 |
-
- Audio file and prediction metadata stored in Firebase Realtime Database.
|
243 |
|
244 |
**Credits:**
|
245 |
- [SpeechBrain](https://speechbrain.github.io)
|
@@ -247,4 +184,4 @@ with gr.Blocks(css=".gradio-container {background-color: #f7f7f7; font-family: A
|
|
247 |
""")
|
248 |
|
249 |
if __name__ == "__main__":
|
250 |
-
demo.launch()
|
|
|
10 |
import matplotlib.pyplot as plt
|
11 |
import librosa.display
|
12 |
from PIL import Image # For image conversion
|
|
|
13 |
|
14 |
+
# Try to import noisereduce (if not available, noise reduction will be skipped)
|
15 |
+
try:
|
16 |
+
import noisereduce as nr
|
17 |
+
NOISEREDUCE_AVAILABLE = True
|
18 |
+
except ImportError:
|
19 |
+
NOISEREDUCE_AVAILABLE = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
# Mapping from emotion labels to emojis
|
22 |
emotion_to_emoji = {
|
|
|
35 |
emoji = emotion_to_emoji.get(label.lower(), "")
|
36 |
return f"{label.capitalize()} {emoji}"
|
37 |
|
38 |
+
# Load the pre-trained SpeechBrain classifier
|
39 |
classifier = foreign_class(
|
40 |
source="speechbrain/emotion-recognition-wav2vec2-IEMOCAP",
|
41 |
pymodule_file="custom_interface.py",
|
|
|
52 |
Saves the processed audio to a temporary file and returns its path.
|
53 |
"""
|
54 |
y, sr = librosa.load(audio_file, sr=16000, mono=True)
|
|
|
|
|
|
|
|
|
|
|
55 |
if apply_noise_reduction and NOISEREDUCE_AVAILABLE:
|
56 |
y = nr.reduce_noise(y=y, sr=sr)
|
57 |
if np.max(np.abs(y)) > 0:
|
|
|
110 |
result = classifier.classify_file(temp_file)
|
111 |
os.remove(temp_file)
|
112 |
if isinstance(result, tuple) and len(result) > 3:
|
113 |
+
label = result[3][0] # Extract predicted emotion label from the tuple
|
114 |
else:
|
115 |
label = str(result)
|
116 |
return add_emoji_to_label(label.lower())
|
|
|
134 |
def predict_and_plot(audio_file, use_ensemble, apply_noise_reduction, segment_duration, overlap):
|
135 |
"""
|
136 |
Run emotion prediction and generate a waveform plot.
|
|
|
137 |
Returns a tuple: (emotion label with emoji, waveform image as a PIL Image).
|
138 |
"""
|
|
|
|
|
|
|
|
|
|
|
139 |
emotion = predict_emotion(audio_file, use_ensemble, apply_noise_reduction, segment_duration, overlap)
|
140 |
waveform = plot_waveform(audio_file)
|
|
|
|
|
|
|
|
|
|
|
141 |
return emotion, waveform
|
142 |
|
|
|
|
|
|
|
143 |
with gr.Blocks(css=".gradio-container {background-color: #f7f7f7; font-family: Arial;}") as demo:
|
144 |
gr.Markdown("<h1 style='text-align: center;'>Enhanced Emotion Recognition</h1>")
|
145 |
gr.Markdown(
|
146 |
"Upload an audio file, and the model will predict the emotion using a wav2vec2 model fine-tuned on IEMOCAP data. "
|
147 |
+
"The prediction is accompanied by an emoji in the output, and you can also view the audio's waveform. "
|
148 |
+
"Use the options below to adjust ensemble prediction and noise reduction settings."
|
149 |
)
|
150 |
|
151 |
with gr.Tabs():
|
152 |
with gr.TabItem("Emotion Recognition"):
|
153 |
with gr.Row():
|
154 |
audio_input = gr.Audio(type="filepath", label="Upload Audio")
|
155 |
+
use_ensemble = gr.Checkbox(label="Use Ensemble Prediction (for long audio)", value=False)
|
156 |
+
apply_noise_reduction = gr.Checkbox(label="Apply Noise Reduction", value=False)
|
157 |
with gr.Row():
|
158 |
+
segment_duration = gr.Slider(minimum=1.0, maximum=10.0, step=0.5, value=3.0, label="Segment Duration (s)")
|
159 |
+
overlap = gr.Slider(minimum=0.0, maximum=5.0, step=0.5, value=1.0, label="Segment Overlap (s)")
|
160 |
predict_button = gr.Button("Predict Emotion")
|
161 |
result_text = gr.Textbox(label="Predicted Emotion")
|
162 |
waveform_image = gr.Image(label="Audio Waveform", type="pil")
|
163 |
|
164 |
predict_button.click(
|
165 |
predict_and_plot,
|
166 |
+
inputs=[audio_input, use_ensemble, apply_noise_reduction, segment_duration, overlap],
|
167 |
outputs=[result_text, waveform_image]
|
168 |
)
|
169 |
+
|
170 |
with gr.TabItem("About"):
|
171 |
gr.Markdown("""
|
172 |
**Enhanced Emotion Recognition App**
|
|
|
176 |
- Ensemble Prediction for long audio files.
|
177 |
- Optional Noise Reduction.
|
178 |
- Visualization of the audio waveform.
|
179 |
+
- Emoji representation of the predicted emotion in the output.
|
|
|
180 |
|
181 |
**Credits:**
|
182 |
- [SpeechBrain](https://speechbrain.github.io)
|
|
|
184 |
""")
|
185 |
|
186 |
if __name__ == "__main__":
|
187 |
+
demo.launch()
|