SaraSadeghi
commited on
Commit
•
4a36ad9
1
Parent(s):
feb891f
Update README.md
Browse files
README.md
CHANGED
@@ -76,7 +76,49 @@ print(prediction[0])
|
|
76 |
```
|
77 |
|
78 |
## Evaluation
|
79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
For the evaluation use the code below:
|
81 |
```python
|
82 |
?
|
@@ -86,7 +128,7 @@ For the evaluation use the code below:
|
|
86 |
|
87 |
| clean | other |
|
88 |
|---|---|
|
89 |
-
|
|
90 |
|
91 |
|
92 |
## Citation
|
|
|
76 |
```
|
77 |
|
78 |
## Evaluation
|
79 |
+
pip install datasets
|
80 |
+
pip install transformers
|
81 |
+
import torch
|
82 |
+
import torchaudio
|
83 |
+
import librosa
|
84 |
+
from datasets import load_dataset,load_metric
|
85 |
+
import numpy as np
|
86 |
+
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
|
87 |
+
from transformers import Wav2Vec2ProcessorWithLM
|
88 |
+
|
89 |
+
model = Wav2Vec2ForCTC.from_pretrained("SLPL/Sharif-wav2vec2")
|
90 |
+
processor = Wav2Vec2ProcessorWithLM.from_pretrained("SLPL/Sharif-wav2vec2")
|
91 |
+
|
92 |
+
def speech_file_to_array_fn(batch):
|
93 |
+
speech_array, sampling_rate = torchaudio.load(batch["path"])
|
94 |
+
speech_array = speech_array.squeeze().numpy()
|
95 |
+
speech_array = librosa.resample(np.asarray(speech_array), sampling_rate, processor.feature_extractor.sampling_rate)
|
96 |
+
batch["speech"] = speech_array
|
97 |
+
return batch
|
98 |
+
|
99 |
+
def predict(batch):
|
100 |
+
features = processor(
|
101 |
+
batch["speech"],
|
102 |
+
sampling_rate=processor.feature_extractor.sampling_rate,
|
103 |
+
return_tensors="pt",
|
104 |
+
padding=True
|
105 |
+
)
|
106 |
+
input_values = features.input_values
|
107 |
+
attention_mask = features.attention_mask
|
108 |
+
|
109 |
+
with torch.no_grad():
|
110 |
+
logits = model(input_values, attention_mask=attention_mask).logits #when we are trying to load model with LM we have to use logits instead of argmax(logits)
|
111 |
+
batch["prediction"] = processor.batch_decode(logits.numpy()).text
|
112 |
+
return batch
|
113 |
+
|
114 |
+
dataset = load_dataset("csv", data_files={"test":"path/to/your.csv"}, delimiter=",")["test"]
|
115 |
+
dataset = dataset.map(speech_file_to_array_fn)
|
116 |
+
|
117 |
+
result = dataset.map(predict, batched=True, batch_size=4)
|
118 |
+
wer = load_metric("wer")
|
119 |
+
cer = load_metric("cer")
|
120 |
+
print("WER: {:.2f}".format(100 * wer.compute(predictions=result["prediction"], references=result["reference"])))
|
121 |
+
print("CER: {:.2f}".format(100 * cer.compute(predictions=result["prediction"], references=result["reference"])))
|
122 |
For the evaluation use the code below:
|
123 |
```python
|
124 |
?
|
|
|
128 |
|
129 |
| clean | other |
|
130 |
|---|---|
|
131 |
+
| 6.0 | 16.4 |
|
132 |
|
133 |
|
134 |
## Citation
|