birgermoell commited on
Commit
f9c55bd
1 Parent(s): 8093841

Updated to input wav file directly

Browse files
Files changed (1) hide show
  1. feature_extractor.py +10 -3
feature_extractor.py CHANGED
@@ -40,24 +40,31 @@ def change_sample_rate(y, sample_rate, new_sample_rate):
40
  value = librosa.resample(y, sample_rate, new_sample_rate)
41
  return value
42
 
 
 
 
 
 
43
  def get_wav2vecembeddings_from_audiofile(wav_file):
44
  print("the file is", wav_file)
45
  speech, sample_rate = sf.read(wav_file)
 
 
 
46
  # change sample rate to 16 000 hertz
47
  resampled = change_sample_rate(speech, sample_rate, new_sample_rate)
48
  print("the speech is", speech)
49
- input_values = processor(resampled, return_tensors="pt", padding=True, sampling_rate=new_sample_rate) # there is no truncation param anymore
50
  print("input values", input_values)
51
  # import pdb
52
  # pdb.set_trace()
53
 
54
  with torch.no_grad():
55
  encoded_states = model(
56
- **input_values,
57
  # attention_mask=input_values["attention_mask"],
58
  output_hidden_states=True
59
  )
60
-
61
  last_hidden_state = encoded_states.hidden_states[-1] # The last hidden-state is the first element of the output tuple
62
  print("getting wav2vec2 embeddings")
63
  print(last_hidden_state)
 
40
  value = librosa.resample(y, sample_rate, new_sample_rate)
41
  return value
42
 
43
+ def stereo_to_mono(audio_input):
44
+ X = audio_input.mean(axis=1, keepdims=True)
45
+ X = np.squeeze(X)
46
+ return X
47
+
48
  def get_wav2vecembeddings_from_audiofile(wav_file):
49
  print("the file is", wav_file)
50
  speech, sample_rate = sf.read(wav_file)
51
+
52
+ if len(speech.shape) > 1:
53
+ speech = stereo_to_mono(speech)
54
  # change sample rate to 16 000 hertz
55
  resampled = change_sample_rate(speech, sample_rate, new_sample_rate)
56
  print("the speech is", speech)
57
+ input_values = processor(wav_file, return_tensors="pt", padding=True, sampling_rate=new_sample_rate) # there is no truncation param anymore
58
  print("input values", input_values)
59
  # import pdb
60
  # pdb.set_trace()
61
 
62
  with torch.no_grad():
63
  encoded_states = model(
64
+ input_values=input_values["input_ids"],
65
  # attention_mask=input_values["attention_mask"],
66
  output_hidden_states=True
67
  )
 
68
  last_hidden_state = encoded_states.hidden_states[-1] # The last hidden-state is the first element of the output tuple
69
  print("getting wav2vec2 embeddings")
70
  print(last_hidden_state)