File size: 3,845 Bytes
8eee022
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import gradio as gr
import os 
import numpy as np 
import librosa
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from transformers import AutoFeatureExtractor
from sklearnex import patch_sklearn, unpatch_sklearn
patch_sklearn()
import xgboost as xgb


MAX_DURATION = 2
# Sampling rate is the number of samples of audio recorded every second
SAMPLING_RATE = 16000
BATCH_SIZE = 2  # Batch-size for training and evaluating our model.
NUM_CLASSES = 8  # Number of classes our dataset will have (11 in our case).
HIDDEN_DIM = 768  # Dimension of our model output (768 in case of Wav2Vec 2.0 - Base).
MAX_SEQ_LENGTH = MAX_DURATION * SAMPLING_RATE  # Maximum length of the input audio file.
# Wav2Vec 2.0 results in an output frequency with a stride of about 20ms.
MAX_FRAMES = 99
MAX_EPOCHS = 5  # Maximum number of training epochs.
RAVDESS_CLASS_LABELS = ("angry", "calm", "disgust", "fear", "happy", "neutral","sad","surprise")
MODEL_CHECKPOINT = "facebook/wav2vec2-base" 

labels = RAVDESS_CLASS_LABELS
label2id, id2label = dict(), dict()

from transformers import TFWav2Vec2Model


def mean_pool(hidden_states, feature_lengths):
    attenion_mask = tf.sequence_mask(
        feature_lengths, maxlen=MAX_FRAMES, dtype=tf.dtypes.int64
    )
    padding_mask = tf.cast(
        tf.reverse(tf.cumsum(tf.reverse(attenion_mask, [-1]), -1), [-1]),
        dtype=tf.dtypes.bool,
    )
    hidden_states = tf.where(
        tf.broadcast_to(
            tf.expand_dims(~padding_mask, -1), (BATCH_SIZE, MAX_FRAMES, HIDDEN_DIM)
        ),
        0.0,
        hidden_states,
    )
    pooled_state = tf.math.reduce_sum(hidden_states, axis=1) / tf.reshape(
        tf.math.reduce_sum(tf.cast(padding_mask, dtype=tf.dtypes.float32), axis=1),
        [-1, 1],
    )
    return pooled_state


class TFWav2Vec2ForAudioClassification(keras.Model):

    def __init__(self, model_checkpoint):
        super().__init__()
        # Instantiate the Wav2Vec 2.0 model without the Classification-Head
        self.wav2vec2 = TFWav2Vec2Model.from_pretrained(
            model_checkpoint, apply_spec_augment=False, from_pt=True
        )
        self.pooling = layers.GlobalAveragePooling1D()
        self.flat = layers.Flatten()
        self.intermediate_layer_dropout = layers.Dropout(0.5)
      

    def call(self, inputs):
        hidden_states = self.wav2vec2(inputs[0])[0]
        if tf.is_tensor(inputs[1]):
            audio_lengths = tf.cumsum(inputs[1], -1)[:, -1]
            feature_lengths = self.wav2vec2.wav2vec2._get_feat_extract_output_lengths(
                audio_lengths
            )
            pooled_state = mean_pool(hidden_states, feature_lengths)
        else:
            pooled_state = self.pooling(hidden_states)

        intermediate_state = self.flat(self.intermediate_layer_dropout(pooled_state))

        return intermediate_state

wav2vec2_model = TFWav2Vec2ForAudioClassification(MODEL_CHECKPOINT)
wav2vec2_model.load_weights('model.keras')

for i, label in enumerate(labels):
    label2id[label] = str(i)
    id2label[str(i)] = label



feature_extractor = AutoFeatureExtractor.from_pretrained(
    MODEL_CHECKPOINT, return_attention_mask=True
)
xgb_params = {
    'objective':                    'binary:logistic',
    'predictor':                    'cpu_predictor',
    'disable_default_eval_metric':  'true',
}

model_xgb= xgb.XGBClassifier(**xgb_params)

def greet(name):
  inp =  feature_extractor(
    name[1],
    sampling_rate=feature_extractor.sampling_rate,
    max_length=MAX_SEQ_LENGTH,
    truncation=True,
    padding=True,
  )
  inp = np.array([y for x,y in inp.items()])
  pred = wav2vec2_model.predict([inp[0],inp[1]])
  pred = model_xgb.predict(pred)
  lab = id2label[str(pred[0])]
  return lab

iface = gr.Interface(fn=greet, inputs="audio", outputs="text")
iface.launch()