SLPL
/

SaraSadeghi commited on
Commit
4a36ad9
1 Parent(s): feb891f

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +44 -2
README.md CHANGED
@@ -76,7 +76,49 @@ print(prediction[0])
76
  ```
77
 
78
  ## Evaluation
79
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  For the evaluation use the code below:
81
  ```python
82
  ?
@@ -86,7 +128,7 @@ For the evaluation use the code below:
86
 
87
  | clean | other |
88
  |---|---|
89
- | 3.4 | 8.6 |
90
 
91
 
92
  ## Citation
 
76
  ```
77
 
78
  ## Evaluation
79
+ pip install datasets
80
+ pip install transformers
81
+ import torch
82
+ import torchaudio
83
+ import librosa
84
+ from datasets import load_dataset,load_metric
85
+ import numpy as np
86
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
87
+ from transformers import Wav2Vec2ProcessorWithLM
88
+
89
+ model = Wav2Vec2ForCTC.from_pretrained("SLPL/Sharif-wav2vec2")
90
+ processor = Wav2Vec2ProcessorWithLM.from_pretrained("SLPL/Sharif-wav2vec2")
91
+
92
+ def speech_file_to_array_fn(batch):
93
+ speech_array, sampling_rate = torchaudio.load(batch["path"])
94
+ speech_array = speech_array.squeeze().numpy()
95
+ speech_array = librosa.resample(np.asarray(speech_array), sampling_rate, processor.feature_extractor.sampling_rate)
96
+ batch["speech"] = speech_array
97
+ return batch
98
+
99
+ def predict(batch):
100
+ features = processor(
101
+ batch["speech"],
102
+ sampling_rate=processor.feature_extractor.sampling_rate,
103
+ return_tensors="pt",
104
+ padding=True
105
+ )
106
+ input_values = features.input_values
107
+ attention_mask = features.attention_mask
108
+
109
+ with torch.no_grad():
110
+ logits = model(input_values, attention_mask=attention_mask).logits #when we are trying to load model with LM we have to use logits instead of argmax(logits)
111
+ batch["prediction"] = processor.batch_decode(logits.numpy()).text
112
+ return batch
113
+
114
+ dataset = load_dataset("csv", data_files={"test":"path/to/your.csv"}, delimiter=",")["test"]
115
+ dataset = dataset.map(speech_file_to_array_fn)
116
+
117
+ result = dataset.map(predict, batched=True, batch_size=4)
118
+ wer = load_metric("wer")
119
+ cer = load_metric("cer")
120
+ print("WER: {:.2f}".format(100 * wer.compute(predictions=result["prediction"], references=result["reference"])))
121
+ print("CER: {:.2f}".format(100 * cer.compute(predictions=result["prediction"], references=result["reference"])))
122
  For the evaluation use the code below:
123
  ```python
124
  ?
 
128
 
129
  | clean | other |
130
  |---|---|
131
+ | 6.0 | 16.4 |
132
 
133
 
134
  ## Citation