birgermoell
commited on
Commit
•
8093841
1
Parent(s):
a875c0d
Update resampling
Browse files- feature_extractor.py +25 -20
- readme.MD +1 -1
feature_extractor.py
CHANGED
@@ -6,13 +6,13 @@ from transformers import AutoTokenizer, Wav2Vec2ForCTC
|
|
6 |
import torch
|
7 |
import numpy as np
|
8 |
import glob
|
|
|
9 |
import numpy
|
10 |
import os.path
|
11 |
|
12 |
processor = AutoTokenizer.from_pretrained("facebook/wav2vec2-large-960h-lv60")
|
13 |
-
|
14 |
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60")
|
15 |
-
|
16 |
# Dementia path
|
17 |
# /home/bmoell/data/media.talkbank.org/dementia/English/Pitt
|
18 |
# cookie dementia /home/bmoell/data/media.talkbank.org/dementia/English/Pitt/Dementia/cookie
|
@@ -35,28 +35,33 @@ def feature_extractor(path):
|
|
35 |
if not os.path.isfile(wav_file + ".wav2vec2.pt"):
|
36 |
get_wav2vecembeddings_from_audiofile(wav_file)
|
37 |
|
|
|
|
|
|
|
|
|
|
|
38 |
def get_wav2vecembeddings_from_audiofile(wav_file):
|
39 |
print("the file is", wav_file)
|
40 |
speech, sample_rate = sf.read(wav_file)
|
41 |
-
|
|
|
|
|
|
|
42 |
print("input values", input_values)
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
print("getting wav2vec2 embeddings")
|
58 |
-
print(last_hidden_state)
|
59 |
-
torch.save(last_hidden_state, wav_file + '.wav2vec2.pt')
|
60 |
|
61 |
|
62 |
|
|
|
6 |
import torch
|
7 |
import numpy as np
|
8 |
import glob
|
9 |
+
import librosa
|
10 |
import numpy
|
11 |
import os.path
|
12 |
|
13 |
processor = AutoTokenizer.from_pretrained("facebook/wav2vec2-large-960h-lv60")
|
|
|
14 |
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60")
|
15 |
+
new_sample_rate = 16000
|
16 |
# Dementia path
|
17 |
# /home/bmoell/data/media.talkbank.org/dementia/English/Pitt
|
18 |
# cookie dementia /home/bmoell/data/media.talkbank.org/dementia/English/Pitt/Dementia/cookie
|
|
|
35 |
if not os.path.isfile(wav_file + ".wav2vec2.pt"):
|
36 |
get_wav2vecembeddings_from_audiofile(wav_file)
|
37 |
|
38 |
+
|
39 |
+
def change_sample_rate(y, sample_rate, new_sample_rate):
|
40 |
+
value = librosa.resample(y, sample_rate, new_sample_rate)
|
41 |
+
return value
|
42 |
+
|
43 |
def get_wav2vecembeddings_from_audiofile(wav_file):
|
44 |
print("the file is", wav_file)
|
45 |
speech, sample_rate = sf.read(wav_file)
|
46 |
+
# change sample rate to 16 000 hertz
|
47 |
+
resampled = change_sample_rate(speech, sample_rate, new_sample_rate)
|
48 |
+
print("the speech is", speech)
|
49 |
+
input_values = processor(resampled, return_tensors="pt", padding=True, sampling_rate=new_sample_rate) # there is no truncation param anymore
|
50 |
print("input values", input_values)
|
51 |
+
# import pdb
|
52 |
+
# pdb.set_trace()
|
53 |
+
|
54 |
+
with torch.no_grad():
|
55 |
+
encoded_states = model(
|
56 |
+
**input_values,
|
57 |
+
# attention_mask=input_values["attention_mask"],
|
58 |
+
output_hidden_states=True
|
59 |
+
)
|
60 |
+
|
61 |
+
last_hidden_state = encoded_states.hidden_states[-1] # The last hidden-state is the first element of the output tuple
|
62 |
+
print("getting wav2vec2 embeddings")
|
63 |
+
print(last_hidden_state)
|
64 |
+
torch.save(last_hidden_state, wav_file + '.wav2vec2.pt')
|
|
|
|
|
|
|
65 |
|
66 |
|
67 |
|
readme.MD
CHANGED
@@ -6,7 +6,7 @@ train
|
|
6 |
# Important readmes
|
7 |
https://github.com/huggingface/transformers/tree/f42a0abf4bd765ad08e14b347a3acbe9fade31b9/examples/research_projects/jax-projects/wav2vec2
|
8 |
|
9 |
-
# path to
|
10 |
# cookie control
|
11 |
data/media.talkbank.org/dementia/English/Pitt/Control/cookie
|
12 |
|
|
|
6 |
# Important readmes
|
7 |
https://github.com/huggingface/transformers/tree/f42a0abf4bd765ad08e14b347a3acbe9fade31b9/examples/research_projects/jax-projects/wav2vec2
|
8 |
|
9 |
+
# path to file
|
10 |
# cookie control
|
11 |
data/media.talkbank.org/dementia/English/Pitt/Control/cookie
|
12 |
|