updated source code
Browse files- src/readme.md +146 -37
- src/run_base.sh +44 -0
- src/{run.sh → run_small.sh} +5 -3
- src/run_speech_recognition_seq2seq_streaming.py +91 -49
- src/{run_debug.sh → run_tiny_debug.sh} +0 -0
src/readme.md
CHANGED
@@ -18,11 +18,24 @@ The code in this repository is a modified version of code from
|
|
18 |
```
|
19 |
|
20 |
## Fine-tuning todos:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
* perform evaluation of fine-tuned model on CommonVoice test set
|
|
|
|
|
22 |
* Learning rate:
|
23 |
* max learning rate is not the same as LR passed as a parameter to training script. it is actually lower.
|
24 |
* when resuming training, LR scheduling behaves incorrectly
|
25 |
* check exact sizes of train, eval, test sets of CommonVoice 11
|
|
|
26 |
|
27 |
## Resuming training from exising checkpoint
|
28 |
When resuming training from existing checkpoint:
|
@@ -55,6 +68,138 @@ When resuming training from existing checkpoint:
|
|
55 |
How is it overwritten when resuming training from existing checkpoint?
|
56 |
* does `ShuffleCallback` work with StreamingDataset? it reshuffles data `on_epoch_begin()`,
|
57 |
but does StreamingDataset have any epochs?
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
### Prepended tokens
|
60 |
* Why are there following lines in Data Collator?
|
@@ -90,40 +235,4 @@ When resuming training from existing checkpoint:
|
|
90 |
|
91 |
* We need to tell the model what language the audio corresponds to and what task it's performing during fine-tuning. This way, it learns what audio corresponds to what language, and the difference between transcribing audio vs translating it
|
92 |
|
93 |
-
|
94 |
-
* using CommonVoice 11 dataset in a streaming way.<br>
|
95 |
-
use `streaming=True` for train & validation & test.<br>
|
96 |
-
as an alternative, we can use `streaming=False` for validation & test sets to save time on data processing.
|
97 |
-
but the size of validation and test sets are unknown (need to check).
|
98 |
-
it's likely they are going to be large - thus pre-download of these sets might not reduce
|
99 |
-
overall fine-tuning time compared to streaming mode.
|
100 |
-
* size of train set is ~370'000 audiofiles. if using `batch_size=64`, then
|
101 |
-
1 epoch will have ~5782 steps. <br>
|
102 |
-
Because of `--eval_steps="1000"` will use `--max_steps="6000"` instead of `--max_steps="5800"`
|
103 |
-
to have evaluation metrics computed in the end of training.
|
104 |
-
* if using Google Colab, need to execute `sudo chmod -R 777 .git` inside hf repo to
|
105 |
-
to set right permissions to be able to push trained models to HuggingFace Hub
|
106 |
-
* Whispers BasicTextNormalizer splits words containing apostrophe:
|
107 |
-
```python
|
108 |
-
> from transformers.models.whisper.english_normalizer import BasicTextNormalizer
|
109 |
-
> normalizer = BasicTextNormalizer()
|
110 |
-
> normalizer("раз'яднаць")
|
111 |
-
'раз яднаць'
|
112 |
-
```
|
113 |
-
* That's why `BelarusianTextNormalizer` (edited version of `BasicTextNormalizer`) was added to training script:
|
114 |
-
```python
|
115 |
-
> from run_speech_recognition_seq2seq_streaming import BelarusianTextNormalizer
|
116 |
-
> normalizer_be = BelarusianTextNormalizer()
|
117 |
-
> normalizer_be("раз'яднаць")
|
118 |
-
"раз'яднаць"
|
119 |
-
```
|
120 |
-
* Need to set `use_cache` to False since we're using gradient checkpointing, and the two are incompatible
|
121 |
-
* Default Linear scheduler is used
|
122 |
-
* Default Adam optimizer is used
|
123 |
-
* To save memory (and increase either model or batch_size) can experiment with:
|
124 |
-
* using Adafactor instead of Adam.
|
125 |
-
Adam requires two optimiser params per one model param, but Adafactor uses only one.
|
126 |
-
> A word of caution: Adafactor is untested for fine-tuning Whisper,
|
127 |
-
so we are unsure sure how Adafactor performance compares to Adam!
|
128 |
-
* using Adam 8bit from `bitsandbytes` module.
|
129 |
-
need to provide `optim="adamw_bnb_8bit"` param to `Seq2SeqTrainingArguments`
|
|
|
18 |
```
|
19 |
|
20 |
## Fine-tuning todos:
|
21 |
+
* logs are printed only right before the evalutaion:<br>
|
22 |
+
```
|
23 |
+
--logging_steps="50"
|
24 |
+
--eval_steps="1000"
|
25 |
+
```
|
26 |
+
* on the next run:
|
27 |
+
* download the whole dataset before the launch.
|
28 |
+
this will probably save some time for data processing,
|
29 |
+
and allow to load and prepare data in a parallel fashion
|
30 |
+
* can also decrease eval batch size. currently it's probably causing GPU to wait for CPU to prepare a next batch
|
31 |
* perform evaluation of fine-tuned model on CommonVoice test set
|
32 |
+
* add [Whisper fine-tuning Event repo](https://github.com/huggingface/community-events/tree/main/whisper-fine-tuning-event)
|
33 |
+
to remotes and merge updates from this original event repo
|
34 |
* Learning rate:
|
35 |
* max learning rate is not the same as LR passed as a parameter to training script. it is actually lower.
|
36 |
* when resuming training, LR scheduling behaves incorrectly
|
37 |
* check exact sizes of train, eval, test sets of CommonVoice 11
|
38 |
+
* fill TODOs in Notes section with answers and discussions from a Discord
|
39 |
|
40 |
## Resuming training from exising checkpoint
|
41 |
When resuming training from existing checkpoint:
|
|
|
68 |
How is it overwritten when resuming training from existing checkpoint?
|
69 |
* does `ShuffleCallback` work with StreamingDataset? it reshuffles data `on_epoch_begin()`,
|
70 |
but does StreamingDataset have any epochs?
|
71 |
+
* does streaming mode support parallel data load and processing?<br>
|
72 |
+
when using non-streaming mode we can use `dataset.map(..., num_proc=<num_proc>)`
|
73 |
+
|
74 |
+
|
75 |
+
## Notes:
|
76 |
+
* using CommonVoice 11 dataset in a streaming way.<br>
|
77 |
+
use `streaming=True` for train & validation & test.<br>
|
78 |
+
as an alternative, we can use `streaming=False` for validation & test sets to save time on data processing.
|
79 |
+
but the size of validation and test sets are unknown (need to check).
|
80 |
+
it's likely they are going to be large - thus pre-download of these sets might not reduce
|
81 |
+
overall fine-tuning time compared to streaming mode.
|
82 |
+
* size of train set is ~370'000 audiofiles. if using `batch_size=64`, then
|
83 |
+
1 epoch will have ~5782 steps. <br>
|
84 |
+
Because of `--eval_steps="1000"` will use `--max_steps="6000"` instead of `--max_steps="5800"`
|
85 |
+
to have evaluation metrics computed in the end of training.
|
86 |
+
* if using Google Colab, need to execute `sudo chmod -R 777 .git` inside hf repo to
|
87 |
+
to set right permissions to be able to push trained models to HuggingFace Hub
|
88 |
+
* Log tracking in Jupyter (not working) and in bash (works as expected with `tee`)
|
89 |
+
* Loggers in `run_speech.....py` do not control `transformers` and `datasets` loggers.
|
90 |
+
can't redirect their outputs using handlers. it's better and easier to redirect output in a bash
|
91 |
+
* Need to set `use_cache` to False since we're using gradient checkpointing, and the two are incompatible
|
92 |
+
* Default Linear scheduler is used
|
93 |
+
* Default Adam optimizer is used
|
94 |
+
|
95 |
+
### Logs not printed when expected
|
96 |
+
* Train logs are printed only before start of a validation.
|
97 |
+
During training they are not printed to a stdout.
|
98 |
+
All worked fine in a Colab.
|
99 |
+
* No progressbar for validation (at least when using streaming and iterable dataset).
|
100 |
+
possible reason is that when using streaming, the dataset len in unknown.
|
101 |
+
* Evaluation metrics get printed to stdout only before the next validation call.
|
102 |
+
All worked fine in a Colab.
|
103 |
+
* Possible reason: usage of `... | tee file.log`. But it's unlikely
|
104 |
+
|
105 |
+
### Text normalization
|
106 |
+
* Whispers BasicTextNormalizer splits words containing apostrophe:
|
107 |
+
```python
|
108 |
+
> from transformers.models.whisper.english_normalizer import BasicTextNormalizer
|
109 |
+
> normalizer = BasicTextNormalizer()
|
110 |
+
> normalizer("раз'яднаць")
|
111 |
+
'раз яднаць'
|
112 |
+
```
|
113 |
+
* That's why `BelarusianTextNormalizer` (edited version of `BasicTextNormalizer`) was added to training script:
|
114 |
+
```python
|
115 |
+
> from run_speech_recognition_seq2seq_streaming import BelarusianTextNormalizer
|
116 |
+
> normalizer_be = BelarusianTextNormalizer()
|
117 |
+
> normalizer_be("раз'яднаць")
|
118 |
+
"раз'яднаць"
|
119 |
+
```
|
120 |
+
|
121 |
+
### Different batch sizes for train and evaluation:
|
122 |
+
* Theoretically you can use a larger batch size for evaluation vs training!
|
123 |
+
* Training: we do a forward pass, storing all the activations, and then a backwards pass, storing all the gradients
|
124 |
+
* Inference (evaluation): we only do a forward pass, and don't store any activations
|
125 |
+
* So the memory required for evaluation is much lower than it is for training
|
126 |
+
(we're only doing the forward pass and not storing any values)
|
127 |
+
* In my experience, altering the eval batch size has little effect on eval speed ->
|
128 |
+
I set it to a lower value as this tends to give a more responsive progress bar
|
129 |
+
when evaluating in non-streaming mode (the bar updates faster and more frequently)
|
130 |
+
|
131 |
+
### Slow inference. Long evalutaion compared to training:
|
132 |
+
* Slower inference is an inherent limitation of the sequence-to-sequence architecture.
|
133 |
+
The auto-regressive decoding means that you have to do as many decoder forward passes as tokens generated.
|
134 |
+
* This is much slower than CTC, where you do a single encoder forward pass
|
135 |
+
* Note that 1 evaluation step **will take much longer** than 1 training step, even with the same batch sizes.
|
136 |
+
* With training, we do one forward pass of the encoder, one forward pass of the decoder,
|
137 |
+
one backward pass of the decoder and one backward pass of the encoder (=4 passes total):<br>
|
138 |
+
```
|
139 |
+
audio -> encoder -> decoder -> labels
|
140 |
+
encoder <- decoder <- loss
|
141 |
+
```
|
142 |
+
* During evaluation we do one forward pass of the encoder, and then auto-regressively generate tokens in the decoder.
|
143 |
+
Here, we do as many forward passes of the decoder as tokens generated.
|
144 |
+
So in total, we do one forward pass of the encoder, and N forward passes of the decoder,
|
145 |
+
where N is the number of tokens generated (can be up to the max length, which is 448...).
|
146 |
+
You can see that for 4 or more generated tokens, evaluation is going to be slower than training:<br>
|
147 |
+
```
|
148 |
+
audio -> encoder -> decoder -> decoder -> decoder -> ... -> decoder -> end of sentence token
|
149 |
+
```
|
150 |
+
* I've made a bit of a simplification here in saying that one forward pass
|
151 |
+
takes the same amount of time as one backward pass, but for the purpose of illustrating,
|
152 |
+
this demonstrates the point why evaluation is much slower than training
|
153 |
+
* Essentially it doesn't really matter what you set your eval batch size as we're not aggregating any statistics
|
154 |
+
over the eval batch (in contrast during training we evaluate a true gradient value based on a given batch).
|
155 |
+
* Since we just do a forward pass, we could even run eval with a batch size of 1 and get exactly the same results!
|
156 |
+
* Because we don't get much of an improvement with batch sizes beyond around 8, it's set somewhat arbitrarily
|
157 |
+
|
158 |
+
### Ways to decrease evaluation time during fine-tuning:
|
159 |
+
* reduce `generation_max_length` param:
|
160 |
+
* During training, we can limit the generation max length to a lower number to cut-off the generation
|
161 |
+
after fewer tokens (e.g. 40). This will give worse results during training,
|
162 |
+
but we can still infer the evolution of WER performance over training.
|
163 |
+
* For the final eval step, we can bump up the generation max length back up to 448.
|
164 |
+
* WER performance varies monotonically with generation max length
|
165 |
+
(WER can only stay equal or improve by increasing generation max length),
|
166 |
+
so we know that our final eval WER will be less than (improved) or equal to the WER during training
|
167 |
+
* We can evaluate at less frequent eval_steps: this reduces the number of times we have to perform evaluation
|
168 |
+
|
169 |
+
### Decrease inference time more generally
|
170 |
+
* PyTorch 2.0 and compiling the model could get you a decent speed-up
|
171 |
+
(https://pytorch.org/blog/Accelerating-Hugging-Face-and-TIMM-models/#hugging-face-models)
|
172 |
+
* Downcasting to fp16
|
173 |
+
|
174 |
+
### Memory saving and training larger models:
|
175 |
+
To save memory (and increase either model or batch_size) can experiment with:
|
176 |
+
* using Adafactor instead of Adam.
|
177 |
+
Adam requires two optimiser params per one model param, but Adafactor uses only one.
|
178 |
+
> A word of caution: Adafactor is untested for fine-tuning Whisper,
|
179 |
+
so we are unsure sure how Adafactor performance compares to Adam!
|
180 |
+
* using Adam 8bit from `bitsandbytes` module.
|
181 |
+
need to provide `optim="adamw_bnb_8bit"` param to `Seq2SeqTrainingArguments`
|
182 |
+
* use `deepspeed`. scripts are there in
|
183 |
+
[Whisper fine-tuning Event repo](https://github.com/huggingface/community-events/tree/main/whisper-fine-tuning-event)
|
184 |
+
* load the model and processor in 8bit mode:
|
185 |
+
```python
|
186 |
+
from transformers import WhisperForConditionalGeneration, WhisperProcessor
|
187 |
+
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large", device_map="auto", load_in_8bit=True)
|
188 |
+
processor = WhisperProcessor.from_pretrained("openai/whisper-large", load_in_8bit=True)
|
189 |
+
```
|
190 |
+
inference loop:
|
191 |
+
```python
|
192 |
+
for data in dataset:
|
193 |
+
inputs = processor.feature_extractor(data["audio"]["array"], return_tensors="pt", sampling_rate=16_000).input_features.half().to(device)
|
194 |
+
forced_decoder_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe")
|
195 |
+
predicted_ids = model.generate(inputs, forced_decoder_ids=forced_decoder_ids)
|
196 |
+
text = processor.tokenizer.batch_decode(predicted_ids, skip_special_tokens=True, normalize=False)[0]
|
197 |
+
print(text)
|
198 |
+
```
|
199 |
+
* 8bit will slower iference compared to full/half-precision
|
200 |
+
* But the memory saving you get is immense (up to 4x vs full-precision).<br>
|
201 |
+
This is the recommended approach when you're limited on VRAM.<br>
|
202 |
+
If you care about inference speed, still to full precision
|
203 |
|
204 |
### Prepended tokens
|
205 |
* Why are there following lines in Data Collator?
|
|
|
235 |
|
236 |
* We need to tell the model what language the audio corresponds to and what task it's performing during fine-tuning. This way, it learns what audio corresponds to what language, and the difference between transcribing audio vs translating it
|
237 |
|
238 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/run_base.sh
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
python src/run_speech_recognition_seq2seq_streaming.py \
|
2 |
+
--model_name_or_path="openai/whisper-base" \
|
3 |
+
--dataset_name="mozilla-foundation/common_voice_11_0" \
|
4 |
+
--dataset_config_name="be" \
|
5 |
+
--language="be" \
|
6 |
+
--train_split_name="train" \
|
7 |
+
--eval_split_name="validation" \
|
8 |
+
--model_index_name="Whisper Base Belarusian" \
|
9 |
+
\
|
10 |
+
--max_steps="6000" \
|
11 |
+
--output_dir="./" \
|
12 |
+
--per_device_train_batch_size="64" \
|
13 |
+
--per_device_eval_batch_size="32" \
|
14 |
+
--logging_steps="50" \
|
15 |
+
--logging_first_step \
|
16 |
+
--learning_rate="1e-4" \
|
17 |
+
--warmup_steps="500" \
|
18 |
+
--evaluation_strategy="steps" \
|
19 |
+
--eval_steps="1000" \
|
20 |
+
--save_strategy="steps" \
|
21 |
+
--save_steps="1000" \
|
22 |
+
--gradient_checkpointing \
|
23 |
+
--fp16 \
|
24 |
+
\
|
25 |
+
--shuffle_buffer_size="500" \
|
26 |
+
--generation_max_length="225" \
|
27 |
+
--max_duration_in_seconds="30" \
|
28 |
+
--text_column_name="sentence" \
|
29 |
+
--freeze_feature_encoder="False" \
|
30 |
+
--report_to="tensorboard" \
|
31 |
+
--metric_for_best_model="wer" \
|
32 |
+
--greater_is_better="False" \
|
33 |
+
--load_best_model_at_end \
|
34 |
+
\
|
35 |
+
--do_train \
|
36 |
+
--do_eval \
|
37 |
+
--ignore_data_skip \
|
38 |
+
--predict_with_generate \
|
39 |
+
--do_normalize_eval \
|
40 |
+
--streaming_train="True" \
|
41 |
+
--streaming_eval="False" \
|
42 |
+
--use_auth_token \
|
43 |
+
--push_to_hub \
|
44 |
+
--hub_model_id="ales/whisper-base-belarusian"
|
src/{run.sh → run_small.sh}
RENAMED
@@ -7,10 +7,10 @@ python src/run_speech_recognition_seq2seq_streaming.py \
|
|
7 |
--eval_split_name="validation" \
|
8 |
--model_index_name="Whisper Small Belarusian" \
|
9 |
\
|
10 |
-
--max_steps="
|
11 |
--output_dir="./" \
|
12 |
--per_device_train_batch_size="64" \
|
13 |
-
--per_device_eval_batch_size="
|
14 |
--logging_steps="50" \
|
15 |
--logging_first_step \
|
16 |
--learning_rate="1e-4" \
|
@@ -34,10 +34,12 @@ python src/run_speech_recognition_seq2seq_streaming.py \
|
|
34 |
\
|
35 |
--do_train \
|
36 |
--do_eval \
|
|
|
37 |
--ignore_data_skip \
|
38 |
--predict_with_generate \
|
39 |
--do_normalize_eval \
|
40 |
-
--
|
|
|
41 |
--use_auth_token \
|
42 |
--push_to_hub \
|
43 |
--hub_model_id="ales/whisper-small-belarusian"
|
|
|
7 |
--eval_split_name="validation" \
|
8 |
--model_index_name="Whisper Small Belarusian" \
|
9 |
\
|
10 |
+
--max_steps="18000" \
|
11 |
--output_dir="./" \
|
12 |
--per_device_train_batch_size="64" \
|
13 |
+
--per_device_eval_batch_size="32" \
|
14 |
--logging_steps="50" \
|
15 |
--logging_first_step \
|
16 |
--learning_rate="1e-4" \
|
|
|
34 |
\
|
35 |
--do_train \
|
36 |
--do_eval \
|
37 |
+
--resume_from_checkpoint="checkpoint-12000" \
|
38 |
--ignore_data_skip \
|
39 |
--predict_with_generate \
|
40 |
--do_normalize_eval \
|
41 |
+
--streaming_train="True" \
|
42 |
+
--streaming_eval="False" \
|
43 |
--use_auth_token \
|
44 |
--push_to_hub \
|
45 |
--hub_model_id="ales/whisper-small-belarusian"
|
src/run_speech_recognition_seq2seq_streaming.py
CHANGED
@@ -220,9 +220,13 @@ class DataTrainingArguments:
|
|
220 |
)
|
221 |
},
|
222 |
)
|
223 |
-
|
224 |
default=True,
|
225 |
-
metadata={"help": "Whether to use streaming mode to load and pre-process the
|
|
|
|
|
|
|
|
|
226 |
)
|
227 |
|
228 |
|
@@ -360,12 +364,14 @@ def main():
|
|
360 |
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
|
361 |
f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
362 |
)
|
|
|
363 |
logger.info(f"Training/evaluation parameters {training_args}")
|
|
|
|
|
364 |
|
365 |
# Set the verbosity to info of the Transformers logger (on main process only):
|
366 |
if is_main_process(training_args.local_rank):
|
367 |
transformers.utils.logging.set_verbosity_info()
|
368 |
-
logger.info("Training/evaluation parameters %s", training_args)
|
369 |
|
370 |
# 3. Detecting last checkpoint and eventually continue from last checkpoint
|
371 |
last_checkpoint = None
|
@@ -423,27 +429,31 @@ def main():
|
|
423 |
set_seed(training_args.seed)
|
424 |
|
425 |
# 4. Load dataset
|
426 |
-
|
|
|
|
|
|
|
|
|
427 |
|
428 |
if training_args.do_train:
|
429 |
-
|
430 |
data_args.dataset_name,
|
431 |
data_args.dataset_config_name,
|
432 |
split=data_args.train_split_name,
|
433 |
use_auth_token=True if model_args.use_auth_token else None,
|
434 |
-
streaming=data_args.
|
435 |
)
|
436 |
|
437 |
if training_args.do_eval:
|
438 |
-
|
439 |
data_args.dataset_name,
|
440 |
data_args.dataset_config_name,
|
441 |
split=data_args.eval_split_name,
|
442 |
use_auth_token=True if model_args.use_auth_token else None,
|
443 |
-
streaming=data_args.
|
444 |
)
|
445 |
|
446 |
-
raw_datasets_features = list(next(iter(
|
447 |
|
448 |
if data_args.audio_column_name not in raw_datasets_features:
|
449 |
raise ValueError(
|
@@ -510,7 +520,13 @@ def main():
|
|
510 |
tokenizer.set_prefix_tokens(language=data_args.language, task=data_args.task)
|
511 |
|
512 |
# 6. Explicitly resample speech dataset
|
513 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
514 |
data_args.audio_column_name, datasets.features.Audio(
|
515 |
sampling_rate=feature_extractor.sampling_rate,
|
516 |
mono=True
|
@@ -531,60 +547,84 @@ def main():
|
|
531 |
normalizer = BelarusianTextNormalizer() # custom normalizer based on 'official' text normalizer from OpenAI
|
532 |
|
533 |
if data_args.max_train_samples is not None:
|
534 |
-
|
535 |
-
|
536 |
-
if data_args.
|
537 |
-
else
|
538 |
)
|
539 |
|
540 |
if data_args.max_eval_samples is not None:
|
541 |
-
|
542 |
-
|
543 |
-
if data_args.
|
544 |
-
else
|
545 |
)
|
546 |
|
547 |
-
def prepare_dataset(
|
548 |
# process audio
|
549 |
-
|
550 |
-
inputs = feature_extractor(
|
551 |
# process audio length
|
552 |
-
|
553 |
-
|
554 |
|
555 |
# process targets
|
556 |
-
input_str =
|
557 |
if do_remove_punctuation:
|
558 |
input_str = normalizer(input_str).strip()
|
559 |
-
|
560 |
-
|
561 |
|
562 |
-
|
563 |
# need to truncate validation and test labels that are longer that model.config.max_length.
|
564 |
# can't drop such examples because this will affect validation and test scores.
|
565 |
# thus need to truncate.
|
566 |
if labels_max_len is not None:
|
567 |
-
if len(
|
568 |
-
|
569 |
-
|
570 |
|
571 |
-
return
|
572 |
|
573 |
with training_args.main_process_first(desc="dataset map pre-processing"):
|
574 |
-
|
575 |
-
|
576 |
-
|
577 |
-
|
578 |
-
|
579 |
-
).
|
580 |
-
|
581 |
-
|
582 |
-
|
583 |
-
|
584 |
-
|
585 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
586 |
# manually shuffle if streaming (done by the trainer for non-streaming)
|
587 |
-
|
588 |
buffer_size=data_args.shuffle_buffer_size,
|
589 |
seed=training_args.seed,
|
590 |
)
|
@@ -601,11 +641,11 @@ def main():
|
|
601 |
if training_args.do_train:
|
602 |
# Filter items from train set only.
|
603 |
# Should keep them in eval set not to affect eval metrics.
|
604 |
-
|
605 |
is_audio_in_length_range,
|
606 |
input_columns=["input_length"],
|
607 |
)
|
608 |
-
|
609 |
are_labels_in_length_range,
|
610 |
input_columns=["labels_length"],
|
611 |
)
|
@@ -657,18 +697,20 @@ def main():
|
|
657 |
if isinstance(train_dataloader.dataset, IterableDatasetShard):
|
658 |
pass # set_epoch() is handled by the Trainer
|
659 |
elif isinstance(train_dataloader.dataset, IterableDataset):
|
|
|
|
|
660 |
train_dataloader.dataset.set_epoch(train_dataloader.dataset._epoch + 1)
|
661 |
|
662 |
# Initialize Trainer
|
663 |
trainer = Seq2SeqTrainer(
|
664 |
model=model,
|
665 |
args=training_args,
|
666 |
-
train_dataset=
|
667 |
-
eval_dataset=
|
668 |
tokenizer=processor,
|
669 |
data_collator=data_collator,
|
670 |
compute_metrics=compute_metrics if training_args.predict_with_generate else None,
|
671 |
-
callbacks=[ShuffleCallback()] if data_args.
|
672 |
)
|
673 |
|
674 |
# 12. Training
|
|
|
220 |
)
|
221 |
},
|
222 |
)
|
223 |
+
streaming_train: bool = field(
|
224 |
default=True,
|
225 |
+
metadata={"help": "Whether to use streaming mode to load and pre-process the train split."},
|
226 |
+
)
|
227 |
+
streaming_eval: bool = field(
|
228 |
+
default=True,
|
229 |
+
metadata={"help": "Whether to use streaming mode to load and pre-process the evaluation split."},
|
230 |
)
|
231 |
|
232 |
|
|
|
364 |
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
|
365 |
f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
366 |
)
|
367 |
+
|
368 |
logger.info(f"Training/evaluation parameters {training_args}")
|
369 |
+
logger.info(f"Data parameters: {data_args}")
|
370 |
+
logger.info(f"Model parameters: {model_args}")
|
371 |
|
372 |
# Set the verbosity to info of the Transformers logger (on main process only):
|
373 |
if is_main_process(training_args.local_rank):
|
374 |
transformers.utils.logging.set_verbosity_info()
|
|
|
375 |
|
376 |
# 3. Detecting last checkpoint and eventually continue from last checkpoint
|
377 |
last_checkpoint = None
|
|
|
429 |
set_seed(training_args.seed)
|
430 |
|
431 |
# 4. Load dataset
|
432 |
+
|
433 |
+
# TODO: replace dataset dicts with single key to IterableDataset and to Dataset.
|
434 |
+
# don't know how to do it know - using dict simply because they work.
|
435 |
+
raw_train = IterableDatasetDict() if data_args.streaming_train else DatasetDict()
|
436 |
+
raw_eval = IterableDatasetDict() if data_args.streaming_eval else DatasetDict()
|
437 |
|
438 |
if training_args.do_train:
|
439 |
+
raw_train['train'] = load_maybe_streaming_dataset(
|
440 |
data_args.dataset_name,
|
441 |
data_args.dataset_config_name,
|
442 |
split=data_args.train_split_name,
|
443 |
use_auth_token=True if model_args.use_auth_token else None,
|
444 |
+
streaming=data_args.streaming_train,
|
445 |
)
|
446 |
|
447 |
if training_args.do_eval:
|
448 |
+
raw_eval['eval'] = load_maybe_streaming_dataset(
|
449 |
data_args.dataset_name,
|
450 |
data_args.dataset_config_name,
|
451 |
split=data_args.eval_split_name,
|
452 |
use_auth_token=True if model_args.use_auth_token else None,
|
453 |
+
streaming=data_args.streaming_eval,
|
454 |
)
|
455 |
|
456 |
+
raw_datasets_features = list(next(iter(raw_train.values())).features.keys())
|
457 |
|
458 |
if data_args.audio_column_name not in raw_datasets_features:
|
459 |
raise ValueError(
|
|
|
520 |
tokenizer.set_prefix_tokens(language=data_args.language, task=data_args.task)
|
521 |
|
522 |
# 6. Explicitly resample speech dataset
|
523 |
+
raw_train = raw_train.cast_column(
|
524 |
+
data_args.audio_column_name, datasets.features.Audio(
|
525 |
+
sampling_rate=feature_extractor.sampling_rate,
|
526 |
+
mono=True
|
527 |
+
)
|
528 |
+
)
|
529 |
+
raw_eval = raw_eval.cast_column(
|
530 |
data_args.audio_column_name, datasets.features.Audio(
|
531 |
sampling_rate=feature_extractor.sampling_rate,
|
532 |
mono=True
|
|
|
547 |
normalizer = BelarusianTextNormalizer() # custom normalizer based on 'official' text normalizer from OpenAI
|
548 |
|
549 |
if data_args.max_train_samples is not None:
|
550 |
+
raw_train['train'] = (
|
551 |
+
raw_train['train'].take(data_args.max_train_samples)
|
552 |
+
if data_args.streaming_train
|
553 |
+
else raw_train['train'].select(range(data_args.max_train_samples))
|
554 |
)
|
555 |
|
556 |
if data_args.max_eval_samples is not None:
|
557 |
+
raw_eval['eval'] = (
|
558 |
+
raw_eval['eval'].take(data_args.max_eval_samples)
|
559 |
+
if data_args.streaming_eval
|
560 |
+
else raw_eval['eval'].select(range(data_args.max_eval_samples))
|
561 |
)
|
562 |
|
563 |
+
def prepare_dataset(sample, labels_max_len: int = None):
|
564 |
# process audio
|
565 |
+
audio = sample[audio_column_name]
|
566 |
+
inputs = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"])
|
567 |
# process audio length
|
568 |
+
sample[model_input_name] = inputs.get(model_input_name)[0]
|
569 |
+
sample["input_length"] = len(audio["array"])
|
570 |
|
571 |
# process targets
|
572 |
+
input_str = sample[text_column_name].lower() if do_lower_case else sample[text_column_name]
|
573 |
if do_remove_punctuation:
|
574 |
input_str = normalizer(input_str).strip()
|
575 |
+
sample['labels'] = tokenizer(input_str).input_ids
|
576 |
+
sample['labels_length'] = len(sample['labels']) # include special characters
|
577 |
|
578 |
+
sample['labels_truncated'] = 0
|
579 |
# need to truncate validation and test labels that are longer that model.config.max_length.
|
580 |
# can't drop such examples because this will affect validation and test scores.
|
581 |
# thus need to truncate.
|
582 |
if labels_max_len is not None:
|
583 |
+
if len(sample['labels']) > labels_max_len:
|
584 |
+
sample['labels'] = sample['labels'][:labels_max_len]
|
585 |
+
sample['labels_truncated'] = 1
|
586 |
|
587 |
+
return sample
|
588 |
|
589 |
with training_args.main_process_first(desc="dataset map pre-processing"):
|
590 |
+
logger.info(f'vectorizing dataset')
|
591 |
+
|
592 |
+
# TODO: replace dataset dicts with single key to IterableDataset and to Dataset.
|
593 |
+
# don't know how to do it know - using dict simply because they work.
|
594 |
+
vectorized_train = IterableDatasetDict() if data_args.streaming_train else DatasetDict()
|
595 |
+
vectorized_eval = IterableDatasetDict() if data_args.streaming_eval else DatasetDict()
|
596 |
+
|
597 |
+
num_proc = None
|
598 |
+
if data_args.streaming_train or data_args.streaming_eval:
|
599 |
+
logger.info(f'will preprocess data using {num_proc} processes.')
|
600 |
+
|
601 |
+
if data_args.streaming_train:
|
602 |
+
vectorized_train['train'] = raw_train['train'].map(
|
603 |
+
prepare_dataset, remove_columns=raw_datasets_features,
|
604 |
+
fn_kwargs=dict(labels_max_len=None),
|
605 |
+
).with_format("torch")
|
606 |
+
else:
|
607 |
+
vectorized_train['train'] = raw_train['train'].map(
|
608 |
+
prepare_dataset, remove_columns=raw_datasets_features,
|
609 |
+
num_proc=num_proc,
|
610 |
+
fn_kwargs=dict(labels_max_len=None),
|
611 |
+
).with_format("torch")
|
612 |
+
|
613 |
+
if data_args.streaming_eval:
|
614 |
+
vectorized_eval['eval'] = raw_eval['eval'].map(
|
615 |
+
prepare_dataset, remove_columns=raw_datasets_features,
|
616 |
+
fn_kwargs=dict(labels_max_len=max_labels_length),
|
617 |
+
).with_format("torch")
|
618 |
+
else:
|
619 |
+
vectorized_eval['eval'] = raw_eval['eval'].map(
|
620 |
+
prepare_dataset, remove_columns=raw_datasets_features,
|
621 |
+
num_proc=num_proc,
|
622 |
+
fn_kwargs=dict(labels_max_len=max_labels_length),
|
623 |
+
).with_format("torch")
|
624 |
+
|
625 |
+
if training_args.do_train and data_args.streaming_train:
|
626 |
# manually shuffle if streaming (done by the trainer for non-streaming)
|
627 |
+
vectorized_train['train'] = vectorized_train['train'].shuffle(
|
628 |
buffer_size=data_args.shuffle_buffer_size,
|
629 |
seed=training_args.seed,
|
630 |
)
|
|
|
641 |
if training_args.do_train:
|
642 |
# Filter items from train set only.
|
643 |
# Should keep them in eval set not to affect eval metrics.
|
644 |
+
vectorized_train['train'] = vectorized_train['train'].filter(
|
645 |
is_audio_in_length_range,
|
646 |
input_columns=["input_length"],
|
647 |
)
|
648 |
+
vectorized_train['train'] = vectorized_train['train'].filter(
|
649 |
are_labels_in_length_range,
|
650 |
input_columns=["labels_length"],
|
651 |
)
|
|
|
697 |
if isinstance(train_dataloader.dataset, IterableDatasetShard):
|
698 |
pass # set_epoch() is handled by the Trainer
|
699 |
elif isinstance(train_dataloader.dataset, IterableDataset):
|
700 |
+
logger.info(f'ShuffleCallback. shuffling train dataset. '
|
701 |
+
f'seed: {training_args.seed}. dataset epoch: {train_dataloader.dataset._epoch}')
|
702 |
train_dataloader.dataset.set_epoch(train_dataloader.dataset._epoch + 1)
|
703 |
|
704 |
# Initialize Trainer
|
705 |
trainer = Seq2SeqTrainer(
|
706 |
model=model,
|
707 |
args=training_args,
|
708 |
+
train_dataset=vectorized_train['train'] if training_args.do_train else None,
|
709 |
+
eval_dataset=vectorized_eval['eval'] if training_args.do_eval else None,
|
710 |
tokenizer=processor,
|
711 |
data_collator=data_collator,
|
712 |
compute_metrics=compute_metrics if training_args.predict_with_generate else None,
|
713 |
+
callbacks=[ShuffleCallback()] if data_args.streaming_train else None,
|
714 |
)
|
715 |
|
716 |
# 12. Training
|
src/{run_debug.sh → run_tiny_debug.sh}
RENAMED
File without changes
|