p-alonso commited on
Commit
c06a2fa
·
1 Parent(s): 2b74698

Upload feature extractor

Browse files
feature_extraction_maest.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Feature extractor class for Music Audio Efficient Spectrogram Transformer.
17
+ """
18
+
19
+
20
+ from typing import List, Optional, Union
21
+
22
+ import numpy as np
23
+ import torch
24
+
25
+ from transformers.audio_utils import mel_filter_bank, spectrogram, window_function
26
+ from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
27
+ from transformers.feature_extraction_utils import BatchFeature
28
+ from transformers.utils import TensorType, logging
29
+
30
+
31
+ logger = logging.get_logger(__name__)
32
+
33
+
34
+ class MAESTFeatureExtractor(SequenceFeatureExtractor):
35
+ r"""
36
+ Constructs a Music Audio Efficient Spectrogram Transformer (MAEST) feature extractor.
37
+
38
+ This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains
39
+ most of the main methods. Users should refer to this superclass for more information regarding those methods.
40
+
41
+ This class extracts mel-filter bank features from raw audio, pads/truncates them to a fixed length and normalizes
42
+ them using a mean and standard deviation.
43
+
44
+ Args:
45
+ feature_size (`int`, *optional*, defaults to 1):
46
+ The feature dimension of the extracted features.
47
+ sampling_rate (`int`, *optional*, defaults to 16000):
48
+ The sampling rate at which the audio files should be digitalized expressed in hertz (Hz).
49
+ num_mel_bins (`int`, *optional*, defaults to 96):
50
+ Number of Mel-frequency bins.
51
+ max_length (`int`, *optional*, defaults to 1876):
52
+ Maximum length to which to pad/truncate the extracted features. Set to -1 to deactivate the functionallity.
53
+ padding_value (`int`, *optional*, defaults to 0.0):
54
+ The value used to pad the input waveform.
55
+ do_normalize (`bool`, *optional*, defaults to `True`):
56
+ Whether or not to normalize the log-Mel features using `mean` and `std`.
57
+ mean (`float`, *optional*, defaults to 2.06755686098554):
58
+ The mean value used to normalize the log-Mel features. Uses the Discogs20 mean by default.
59
+ std (`float`, *optional*, defaults to 1.268292820667291):
60
+ The standard deviation value used to normalize the log-Mel features. Uses the Discogs20 standard deviation
61
+ by default.
62
+ return_attention_mask (`bool`, *optional*, defaults to `False`):
63
+ Whether or not [`~ASTFeatureExtractor.__call__`] should return `attention_mask`.
64
+ n_fft (`int`, *optional*, defaults to 512):
65
+ Length of the FFT window.
66
+ hop_length (`int`, *optional*, defaults to 256):
67
+ Number of samples between successive frames.
68
+ log_compression (`str`, *optional*, defaults to `"logC"`):
69
+ Type of log compression to apply to the mel-spectrogram. Can be one of [`None`, `log`, `logC`].
70
+ """
71
+
72
+ model_input_names = ["input_values", "attention_mask"]
73
+
74
+ def __init__(
75
+ self,
76
+ feature_size=1,
77
+ sampling_rate=16000,
78
+ num_mel_bins=96,
79
+ max_length=1876,
80
+ padding_value=0.0,
81
+ do_normalize=True,
82
+ mean=2.06755686098554,
83
+ std=1.268292820667291,
84
+ return_attention_mask=False,
85
+ n_fft=512,
86
+ hop_length=256,
87
+ log_compression="logC",
88
+ **kwargs,
89
+ ):
90
+ super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
91
+ self.sampling_rate = sampling_rate
92
+ self.n_fft = n_fft
93
+ self.hop_length = hop_length
94
+ self.log_compression = log_compression
95
+ self.num_mel_bins = num_mel_bins
96
+ self.max_length = max_length
97
+ self.do_normalize = do_normalize
98
+ self.mean = mean
99
+ self.std = std
100
+ self.return_attention_mask = return_attention_mask
101
+
102
+ self.window = window_function(
103
+ window_length=self.n_fft,
104
+ name="hann",
105
+ ).tolist()
106
+
107
+ self.mel_fb = mel_filter_bank(
108
+ num_frequency_bins=self.n_fft // 2 + 1,
109
+ num_mel_filters=self.num_mel_bins,
110
+ min_frequency=0,
111
+ max_frequency=self.sampling_rate / 2,
112
+ sampling_rate=self.sampling_rate,
113
+ norm="slaney",
114
+ mel_scale="slaney",
115
+ ).tolist()
116
+
117
+ def _extract_fbank_features(
118
+ self,
119
+ waveform: np.ndarray,
120
+ max_length: int,
121
+ ) -> np.ndarray:
122
+ """
123
+ Get mel-spectrogram features using audio_utils.
124
+ """
125
+
126
+ melspec = spectrogram(
127
+ waveform,
128
+ window=np.array(self.window),
129
+ frame_length=self.n_fft,
130
+ hop_length=self.hop_length,
131
+ power=2,
132
+ mel_filters=np.array(self.mel_fb),
133
+ min_value=1e-30,
134
+ mel_floor=1e-30,
135
+ pad_mode="constant",
136
+ ).T
137
+
138
+ if not self.log_compression:
139
+ pass
140
+ elif self.log_compression == "log":
141
+ melspec = np.log(melspec + np.finfo(float).eps)
142
+ elif self.log_compression == "logC":
143
+ melspec = np.log10(1 + melspec * 10000)
144
+ else:
145
+ raise ValueError(
146
+ f"`log_compression` can only be one of [None, 'log', 'logC'], but got: {self.log_compression}"
147
+ )
148
+
149
+ melspec = torch.Tensor(melspec)
150
+ n_frames = melspec.shape[0]
151
+
152
+ if max_length > 0:
153
+ difference = max_length - n_frames
154
+
155
+ # pad or truncate, depending on difference
156
+ if difference > 0:
157
+ pad_module = torch.nn.ZeroPad2d((0, 0, 0, difference))
158
+ melspec = pad_module(melspec)
159
+ elif difference < 0:
160
+ melspec = melspec[0:max_length, :]
161
+
162
+ return melspec.numpy()
163
+
164
+ def normalize(self, input_values: np.ndarray) -> np.ndarray:
165
+ return (input_values - (self.mean)) / (self.std * 2)
166
+
167
+ def __call__(
168
+ self,
169
+ raw_speech: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]],
170
+ sampling_rate: Optional[int] = None,
171
+ return_tensors: Optional[Union[str, TensorType]] = None,
172
+ **kwargs,
173
+ ) -> BatchFeature:
174
+ """
175
+ Main method to featurize and prepare for the model one or several sequence(s).
176
+
177
+ Args:
178
+ raw_speech (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`):
179
+ The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float
180
+ values, a list of numpy arrays or a list of list of float values. Must be mono channel audio, not
181
+ stereo, i.e. single float per timestep.
182
+ sampling_rate (`int`, *optional*):
183
+ The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass
184
+ `sampling_rate` at the forward call to prevent silent errors.
185
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
186
+ If set, will return tensors instead of list of python integers. Acceptable values are:
187
+
188
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
189
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
190
+ - `'np'`: Return Numpy `np.ndarray` objects.
191
+ """
192
+
193
+ if sampling_rate is not None:
194
+ if sampling_rate != self.sampling_rate:
195
+ raise ValueError(
196
+ f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of"
197
+ f" {self.sampling_rate}. Please make sure that the provided `raw_speech` input was sampled with"
198
+ f" {self.sampling_rate} and not {sampling_rate}."
199
+ )
200
+ else:
201
+ logger.warning(
202
+ "It is strongly recommended to pass the `sampling_rate` argument to this function. "
203
+ "Failing to do so can result in silent errors that might be hard to debug."
204
+ )
205
+
206
+ is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1
207
+ if is_batched_numpy and len(raw_speech.shape) > 2:
208
+ raise ValueError(f"Only mono-channel audio is supported for input to {self}")
209
+ is_batched = is_batched_numpy or (
210
+ isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list)))
211
+ )
212
+
213
+ if is_batched:
214
+ raw_speech = [np.asarray(speech, dtype=np.float32) for speech in raw_speech]
215
+ elif not is_batched and not isinstance(raw_speech, np.ndarray):
216
+ raw_speech = np.asarray(raw_speech, dtype=np.float32)
217
+ elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(np.float64):
218
+ raw_speech = raw_speech.astype(np.float32)
219
+
220
+ # always return batch
221
+ if not is_batched:
222
+ raw_speech = [raw_speech]
223
+
224
+ # extract fbank features and pad/truncate to max_length
225
+ features = [self._extract_fbank_features(waveform, max_length=self.max_length) for waveform in raw_speech]
226
+
227
+ # convert into BatchFeature
228
+ padded_inputs = BatchFeature({"input_values": features})
229
+
230
+ # make sure list is in array format
231
+ input_values = padded_inputs.get("input_values")
232
+ if isinstance(input_values[0], list):
233
+ padded_inputs["input_values"] = [np.asarray(feature, dtype=np.float32) for feature in input_values]
234
+
235
+ # normalization
236
+ if self.do_normalize:
237
+ padded_inputs["input_values"] = [self.normalize(feature) for feature in input_values]
238
+
239
+ if return_tensors is not None:
240
+ padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
241
+
242
+ return padded_inputs
preprocessor_config.json CHANGED
The diff for this file is too large to render. See raw diff