invincible-jha commited on
Commit
784383b
·
verified ·
1 Parent(s): 9ec2a83

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -48
app.py CHANGED
@@ -7,7 +7,7 @@ import plotly.graph_objects as go
7
 
8
  class ModelManager:
9
  def __init__(self):
10
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
  self.models = {}
12
  self.tokenizers = {}
13
  self.processors = {}
@@ -15,12 +15,23 @@ class ModelManager:
15
 
16
  def load_models(self):
17
  print("Loading Whisper model...")
18
- self.processors['whisper'] = WhisperProcessor.from_pretrained("openai/whisper-base")
19
- self.models['whisper'] = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base").to(self.device)
 
 
 
 
 
 
20
 
21
  print("Loading emotion model...")
22
- self.tokenizers['emotion'] = AutoTokenizer.from_pretrained("j-hartmann/emotion-english-distilroberta-base")
23
- self.models['emotion'] = AutoModelForSequenceClassification.from_pretrained("j-hartmann/emotion-english-distilroberta-base").to(self.device)
 
 
 
 
 
24
 
25
  class AudioProcessor:
26
  def __init__(self):
@@ -28,14 +39,22 @@ class AudioProcessor:
28
  self.n_mfcc = 13
29
 
30
  def process_audio(self, audio_path):
31
- waveform, sr = librosa.load(audio_path, sr=self.sample_rate)
32
- return waveform, self._extract_features(waveform)
 
 
 
 
33
 
34
  def _extract_features(self, waveform):
35
- return {
36
- 'mfcc': librosa.feature.mfcc(y=waveform, sr=self.sample_rate, n_mfcc=self.n_mfcc),
37
- 'energy': librosa.feature.rms(y=waveform)[0]
38
- }
 
 
 
 
39
 
40
  class Analyzer:
41
  def __init__(self):
@@ -45,45 +64,80 @@ class Analyzer:
45
  print("Analyzer initialization complete")
46
 
47
  def analyze(self, audio_path):
48
- print(f"Processing audio file: {audio_path}")
49
- waveform, features = self.audio_processor.process_audio(audio_path)
50
-
51
- print("Transcribing audio...")
52
- inputs = self.model_manager.processors['whisper'](waveform, return_tensors="pt").input_features.to(self.model_manager.device)
53
- predicted_ids = self.model_manager.models['whisper'].generate(inputs)
54
- transcription = self.model_manager.processors['whisper'].batch_decode(predicted_ids, skip_special_tokens=True)[0]
55
-
56
- print("Analyzing emotions...")
57
- inputs = self.model_manager.tokenizers['emotion'](transcription, return_tensors="pt", padding=True, truncation=True)
58
- outputs = self.model_manager.models['emotion'](**inputs)
59
- emotions = torch.nn.functional.softmax(outputs.logits, dim=-1)
60
-
61
- emotion_labels = ['anger', 'fear', 'joy', 'neutral', 'sadness', 'surprise']
62
- emotion_scores = {
63
- label: float(score)
64
- for label, score in zip(emotion_labels, emotions[0])
65
- }
66
-
67
- return {
68
- 'transcription': transcription,
69
- 'emotions': emotion_scores
70
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  def create_emotion_plot(emotions):
73
- fig = go.Figure(data=[
74
- go.Bar(x=list(emotions.keys()), y=list(emotions.values()))
75
- ])
76
- fig.update_layout(
77
- title='Emotion Analysis',
78
- yaxis_range=[0, 1]
79
- )
80
- return fig.to_html()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
  print("Initializing application...")
83
  analyzer = Analyzer()
84
 
85
  def process_audio(audio_file):
86
  try:
 
 
 
87
  print(f"Processing audio file: {audio_file}")
88
  results = analyzer.analyze(audio_file)
89
 
@@ -92,20 +146,28 @@ def process_audio(audio_file):
92
  create_emotion_plot(results['emotions'])
93
  )
94
  except Exception as e:
95
- print(f"Error processing audio: {str(e)}")
96
- return str(e), "Error in analysis"
 
97
 
98
  print("Creating Gradio interface...")
99
  interface = gr.Interface(
100
  fn=process_audio,
101
- inputs=gr.Audio(sources=["microphone", "upload"]), # Fixed parameter
102
  outputs=[
103
  gr.Textbox(label="Transcription"),
104
  gr.HTML(label="Emotion Analysis")
105
  ],
106
  title="Vocal Biomarker Analysis",
107
- description="Analyze voice for emotional indicators"
 
 
108
  )
109
 
110
- print("Launching application...")
111
- interface.launch(share=False)
 
 
 
 
 
 
7
 
8
  class ModelManager:
9
  def __init__(self):
10
+ self.device = torch.device("cpu")
11
  self.models = {}
12
  self.tokenizers = {}
13
  self.processors = {}
 
15
 
16
  def load_models(self):
17
  print("Loading Whisper model...")
18
+ self.processors['whisper'] = WhisperProcessor.from_pretrained(
19
+ "openai/whisper-base",
20
+ device_map="cpu"
21
+ )
22
+ self.models['whisper'] = WhisperForConditionalGeneration.from_pretrained(
23
+ "openai/whisper-base",
24
+ device_map="cpu"
25
+ )
26
 
27
  print("Loading emotion model...")
28
+ self.tokenizers['emotion'] = AutoTokenizer.from_pretrained(
29
+ "j-hartmann/emotion-english-distilroberta-base"
30
+ )
31
+ self.models['emotion'] = AutoModelForSequenceClassification.from_pretrained(
32
+ "j-hartmann/emotion-english-distilroberta-base",
33
+ device_map="cpu"
34
+ )
35
 
36
  class AudioProcessor:
37
  def __init__(self):
 
39
  self.n_mfcc = 13
40
 
41
  def process_audio(self, audio_path):
42
+ try:
43
+ waveform, sr = librosa.load(audio_path, sr=self.sample_rate)
44
+ return waveform, self._extract_features(waveform)
45
+ except Exception as e:
46
+ print(f"Error processing audio: {str(e)}")
47
+ raise
48
 
49
  def _extract_features(self, waveform):
50
+ try:
51
+ return {
52
+ 'mfcc': librosa.feature.mfcc(y=waveform, sr=self.sample_rate, n_mfcc=self.n_mfcc),
53
+ 'energy': librosa.feature.rms(y=waveform)[0]
54
+ }
55
+ except Exception as e:
56
+ print(f"Error extracting features: {str(e)}")
57
+ raise
58
 
59
  class Analyzer:
60
  def __init__(self):
 
64
  print("Analyzer initialization complete")
65
 
66
  def analyze(self, audio_path):
67
+ try:
68
+ print(f"Processing audio file: {audio_path}")
69
+ waveform, features = self.audio_processor.process_audio(audio_path)
70
+
71
+ print("Transcribing audio...")
72
+ inputs = self.model_manager.processors['whisper'](
73
+ waveform,
74
+ return_tensors="pt"
75
+ ).input_features
76
+
77
+ predicted_ids = self.model_manager.models['whisper'].generate(inputs)
78
+ transcription = self.model_manager.processors['whisper'].batch_decode(
79
+ predicted_ids,
80
+ skip_special_tokens=True
81
+ )[0]
82
+
83
+ print("Analyzing emotions...")
84
+ inputs = self.model_manager.tokenizers['emotion'](
85
+ transcription,
86
+ return_tensors="pt",
87
+ padding=True,
88
+ truncation=True,
89
+ max_length=512
90
+ )
91
+
92
+ outputs = self.model_manager.models['emotion'](**inputs)
93
+ emotions = torch.nn.functional.softmax(outputs.logits, dim=-1)
94
+
95
+ emotion_labels = ['anger', 'fear', 'joy', 'neutral', 'sadness', 'surprise']
96
+ emotion_scores = {
97
+ label: float(score)
98
+ for label, score in zip(emotion_labels, emotions[0])
99
+ }
100
+
101
+ return {
102
+ 'transcription': transcription,
103
+ 'emotions': emotion_scores
104
+ }
105
+ except Exception as e:
106
+ print(f"Error in analysis: {str(e)}")
107
+ raise
108
 
109
  def create_emotion_plot(emotions):
110
+ try:
111
+ fig = go.Figure(data=[
112
+ go.Bar(
113
+ x=list(emotions.keys()),
114
+ y=list(emotions.values()),
115
+ marker_color='rgb(55, 83, 109)'
116
+ )
117
+ ])
118
+
119
+ fig.update_layout(
120
+ title='Emotion Analysis',
121
+ xaxis_title='Emotion',
122
+ yaxis_title='Score',
123
+ yaxis_range=[0, 1],
124
+ template='plotly_white',
125
+ height=400
126
+ )
127
+
128
+ return fig.to_html(include_plotlyjs=True)
129
+ except Exception as e:
130
+ print(f"Error creating plot: {str(e)}")
131
+ return "Error creating visualization"
132
 
133
  print("Initializing application...")
134
  analyzer = Analyzer()
135
 
136
  def process_audio(audio_file):
137
  try:
138
+ if audio_file is None:
139
+ return "No audio file provided", "Please provide an audio file"
140
+
141
  print(f"Processing audio file: {audio_file}")
142
  results = analyzer.analyze(audio_file)
143
 
 
146
  create_emotion_plot(results['emotions'])
147
  )
148
  except Exception as e:
149
+ error_msg = f"Error processing audio: {str(e)}"
150
+ print(error_msg)
151
+ return error_msg, "Error in analysis"
152
 
153
  print("Creating Gradio interface...")
154
  interface = gr.Interface(
155
  fn=process_audio,
156
+ inputs=gr.Audio(sources=["microphone", "upload"], type="filepath"),
157
  outputs=[
158
  gr.Textbox(label="Transcription"),
159
  gr.HTML(label="Emotion Analysis")
160
  ],
161
  title="Vocal Biomarker Analysis",
162
+ description="Analyze voice for emotional indicators",
163
+ examples=[],
164
+ cache_examples=False
165
  )
166
 
167
+ if __name__ == "__main__":
168
+ print("Launching application...")
169
+ interface.launch(
170
+ server_name="0.0.0.0",
171
+ server_port=7860,
172
+ share=False
173
+ )