birgermoell
commited on
Commit
•
f9c55bd
1
Parent(s):
8093841
Updated to input wav file directly
Browse files- feature_extractor.py +10 -3
feature_extractor.py
CHANGED
@@ -40,24 +40,31 @@ def change_sample_rate(y, sample_rate, new_sample_rate):
|
|
40 |
value = librosa.resample(y, sample_rate, new_sample_rate)
|
41 |
return value
|
42 |
|
|
|
|
|
|
|
|
|
|
|
43 |
def get_wav2vecembeddings_from_audiofile(wav_file):
|
44 |
print("the file is", wav_file)
|
45 |
speech, sample_rate = sf.read(wav_file)
|
|
|
|
|
|
|
46 |
# change sample rate to 16 000 hertz
|
47 |
resampled = change_sample_rate(speech, sample_rate, new_sample_rate)
|
48 |
print("the speech is", speech)
|
49 |
-
input_values = processor(
|
50 |
print("input values", input_values)
|
51 |
# import pdb
|
52 |
# pdb.set_trace()
|
53 |
|
54 |
with torch.no_grad():
|
55 |
encoded_states = model(
|
56 |
-
|
57 |
# attention_mask=input_values["attention_mask"],
|
58 |
output_hidden_states=True
|
59 |
)
|
60 |
-
|
61 |
last_hidden_state = encoded_states.hidden_states[-1] # The last hidden-state is the first element of the output tuple
|
62 |
print("getting wav2vec2 embeddings")
|
63 |
print(last_hidden_state)
|
|
|
40 |
value = librosa.resample(y, sample_rate, new_sample_rate)
|
41 |
return value
|
42 |
|
43 |
+
def stereo_to_mono(audio_input):
|
44 |
+
X = audio_input.mean(axis=1, keepdims=True)
|
45 |
+
X = np.squeeze(X)
|
46 |
+
return X
|
47 |
+
|
48 |
def get_wav2vecembeddings_from_audiofile(wav_file):
|
49 |
print("the file is", wav_file)
|
50 |
speech, sample_rate = sf.read(wav_file)
|
51 |
+
|
52 |
+
if len(speech.shape) > 1:
|
53 |
+
speech = stereo_to_mono(speech)
|
54 |
# change sample rate to 16 000 hertz
|
55 |
resampled = change_sample_rate(speech, sample_rate, new_sample_rate)
|
56 |
print("the speech is", speech)
|
57 |
+
input_values = processor(wav_file, return_tensors="pt", padding=True, sampling_rate=new_sample_rate) # there is no truncation param anymore
|
58 |
print("input values", input_values)
|
59 |
# import pdb
|
60 |
# pdb.set_trace()
|
61 |
|
62 |
with torch.no_grad():
|
63 |
encoded_states = model(
|
64 |
+
input_values=input_values["input_ids"],
|
65 |
# attention_mask=input_values["attention_mask"],
|
66 |
output_hidden_states=True
|
67 |
)
|
|
|
68 |
last_hidden_state = encoded_states.hidden_states[-1] # The last hidden-state is the first element of the output tuple
|
69 |
print("getting wav2vec2 embeddings")
|
70 |
print(last_hidden_state)
|