andromeda01111 commited on
Commit
7ded071
·
verified ·
1 Parent(s): 7c32789

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -1
app.py CHANGED
@@ -1,3 +1,64 @@
1
  import gradio as gr
 
 
 
 
 
 
2
 
3
- gr.Interface.load("models/andromeda01111/Malayalam_SA").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torchaudio
6
+ from transformers import AutoConfig, Wav2Vec2Processor, Wav2Vec2FeatureExtractor
7
+ from src.models import Wav2Vec2ForSpeechClassification
8
 
9
+ import librosa
10
+ import IPython.display as ipd
11
+ import numpy as np
12
+ import pandas as pd
13
+ import os
14
+
15
+ model_name_or_path = "andromeda01111/Malayalam_SA"
16
+ config = AutoConfig.from_pretrained(model_name_or_path)
17
+ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name_or_path)
18
+ sampling_rate = feature_extractor.sampling_rate
19
+ model = Wav2Vec2ForSpeechClassification.from_pretrained(model_name_or_path)
20
+
21
+
22
+ def speech_file_to_array_fn(path, sampling_rate):
23
+ speech_array, _sampling_rate = torchaudio.load(path)
24
+ resampler = torchaudio.transforms.Resample(_sampling_rate)
25
+ speech = resampler(speech_array).squeeze().numpy()
26
+ return speech
27
+
28
+
29
+ def predict(path, sampling_rate):
30
+ speech = speech_file_to_array_fn(path, sampling_rate)
31
+ features = feature_extractor(speech, sampling_rate=sampling_rate, return_tensors="pt", padding=True)
32
+
33
+ input_values = features.input_values
34
+ attention_mask = features.attention_mask
35
+
36
+ with torch.no_grad():
37
+ logits = model(input_values, attention_mask=attention_mask).logits
38
+
39
+ scores = F.softmax(logits, dim=1).detach().cpu().numpy()[0]
40
+ output_emotion = [{"Emotion": config.id2label[i], "Score": f"{round(score * 100, 3):.1f}%"} for i, score in enumerate(scores)]
41
+
42
+ return output_emotion
43
+
44
+
45
+ # Wrapper function for Gradio
46
+ def gradio_predict(audio):
47
+ predictions = predict(audio)
48
+ return [f"{pred['Emotion']}: {pred['Score']}" for pred in predictions]
49
+
50
+
51
+ # Gradio interface
52
+ emotions = [config.id2label[i] for i in range(len(config.id2label))]
53
+ outputs = [gr.Textbox(label=emotion, interactive=False) for emotion in emotions]
54
+
55
+ interface = gr.Interface(
56
+ fn=predict,
57
+ inputs=gr.Audio(label="Upload Audio", type="filepath"),
58
+ outputs=outputs,
59
+ title="Emotion Recognition",
60
+ description="Upload an audio file to predict emotions and their corresponding percentages.",
61
+ )
62
+
63
+ # Launch the app
64
+ interface.launch()