omsandeeppatil commited on
Commit
d5038df
·
verified ·
1 Parent(s): 1de8eea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -31
app.py CHANGED
@@ -2,6 +2,7 @@ import gradio as gr
2
  import torch
3
  import numpy as np
4
  from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2ForSequenceClassification
 
5
 
6
  # Initialize model and processor
7
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -21,58 +22,119 @@ EMOTION_LABELS = {
21
  6: "surprise"
22
  }
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  def process_audio(audio):
25
  """Process audio chunk and return emotion"""
26
  if audio is None:
27
- return ""
28
-
29
  try:
30
  # Get the audio data
31
  if isinstance(audio, tuple):
32
  audio = audio[1]
33
-
34
- # Convert to numpy array and ensure float32 type
35
- audio = np.array(audio, dtype=np.float32)
36
-
37
- # Ensure we have mono audio
38
- if len(audio.shape) > 1:
39
- audio = audio.mean(axis=1)
40
-
41
- # Normalize audio if needed
42
- if audio.max() > 1.0 or audio.min() < -1.0:
43
- audio = audio / max(abs(audio.max()), abs(audio.min()))
44
 
45
- # Ensure we have non-zero audio
46
- if len(audio) == 0 or np.all(audio == 0):
47
- return "No audio detected"
48
-
 
 
 
 
49
  # Prepare input for the model
50
  inputs = feature_extractor(
51
- audio,
52
  sampling_rate=16000,
53
  return_tensors="pt",
54
  padding=True
55
  )
56
 
57
- # Ensure all tensors are float32
58
  inputs = {k: v.to(device, dtype=torch.float32) for k, v in inputs.items()}
59
 
60
  # Get prediction
61
  with torch.no_grad():
62
  outputs = model(**inputs)
63
  logits = outputs.logits
64
- predicted_id = torch.argmax(logits, dim=-1).item()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
- # Get probabilities
67
- probs = torch.nn.functional.softmax(logits, dim=-1)
68
- confidence = probs[0][predicted_id].item() * 100
69
 
70
- emotion = EMOTION_LABELS[predicted_id]
71
- return f"{emotion} (confidence: {confidence:.1f}%)"
72
-
73
  except Exception as e:
74
- print(f"Error in audio processing: {str(e)}")
75
- return "Error processing audio. Please try again."
76
 
77
  # Create Gradio interface
78
  demo = gr.Interface(
@@ -86,9 +148,12 @@ demo = gr.Interface(
86
  show_label=True
87
  )
88
  ],
89
- outputs=gr.Textbox(label="Detected Emotion"),
90
- title="Live Emotion Detection",
91
- description="Speak into your microphone to detect emotions in real-time.",
 
 
 
92
  live=True,
93
  allow_flagging=False
94
  )
 
2
  import torch
3
  import numpy as np
4
  from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2ForSequenceClassification
5
+ import librosa
6
 
7
  # Initialize model and processor
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
22
  6: "surprise"
23
  }
24
 
25
+ def preprocess_audio(audio, target_sr=16000):
26
+ """Enhanced audio preprocessing"""
27
+ try:
28
+ # Convert to numpy array and ensure float32
29
+ audio = np.array(audio, dtype=np.float32)
30
+
31
+ # Convert to mono if stereo
32
+ if len(audio.shape) > 1:
33
+ audio = librosa.to_mono(audio.T)
34
+
35
+ # Resample if needed
36
+ if target_sr != 16000:
37
+ audio = librosa.resample(audio, orig_sr=target_sr, target_sr=16000)
38
+
39
+ # Apply preprocessing steps
40
+ # 1. Noise reduction
41
+ audio = librosa.effects.preemphasis(audio)
42
+
43
+ # 2. Normalize
44
+ audio = librosa.util.normalize(audio)
45
+
46
+ # 3. Voice activity detection
47
+ intervals = librosa.effects.split(audio, top_db=20)
48
+ if len(intervals) > 0:
49
+ audio = np.concatenate([audio[start:end] for start, end in intervals])
50
+
51
+ # 4. Ensure minimum length (1 second)
52
+ if len(audio) < 16000:
53
+ audio = np.pad(audio, (0, 16000 - len(audio)))
54
+
55
+ # 5. Take center 3 seconds if too long
56
+ if len(audio) > 48000: # 3 seconds at 16kHz
57
+ center = len(audio) // 2
58
+ start = center - 24000
59
+ end = center + 24000
60
+ audio = audio[start:end]
61
+
62
+ return audio
63
+
64
+ except Exception as e:
65
+ print(f"Preprocessing error: {str(e)}")
66
+ return None
67
+
68
+ def get_emotion_history():
69
+ """Get emotion detection history"""
70
+ if not hasattr(get_emotion_history, "history"):
71
+ get_emotion_history.history = []
72
+ return get_emotion_history.history
73
+
74
  def process_audio(audio):
75
  """Process audio chunk and return emotion"""
76
  if audio is None:
77
+ return "No audio input detected"
78
+
79
  try:
80
  # Get the audio data
81
  if isinstance(audio, tuple):
82
  audio = audio[1]
 
 
 
 
 
 
 
 
 
 
 
83
 
84
+ # Preprocess audio
85
+ processed_audio = preprocess_audio(audio)
86
+ if processed_audio is None:
87
+ return "Audio preprocessing failed"
88
+
89
+ if np.max(np.abs(processed_audio)) < 0.01:
90
+ return "Audio too quiet"
91
+
92
  # Prepare input for the model
93
  inputs = feature_extractor(
94
+ processed_audio,
95
  sampling_rate=16000,
96
  return_tensors="pt",
97
  padding=True
98
  )
99
 
100
+ # Move to device and ensure float32
101
  inputs = {k: v.to(device, dtype=torch.float32) for k, v in inputs.items()}
102
 
103
  # Get prediction
104
  with torch.no_grad():
105
  outputs = model(**inputs)
106
  logits = outputs.logits
107
+ probs = torch.nn.functional.softmax(logits, dim=-1)[0]
108
+
109
+ # Get top 2 predictions
110
+ top2_probs, top2_ids = torch.topk(probs, 2)
111
+
112
+ # Convert to percentages
113
+ top2_probs = [p * 100 for p in top2_probs.cpu().numpy()]
114
+ top2_emotions = [EMOTION_LABELS[idx.item()] for idx in top2_ids]
115
+
116
+ # Update history
117
+ history = get_emotion_history()
118
+ history.append(top2_emotions[0])
119
+ if len(history) > 5:
120
+ history.pop(0)
121
+
122
+ # Get most common emotion in history
123
+ if len(history) >= 3:
124
+ from collections import Counter
125
+ most_common = Counter(history).most_common(1)[0][0]
126
+ else:
127
+ most_common = top2_emotions[0]
128
+
129
+ result = f"Primary: {top2_emotions[0]} ({top2_probs[0]:.1f}%)\n"
130
+ result += f"Secondary: {top2_emotions[1]} ({top2_probs[1]:.1f}%)\n"
131
+ result += f"Trending: {most_common}"
132
 
133
+ return result
 
 
134
 
 
 
 
135
  except Exception as e:
136
+ print(f"Error in processing: {str(e)}")
137
+ return "Processing error. Please try again."
138
 
139
  # Create Gradio interface
140
  demo = gr.Interface(
 
148
  show_label=True
149
  )
150
  ],
151
+ outputs=gr.Textbox(
152
+ label="Detected Emotions",
153
+ lines=3
154
+ ),
155
+ title="Enhanced Live Emotion Detection",
156
+ description="Speak naturally into your microphone. Shows primary and secondary emotions with confidence levels.",
157
  live=True,
158
  allow_flagging=False
159
  )