Upload copy_of_dl_project.py
Browse files- copy_of_dl_project.py +442 -0
copy_of_dl_project.py
ADDED
@@ -0,0 +1,442 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""Copy_of_dl_project.ipynb
|
3 |
+
|
4 |
+
Automatically generated by Colaboratory.
|
5 |
+
|
6 |
+
Original file is located at
|
7 |
+
https://colab.research.google.com/drive/1SAMY64pTPqfF7T0Slzr5K_Xeu7l43Kjj
|
8 |
+
|
9 |
+
## IGNORE EVERYTHING ABOVE THIS. ITS TOO SLOW AND PAINFUL.
|
10 |
+
|
11 |
+
## ONLY RUN CELLS BELOW THIS.
|
12 |
+
"""
|
13 |
+
|
14 |
+
# using an alternate download method
|
15 |
+
|
16 |
+
! pip install datasets
|
17 |
+
|
18 |
+
from datasets import load_dataset, Audio
|
19 |
+
|
20 |
+
# this is annanay's private token from huggingface, please do not leak/share this 🙏
|
21 |
+
access_token = "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"
|
22 |
+
gs = load_dataset("speechcolab/gigaspeech", "xs", "train[:2]", use_auth_token=access_token)
|
23 |
+
gs['train'] = gs['train'].cast_column("audio", Audio(sampling_rate=16_000))
|
24 |
+
|
25 |
+
# see structure
|
26 |
+
print(gs)
|
27 |
+
|
28 |
+
# load audio sample on the fly
|
29 |
+
audio_input = gs["train"][0]["audio"] # first decoded audio sample
|
30 |
+
transcription = gs["train"][0]["text"] # first transcription
|
31 |
+
|
32 |
+
print(len(gs["train"]), len(gs["test"]), len(gs["validation"]))
|
33 |
+
|
34 |
+
"""## phew, that took 25 mins to run!! we can't keep doing this everytime we need to run an experiment
|
35 |
+
### (^ temporarily switched to smaller split but still slow)
|
36 |
+
|
37 |
+
## anyway, for now lets investigate the dataset
|
38 |
+
|
39 |
+
"""
|
40 |
+
|
41 |
+
print("# training datapoints", len(gs["train"]), "# test datapoints", len(gs["test"]))
|
42 |
+
|
43 |
+
categories = ["People and Blogs", "Business", "Nonprofits and Activism", "Crime", "History", "Pets and Animals", "News and Politics", "Travel and Events", "Kids and Family", "Leisure", "N/A", "Comedy", "News and Politics", "Sports", "Arts", "Science and Technology", "Autos and Vehicles", "Science and Technology", "People and Blogs", "Music", "Society and Culture", "Education", "Howto and Style", "Film and Animation", "Gaming", "Entertainment", "Travel and Events", "Health and Fitness", "audiobook"]
|
44 |
+
print("number of categories", len(categories))
|
45 |
+
|
46 |
+
print("check if categories are zero indexed")
|
47 |
+
for i in range(0, 30):
|
48 |
+
print(gs["train"][i]["category"], categories[gs["train"][i]["category"]-1])
|
49 |
+
|
50 |
+
print("i see a zero in there somewhere, so yes, they are.")
|
51 |
+
|
52 |
+
print("Category", gs["train"][0]["category"], "is", categories[gs["train"][0]["category"]])
|
53 |
+
|
54 |
+
"""Ignore the next cell"""
|
55 |
+
|
56 |
+
# trying to play an audio sample
|
57 |
+
from IPython.display import Audio, display
|
58 |
+
|
59 |
+
display(Audio(gs["train"][0]["audio"]["path"], autoplay=True))
|
60 |
+
|
61 |
+
"""# Keyur's implementation
|
62 |
+
|
63 |
+
## Ok now let's load and use the wav2vec model
|
64 |
+
"""
|
65 |
+
|
66 |
+
# see https://huggingface.co/docs/transformers/model_doc/wav2vec2#transformers.Wav2Vec2CTCTokenizer.decode.example
|
67 |
+
|
68 |
+
! pip install transformers torch
|
69 |
+
|
70 |
+
from transformers import AutoTokenizer, AutoFeatureExtractor, AutoModelForCTC
|
71 |
+
import torch
|
72 |
+
|
73 |
+
# import model, feature extractor, tokenizer
|
74 |
+
model = AutoModelForCTC.from_pretrained("facebook/wav2vec2-base-960h").to("cuda")
|
75 |
+
tokenizer = AutoTokenizer.from_pretrained("facebook/wav2vec2-base-960h")
|
76 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")
|
77 |
+
|
78 |
+
# load first 10 training samples
|
79 |
+
first10 = [a['array'] for a in gs['train'][0:10]['audio']]#gs['train'][0]['audio']['array']#
|
80 |
+
|
81 |
+
# forward sample through model to get greedily predicted transcription ids
|
82 |
+
input_values = feature_extractor(
|
83 |
+
first10,
|
84 |
+
return_tensors="pt",
|
85 |
+
sampling_rate=16_000,
|
86 |
+
padding=True
|
87 |
+
).input_values.to("cuda")
|
88 |
+
logits = model(input_values).logits
|
89 |
+
|
90 |
+
pred_ids = torch.argmax(logits, axis=-1)
|
91 |
+
|
92 |
+
# Output word_offsets (i.e. recognized words and their timestamps)
|
93 |
+
# If we wanted recognized characters and their timestamps, replace with analogous output_char_offsets
|
94 |
+
outputs = tokenizer.batch_decode(pred_ids, output_word_offsets=True)
|
95 |
+
|
96 |
+
# compute `time_offset` in seconds as product of downsampling ratio and sampling_rate
|
97 |
+
time_offset = model.config.inputs_to_logits_ratio / feature_extractor.sampling_rate
|
98 |
+
|
99 |
+
for i, (text, row) in enumerate(zip(outputs.text, outputs.word_offsets)):
|
100 |
+
print(f"#{i}: {text}")
|
101 |
+
|
102 |
+
for word_offset in row:
|
103 |
+
print({
|
104 |
+
"word": word_offset["word"],
|
105 |
+
"start_time": round(word_offset["start_offset"] * time_offset, 2),
|
106 |
+
"end_time": round(word_offset["end_offset"] * time_offset, 2),
|
107 |
+
})
|
108 |
+
|
109 |
+
"""## We need to deal with this error message from loading the weights though. Are there fully-trained weights elsewhere?
|
110 |
+
|
111 |
+
```
|
112 |
+
Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
|
113 |
+
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
|
114 |
+
```
|
115 |
+
|
116 |
+
## (otherwise maybe it doesn't matter since we'll be finetuning anyway)
|
117 |
+
"""
|
118 |
+
|
119 |
+
# sample package for computing WER, there may be a faster one for batches though
|
120 |
+
! pip install torchmetrics
|
121 |
+
from torchmetrics import WordErrorRate
|
122 |
+
wer = WordErrorRate()
|
123 |
+
|
124 |
+
# remove punctuation / special marks
|
125 |
+
trues = [
|
126 |
+
" ".join(
|
127 |
+
token
|
128 |
+
for token in transcript.split()
|
129 |
+
if not (token.startswith("<") and token.endswith(">"))
|
130 |
+
).upper()
|
131 |
+
for transcript in gs['train']['text']
|
132 |
+
]
|
133 |
+
print(trues)
|
134 |
+
predicteds = [t.upper() for t in outputs.text]
|
135 |
+
print(predicteds)
|
136 |
+
print(wer(predicteds, trues).item())
|
137 |
+
|
138 |
+
"""## This WER is computed over the entire dataset. Do we want individual WER of each row instead?
|
139 |
+
|
140 |
+
# Split Datasets into categories
|
141 |
+
"""
|
142 |
+
|
143 |
+
len(gs["train"])
|
144 |
+
|
145 |
+
"""
|
146 |
+
### I'd recommend not to split the datasets (the cell below) again and again. Takes too long!"""
|
147 |
+
|
148 |
+
# split the dataset
|
149 |
+
for i in range(len(categories)):
|
150 |
+
print(i, categories[i])
|
151 |
+
|
152 |
+
# datasets
|
153 |
+
audiobooks = []
|
154 |
+
comedy = []
|
155 |
+
people_blogs = []
|
156 |
+
education = []
|
157 |
+
gaming = []
|
158 |
+
|
159 |
+
|
160 |
+
for i in range(len(gs["train"])):
|
161 |
+
if gs['train'][i]['category'] == 28:
|
162 |
+
audiobooks.append(gs['train'][i])
|
163 |
+
elif gs['train'][i]['category'] == 11:
|
164 |
+
comedy.append(gs['train'][i])
|
165 |
+
elif (gs['train'][i]['category'] == 0) or (gs['train'][i]['category'] == 18):
|
166 |
+
people_blogs.append(gs['train'][i])
|
167 |
+
elif gs['train'][i]['category'] == 21:
|
168 |
+
education.append(gs['train'][i])
|
169 |
+
elif gs['train'][i]['category'] == 24:
|
170 |
+
gaming.append(gs['train'][i])
|
171 |
+
|
172 |
+
# could have written code for the split, but i did it manually.
|
173 |
+
|
174 |
+
# train dataset (70% of the len of datasets)
|
175 |
+
audiobooks_train = audiobooks[:1648]
|
176 |
+
comedy_train = comedy[:89]
|
177 |
+
people_blogs_train = people_blogs[:244]
|
178 |
+
education_train = education[:978]
|
179 |
+
gaming_train = gaming[:713]
|
180 |
+
|
181 |
+
# test dataset (30% of the len of datasets)
|
182 |
+
audiobooks_test = audiobooks[1648:]
|
183 |
+
comedy_test = comedy[89:]
|
184 |
+
people_blogs_test = people_blogs[244:]
|
185 |
+
education_test = education[978:]
|
186 |
+
gaming_test = gaming[713:]
|
187 |
+
|
188 |
+
print(len(audiobooks_train), len(comedy_train), len(people_blogs_train), len(education_train), len(gaming_train))
|
189 |
+
|
190 |
+
print(len(audiobooks_test), len(comedy_test), len(people_blogs_test), len(education_test), len(gaming_test))
|
191 |
+
|
192 |
+
audiobooks_train[:2]
|
193 |
+
|
194 |
+
from datasets import Dataset, DatasetDict
|
195 |
+
import pandas as pd
|
196 |
+
|
197 |
+
audiobooks_train_df = pd.DataFrame(audiobooks_train)
|
198 |
+
audiobooks_test_df = pd.DataFrame(audiobooks_test)
|
199 |
+
|
200 |
+
audiobooks_train_data = Dataset.from_pandas(audiobooks_train_df)
|
201 |
+
audiobooks_test_data = Dataset.from_pandas(audiobooks_test_df)
|
202 |
+
|
203 |
+
# create a DatasetDict with two train and test datasets
|
204 |
+
DataDict_audiobooks = DatasetDict({
|
205 |
+
"train": audiobooks_train_data,
|
206 |
+
"test": audiobooks_test_data
|
207 |
+
})
|
208 |
+
|
209 |
+
DataDict_audiobooks
|
210 |
+
|
211 |
+
DataDict_audiobooks = DataDict_audiobooks.remove_columns(['segment_id', 'speaker', 'begin_time', 'end_time', 'audio_id', 'title', 'url', 'source', 'original_full_path'])
|
212 |
+
|
213 |
+
DataDict_audiobooks["train"][0].keys()
|
214 |
+
|
215 |
+
DataDict_audiobooks["train"][0]["audio"]
|
216 |
+
|
217 |
+
"""# Text Preprocessing"""
|
218 |
+
|
219 |
+
import re
|
220 |
+
chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"\<.*?>]'
|
221 |
+
|
222 |
+
def remove_special_characters(batch):
|
223 |
+
batch["text"] = re.sub(chars_to_ignore_regex, '', batch["text"]).lower()
|
224 |
+
return batch
|
225 |
+
|
226 |
+
DataDict_audiobooks = DataDict_audiobooks.map(remove_special_characters)
|
227 |
+
|
228 |
+
import json
|
229 |
+
|
230 |
+
def extract_chars(batch):
|
231 |
+
all_text = " ".join(batch["text"])
|
232 |
+
vocab = list(set(all_text))
|
233 |
+
return {"vocab": [vocab], "all_text": [all_text]}
|
234 |
+
|
235 |
+
vocabs = DataDict_audiobooks.map(extract_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=DataDict_audiobooks.column_names["train"])
|
236 |
+
|
237 |
+
vocab_list = list(set(vocabs["train"]["vocab"][0]) | set(vocabs["test"]["vocab"][0]))
|
238 |
+
vocab_dict = {v: k for k, v in enumerate(vocab_list)}
|
239 |
+
vocab_dict["|"] = vocab_dict[" "]
|
240 |
+
del vocab_dict[" "]
|
241 |
+
|
242 |
+
vocab_dict["[UNK]"] = len(vocab_dict) # add "unknown" token
|
243 |
+
vocab_dict["[PAD]"] = len(vocab_dict) # add a padding token that corresponds to CTC's "blank token"
|
244 |
+
|
245 |
+
with open('vocab.json', 'w') as vocab_file:
|
246 |
+
json.dump(vocab_dict, vocab_file)
|
247 |
+
|
248 |
+
vocab_dict
|
249 |
+
|
250 |
+
DataDict_audiobooks
|
251 |
+
|
252 |
+
import random
|
253 |
+
import numpy as np
|
254 |
+
rand_int = random.randint(0, len(DataDict_audiobooks["train"]))
|
255 |
+
|
256 |
+
print("Target text:", DataDict_audiobooks["train"][rand_int]["text"])
|
257 |
+
print("Input array shape:", np.asarray(DataDict_audiobooks["train"][rand_int]["audio"]["array"]).shape)
|
258 |
+
print("Sampling rate:", DataDict_audiobooks["train"][rand_int]["audio"]["sampling_rate"])
|
259 |
+
|
260 |
+
"""# Time to fine-tune"""
|
261 |
+
|
262 |
+
!pip install transformers
|
263 |
+
|
264 |
+
from transformers import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor, Wav2Vec2Processor
|
265 |
+
|
266 |
+
tokenizer = Wav2Vec2CTCTokenizer("./vocab.json", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")
|
267 |
+
feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=False)
|
268 |
+
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
|
269 |
+
|
270 |
+
DataDict_audiobooks
|
271 |
+
|
272 |
+
DataDict_audiobooks["train"][0]['audio']['sampling_rate']
|
273 |
+
|
274 |
+
print(type(DataDict_audiobooks["train"][0]["audio"]))
|
275 |
+
|
276 |
+
def modify_sample(sample):
|
277 |
+
sample['audio'] = sample['audio']['array']
|
278 |
+
return sample
|
279 |
+
|
280 |
+
temp = DataDict_audiobooks.map(modify_sample)
|
281 |
+
|
282 |
+
temp['train'][0]['audio']
|
283 |
+
|
284 |
+
def prepare_dataset(batch):
|
285 |
+
# print(type(batch), batch.keys())
|
286 |
+
x = batch["audio"]
|
287 |
+
# print("SIZE: ", len(x))
|
288 |
+
|
289 |
+
# y = x["array"]
|
290 |
+
z = 16000
|
291 |
+
|
292 |
+
batch["input_values"] = processor(x, sampling_rate=z).input_values
|
293 |
+
|
294 |
+
with processor.as_target_processor():
|
295 |
+
batch["labels"] = processor(batch["text"]).input_ids
|
296 |
+
return batch
|
297 |
+
|
298 |
+
processor.tokenizer.is_fast
|
299 |
+
|
300 |
+
DataDict_audiobooks_prepared = temp.map(prepare_dataset, batch_size=8, batched=True)
|
301 |
+
|
302 |
+
temp
|
303 |
+
|
304 |
+
# DataDict_audiobooks_prepared = DataDict_audiobooks.map(prepare_dataset, batch_size=8, num_proc=4, batched=True)
|
305 |
+
|
306 |
+
"""# Training and Evaluation"""
|
307 |
+
|
308 |
+
import torch
|
309 |
+
|
310 |
+
from dataclasses import dataclass, field
|
311 |
+
from typing import Any, Dict, List, Optional, Union
|
312 |
+
|
313 |
+
@dataclass
|
314 |
+
class DataCollatorCTCWithPadding:
|
315 |
+
processor: Wav2Vec2Processor
|
316 |
+
padding: Union[bool, str] = True
|
317 |
+
max_length: Optional[int] = None
|
318 |
+
max_length_labels: Optional[int] = None
|
319 |
+
pad_to_multiple_of: Optional[int] = None
|
320 |
+
pad_to_multiple_of_labels: Optional[int] = None
|
321 |
+
|
322 |
+
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
323 |
+
# split inputs and labels since they have to be of different lenghts and need
|
324 |
+
# different padding methods
|
325 |
+
input_features = [{"input_values": feature["input_values"]} for feature in features]
|
326 |
+
label_features = [{"input_ids": feature["labels"]} for feature in features]
|
327 |
+
|
328 |
+
batch = self.processor.pad(
|
329 |
+
input_features,
|
330 |
+
padding=self.padding,
|
331 |
+
max_length=self.max_length,
|
332 |
+
pad_to_multiple_of=self.pad_to_multiple_of,
|
333 |
+
return_tensors="pt",
|
334 |
+
)
|
335 |
+
with self.processor.as_target_processor():
|
336 |
+
labels_batch = self.processor.pad(
|
337 |
+
label_features,
|
338 |
+
padding=self.padding,
|
339 |
+
max_length=self.max_length_labels,
|
340 |
+
pad_to_multiple_of=self.pad_to_multiple_of_labels,
|
341 |
+
return_tensors="pt",
|
342 |
+
)
|
343 |
+
|
344 |
+
# replace padding with -100 to ignore loss correctly
|
345 |
+
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
|
346 |
+
|
347 |
+
batch["labels"] = labels
|
348 |
+
|
349 |
+
return batch
|
350 |
+
|
351 |
+
data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)
|
352 |
+
|
353 |
+
!pip install jiwer
|
354 |
+
|
355 |
+
from datasets import load_dataset, load_metric
|
356 |
+
|
357 |
+
wer_metric = load_metric("wer")
|
358 |
+
|
359 |
+
def compute_metrics(pred):
|
360 |
+
pred_logits = pred.predictions
|
361 |
+
pred_ids = np.argmax(pred_logits, axis=-1)
|
362 |
+
|
363 |
+
pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id
|
364 |
+
|
365 |
+
pred_str = processor.batch_decode(pred_ids)
|
366 |
+
# we do not want to group tokens when computing the metrics
|
367 |
+
label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
|
368 |
+
|
369 |
+
wer = wer_metric.compute(predictions=pred_str, references=label_str)
|
370 |
+
|
371 |
+
return {"wer": wer}
|
372 |
+
|
373 |
+
from transformers import Wav2Vec2ForCTC
|
374 |
+
|
375 |
+
model = Wav2Vec2ForCTC.from_pretrained(
|
376 |
+
"facebook/wav2vec2-base",
|
377 |
+
gradient_checkpointing=True,
|
378 |
+
ctc_loss_reduction="mean",
|
379 |
+
pad_token_id=processor.tokenizer.pad_token_id,
|
380 |
+
)
|
381 |
+
|
382 |
+
model.freeze_feature_extractor()
|
383 |
+
|
384 |
+
from transformers import TrainingArguments
|
385 |
+
|
386 |
+
training_args = TrainingArguments(
|
387 |
+
output_dir="./wav2vec2-base-speechcolab-demo",
|
388 |
+
group_by_length=True,
|
389 |
+
per_device_train_batch_size=32,
|
390 |
+
evaluation_strategy="steps",
|
391 |
+
num_train_epochs=30,
|
392 |
+
fp16=True,
|
393 |
+
save_steps=500,
|
394 |
+
eval_steps=500,
|
395 |
+
logging_steps=100,
|
396 |
+
learning_rate=1e-4,
|
397 |
+
weight_decay=0.005,
|
398 |
+
warmup_steps=1000,
|
399 |
+
save_total_limit=2,
|
400 |
+
)
|
401 |
+
|
402 |
+
from transformers import Trainer
|
403 |
+
|
404 |
+
trainer = Trainer(
|
405 |
+
model=model,
|
406 |
+
data_collator=data_collator,
|
407 |
+
args=training_args,
|
408 |
+
compute_metrics=compute_metrics,
|
409 |
+
train_dataset=DataDict_audiobooks_prepared["train"],
|
410 |
+
eval_dataset=DataDict_audiobooks_prepared["test"],
|
411 |
+
tokenizer=processor.feature_extractor,
|
412 |
+
)
|
413 |
+
|
414 |
+
trainer.train()
|
415 |
+
|
416 |
+
"""# Evaluate"""
|
417 |
+
|
418 |
+
processor = Wav2Vec2Processor.from_pretrained("/content/wav2vec2-base-speechcolab-demo")
|
419 |
+
model = Wav2Vec2ForCTC.from_pretrained("/content/wav2vec2-base-speechcolab-demo")
|
420 |
+
|
421 |
+
def map_to_result(batch):
|
422 |
+
model.to("cuda")
|
423 |
+
input_values = processor(
|
424 |
+
batch["array"]["audio"],
|
425 |
+
sampling_rate=batch["array"]["sampling_rate"],
|
426 |
+
return_tensors="pt"
|
427 |
+
).input_values.to("cuda")
|
428 |
+
|
429 |
+
with torch.no_grad():
|
430 |
+
logits = model(input_values).logits
|
431 |
+
|
432 |
+
pred_ids = torch.argmax(logits, dim=-1)
|
433 |
+
batch["pred_str"] = processor.batch_decode(pred_ids)[0]
|
434 |
+
|
435 |
+
return batch
|
436 |
+
|
437 |
+
results = DataDict_audiobooks_prepared["test"].map(map_to_result)
|
438 |
+
|
439 |
+
# Preprocess text?
|
440 |
+
# Test dataset is empty
|
441 |
+
# Something wrong in the input I'm sending and the model is expecting. Trying to fix that
|
442 |
+
|