ales commited on
Commit
4f145b9
1 Parent(s): 4bc6ff7

updated source code

Browse files
src/readme.md CHANGED
@@ -18,11 +18,24 @@ The code in this repository is a modified version of code from
18
  ```
19
 
20
  ## Fine-tuning todos:
 
 
 
 
 
 
 
 
 
 
21
  * perform evaluation of fine-tuned model on CommonVoice test set
 
 
22
  * Learning rate:
23
  * max learning rate is not the same as LR passed as a parameter to training script. it is actually lower.
24
  * when resuming training, LR scheduling behaves incorrectly
25
  * check exact sizes of train, eval, test sets of CommonVoice 11
 
26
 
27
  ## Resuming training from exising checkpoint
28
  When resuming training from existing checkpoint:
@@ -55,6 +68,138 @@ When resuming training from existing checkpoint:
55
  How is it overwritten when resuming training from existing checkpoint?
56
  * does `ShuffleCallback` work with StreamingDataset? it reshuffles data `on_epoch_begin()`,
57
  but does StreamingDataset have any epochs?
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  ### Prepended tokens
60
  * Why are there following lines in Data Collator?
@@ -90,40 +235,4 @@ When resuming training from existing checkpoint:
90
 
91
  * We need to tell the model what language the audio corresponds to and what task it's performing during fine-tuning. This way, it learns what audio corresponds to what language, and the difference between transcribing audio vs translating it
92
 
93
- ## Notes:
94
- * using CommonVoice 11 dataset in a streaming way.<br>
95
- use `streaming=True` for train & validation & test.<br>
96
- as an alternative, we can use `streaming=False` for validation & test sets to save time on data processing.
97
- but the size of validation and test sets are unknown (need to check).
98
- it's likely they are going to be large - thus pre-download of these sets might not reduce
99
- overall fine-tuning time compared to streaming mode.
100
- * size of train set is ~370'000 audiofiles. if using `batch_size=64`, then
101
- 1 epoch will have ~5782 steps. <br>
102
- Because of `--eval_steps="1000"` will use `--max_steps="6000"` instead of `--max_steps="5800"`
103
- to have evaluation metrics computed in the end of training.
104
- * if using Google Colab, need to execute `sudo chmod -R 777 .git` inside hf repo to
105
- to set right permissions to be able to push trained models to HuggingFace Hub
106
- * Whispers BasicTextNormalizer splits words containing apostrophe:
107
- ```python
108
- > from transformers.models.whisper.english_normalizer import BasicTextNormalizer
109
- > normalizer = BasicTextNormalizer()
110
- > normalizer("раз'яднаць")
111
- 'раз яднаць'
112
- ```
113
- * That's why `BelarusianTextNormalizer` (edited version of `BasicTextNormalizer`) was added to training script:
114
- ```python
115
- > from run_speech_recognition_seq2seq_streaming import BelarusianTextNormalizer
116
- > normalizer_be = BelarusianTextNormalizer()
117
- > normalizer_be("раз'яднаць")
118
- "раз'яднаць"
119
- ```
120
- * Need to set `use_cache` to False since we're using gradient checkpointing, and the two are incompatible
121
- * Default Linear scheduler is used
122
- * Default Adam optimizer is used
123
- * To save memory (and increase either model or batch_size) can experiment with:
124
- * using Adafactor instead of Adam.
125
- Adam requires two optimiser params per one model param, but Adafactor uses only one.
126
- > A word of caution: Adafactor is untested for fine-tuning Whisper,
127
- so we are unsure sure how Adafactor performance compares to Adam!
128
- * using Adam 8bit from `bitsandbytes` module.
129
- need to provide `optim="adamw_bnb_8bit"` param to `Seq2SeqTrainingArguments`
 
18
  ```
19
 
20
  ## Fine-tuning todos:
21
+ * logs are printed only right before the evalutaion:<br>
22
+ ```
23
+ --logging_steps="50"
24
+ --eval_steps="1000"
25
+ ```
26
+ * on the next run:
27
+ * download the whole dataset before the launch.
28
+ this will probably save some time for data processing,
29
+ and allow to load and prepare data in a parallel fashion
30
+ * can also decrease eval batch size. currently it's probably causing GPU to wait for CPU to prepare a next batch
31
  * perform evaluation of fine-tuned model on CommonVoice test set
32
+ * add [Whisper fine-tuning Event repo](https://github.com/huggingface/community-events/tree/main/whisper-fine-tuning-event)
33
+ to remotes and merge updates from this original event repo
34
  * Learning rate:
35
  * max learning rate is not the same as LR passed as a parameter to training script. it is actually lower.
36
  * when resuming training, LR scheduling behaves incorrectly
37
  * check exact sizes of train, eval, test sets of CommonVoice 11
38
+ * fill TODOs in Notes section with answers and discussions from a Discord
39
 
40
  ## Resuming training from exising checkpoint
41
  When resuming training from existing checkpoint:
 
68
  How is it overwritten when resuming training from existing checkpoint?
69
  * does `ShuffleCallback` work with StreamingDataset? it reshuffles data `on_epoch_begin()`,
70
  but does StreamingDataset have any epochs?
71
+ * does streaming mode support parallel data load and processing?<br>
72
+ when using non-streaming mode we can use `dataset.map(..., num_proc=<num_proc>)`
73
+
74
+
75
+ ## Notes:
76
+ * using CommonVoice 11 dataset in a streaming way.<br>
77
+ use `streaming=True` for train & validation & test.<br>
78
+ as an alternative, we can use `streaming=False` for validation & test sets to save time on data processing.
79
+ but the size of validation and test sets are unknown (need to check).
80
+ it's likely they are going to be large - thus pre-download of these sets might not reduce
81
+ overall fine-tuning time compared to streaming mode.
82
+ * size of train set is ~370'000 audiofiles. if using `batch_size=64`, then
83
+ 1 epoch will have ~5782 steps. <br>
84
+ Because of `--eval_steps="1000"` will use `--max_steps="6000"` instead of `--max_steps="5800"`
85
+ to have evaluation metrics computed in the end of training.
86
+ * if using Google Colab, need to execute `sudo chmod -R 777 .git` inside hf repo to
87
+ to set right permissions to be able to push trained models to HuggingFace Hub
88
+ * Log tracking in Jupyter (not working) and in bash (works as expected with `tee`)
89
+ * Loggers in `run_speech.....py` do not control `transformers` and `datasets` loggers.
90
+ can't redirect their outputs using handlers. it's better and easier to redirect output in a bash
91
+ * Need to set `use_cache` to False since we're using gradient checkpointing, and the two are incompatible
92
+ * Default Linear scheduler is used
93
+ * Default Adam optimizer is used
94
+
95
+ ### Logs not printed when expected
96
+ * Train logs are printed only before start of a validation.
97
+ During training they are not printed to a stdout.
98
+ All worked fine in a Colab.
99
+ * No progressbar for validation (at least when using streaming and iterable dataset).
100
+ possible reason is that when using streaming, the dataset len in unknown.
101
+ * Evaluation metrics get printed to stdout only before the next validation call.
102
+ All worked fine in a Colab.
103
+ * Possible reason: usage of `... | tee file.log`. But it's unlikely
104
+
105
+ ### Text normalization
106
+ * Whispers BasicTextNormalizer splits words containing apostrophe:
107
+ ```python
108
+ > from transformers.models.whisper.english_normalizer import BasicTextNormalizer
109
+ > normalizer = BasicTextNormalizer()
110
+ > normalizer("раз'яднаць")
111
+ 'раз яднаць'
112
+ ```
113
+ * That's why `BelarusianTextNormalizer` (edited version of `BasicTextNormalizer`) was added to training script:
114
+ ```python
115
+ > from run_speech_recognition_seq2seq_streaming import BelarusianTextNormalizer
116
+ > normalizer_be = BelarusianTextNormalizer()
117
+ > normalizer_be("раз'яднаць")
118
+ "раз'яднаць"
119
+ ```
120
+
121
+ ### Different batch sizes for train and evaluation:
122
+ * Theoretically you can use a larger batch size for evaluation vs training!
123
+ * Training: we do a forward pass, storing all the activations, and then a backwards pass, storing all the gradients
124
+ * Inference (evaluation): we only do a forward pass, and don't store any activations
125
+ * So the memory required for evaluation is much lower than it is for training
126
+ (we're only doing the forward pass and not storing any values)
127
+ * In my experience, altering the eval batch size has little effect on eval speed ->
128
+ I set it to a lower value as this tends to give a more responsive progress bar
129
+ when evaluating in non-streaming mode (the bar updates faster and more frequently)
130
+
131
+ ### Slow inference. Long evalutaion compared to training:
132
+ * Slower inference is an inherent limitation of the sequence-to-sequence architecture.
133
+ The auto-regressive decoding means that you have to do as many decoder forward passes as tokens generated.
134
+ * This is much slower than CTC, where you do a single encoder forward pass
135
+ * Note that 1 evaluation step **will take much longer** than 1 training step, even with the same batch sizes.
136
+ * With training, we do one forward pass of the encoder, one forward pass of the decoder,
137
+ one backward pass of the decoder and one backward pass of the encoder (=4 passes total):<br>
138
+ ```
139
+ audio -> encoder -> decoder -> labels
140
+ encoder <- decoder <- loss
141
+ ```
142
+ * During evaluation we do one forward pass of the encoder, and then auto-regressively generate tokens in the decoder.
143
+ Here, we do as many forward passes of the decoder as tokens generated.
144
+ So in total, we do one forward pass of the encoder, and N forward passes of the decoder,
145
+ where N is the number of tokens generated (can be up to the max length, which is 448...).
146
+ You can see that for 4 or more generated tokens, evaluation is going to be slower than training:<br>
147
+ ```
148
+ audio -> encoder -> decoder -> decoder -> decoder -> ... -> decoder -> end of sentence token
149
+ ```
150
+ * I've made a bit of a simplification here in saying that one forward pass
151
+ takes the same amount of time as one backward pass, but for the purpose of illustrating,
152
+ this demonstrates the point why evaluation is much slower than training
153
+ * Essentially it doesn't really matter what you set your eval batch size as we're not aggregating any statistics
154
+ over the eval batch (in contrast during training we evaluate a true gradient value based on a given batch).
155
+ * Since we just do a forward pass, we could even run eval with a batch size of 1 and get exactly the same results!
156
+ * Because we don't get much of an improvement with batch sizes beyond around 8, it's set somewhat arbitrarily
157
+
158
+ ### Ways to decrease evaluation time during fine-tuning:
159
+ * reduce `generation_max_length` param:
160
+ * During training, we can limit the generation max length to a lower number to cut-off the generation
161
+ after fewer tokens (e.g. 40). This will give worse results during training,
162
+ but we can still infer the evolution of WER performance over training.
163
+ * For the final eval step, we can bump up the generation max length back up to 448.
164
+ * WER performance varies monotonically with generation max length
165
+ (WER can only stay equal or improve by increasing generation max length),
166
+ so we know that our final eval WER will be less than (improved) or equal to the WER during training
167
+ * We can evaluate at less frequent eval_steps: this reduces the number of times we have to perform evaluation
168
+
169
+ ### Decrease inference time more generally
170
+ * PyTorch 2.0 and compiling the model could get you a decent speed-up
171
+ (https://pytorch.org/blog/Accelerating-Hugging-Face-and-TIMM-models/#hugging-face-models)
172
+ * Downcasting to fp16
173
+
174
+ ### Memory saving and training larger models:
175
+ To save memory (and increase either model or batch_size) can experiment with:
176
+ * using Adafactor instead of Adam.
177
+ Adam requires two optimiser params per one model param, but Adafactor uses only one.
178
+ > A word of caution: Adafactor is untested for fine-tuning Whisper,
179
+ so we are unsure sure how Adafactor performance compares to Adam!
180
+ * using Adam 8bit from `bitsandbytes` module.
181
+ need to provide `optim="adamw_bnb_8bit"` param to `Seq2SeqTrainingArguments`
182
+ * use `deepspeed`. scripts are there in
183
+ [Whisper fine-tuning Event repo](https://github.com/huggingface/community-events/tree/main/whisper-fine-tuning-event)
184
+ * load the model and processor in 8bit mode:
185
+ ```python
186
+ from transformers import WhisperForConditionalGeneration, WhisperProcessor
187
+ model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large", device_map="auto", load_in_8bit=True)
188
+ processor = WhisperProcessor.from_pretrained("openai/whisper-large", load_in_8bit=True)
189
+ ```
190
+ inference loop:
191
+ ```python
192
+ for data in dataset:
193
+ inputs = processor.feature_extractor(data["audio"]["array"], return_tensors="pt", sampling_rate=16_000).input_features.half().to(device)
194
+ forced_decoder_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe")
195
+ predicted_ids = model.generate(inputs, forced_decoder_ids=forced_decoder_ids)
196
+ text = processor.tokenizer.batch_decode(predicted_ids, skip_special_tokens=True, normalize=False)[0]
197
+ print(text)
198
+ ```
199
+ * 8bit will slower iference compared to full/half-precision
200
+ * But the memory saving you get is immense (up to 4x vs full-precision).<br>
201
+ This is the recommended approach when you're limited on VRAM.<br>
202
+ If you care about inference speed, still to full precision
203
 
204
  ### Prepended tokens
205
  * Why are there following lines in Data Collator?
 
235
 
236
  * We need to tell the model what language the audio corresponds to and what task it's performing during fine-tuning. This way, it learns what audio corresponds to what language, and the difference between transcribing audio vs translating it
237
 
238
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/run_base.sh ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ python src/run_speech_recognition_seq2seq_streaming.py \
2
+ --model_name_or_path="openai/whisper-base" \
3
+ --dataset_name="mozilla-foundation/common_voice_11_0" \
4
+ --dataset_config_name="be" \
5
+ --language="be" \
6
+ --train_split_name="train" \
7
+ --eval_split_name="validation" \
8
+ --model_index_name="Whisper Base Belarusian" \
9
+ \
10
+ --max_steps="6000" \
11
+ --output_dir="./" \
12
+ --per_device_train_batch_size="64" \
13
+ --per_device_eval_batch_size="32" \
14
+ --logging_steps="50" \
15
+ --logging_first_step \
16
+ --learning_rate="1e-4" \
17
+ --warmup_steps="500" \
18
+ --evaluation_strategy="steps" \
19
+ --eval_steps="1000" \
20
+ --save_strategy="steps" \
21
+ --save_steps="1000" \
22
+ --gradient_checkpointing \
23
+ --fp16 \
24
+ \
25
+ --shuffle_buffer_size="500" \
26
+ --generation_max_length="225" \
27
+ --max_duration_in_seconds="30" \
28
+ --text_column_name="sentence" \
29
+ --freeze_feature_encoder="False" \
30
+ --report_to="tensorboard" \
31
+ --metric_for_best_model="wer" \
32
+ --greater_is_better="False" \
33
+ --load_best_model_at_end \
34
+ \
35
+ --do_train \
36
+ --do_eval \
37
+ --ignore_data_skip \
38
+ --predict_with_generate \
39
+ --do_normalize_eval \
40
+ --streaming_train="True" \
41
+ --streaming_eval="False" \
42
+ --use_auth_token \
43
+ --push_to_hub \
44
+ --hub_model_id="ales/whisper-base-belarusian"
src/{run.sh → run_small.sh} RENAMED
@@ -7,10 +7,10 @@ python src/run_speech_recognition_seq2seq_streaming.py \
7
  --eval_split_name="validation" \
8
  --model_index_name="Whisper Small Belarusian" \
9
  \
10
- --max_steps="12000" \
11
  --output_dir="./" \
12
  --per_device_train_batch_size="64" \
13
- --per_device_eval_batch_size="64" \
14
  --logging_steps="50" \
15
  --logging_first_step \
16
  --learning_rate="1e-4" \
@@ -34,10 +34,12 @@ python src/run_speech_recognition_seq2seq_streaming.py \
34
  \
35
  --do_train \
36
  --do_eval \
 
37
  --ignore_data_skip \
38
  --predict_with_generate \
39
  --do_normalize_eval \
40
- --streaming \
 
41
  --use_auth_token \
42
  --push_to_hub \
43
  --hub_model_id="ales/whisper-small-belarusian"
 
7
  --eval_split_name="validation" \
8
  --model_index_name="Whisper Small Belarusian" \
9
  \
10
+ --max_steps="18000" \
11
  --output_dir="./" \
12
  --per_device_train_batch_size="64" \
13
+ --per_device_eval_batch_size="32" \
14
  --logging_steps="50" \
15
  --logging_first_step \
16
  --learning_rate="1e-4" \
 
34
  \
35
  --do_train \
36
  --do_eval \
37
+ --resume_from_checkpoint="checkpoint-12000" \
38
  --ignore_data_skip \
39
  --predict_with_generate \
40
  --do_normalize_eval \
41
+ --streaming_train="True" \
42
+ --streaming_eval="False" \
43
  --use_auth_token \
44
  --push_to_hub \
45
  --hub_model_id="ales/whisper-small-belarusian"
src/run_speech_recognition_seq2seq_streaming.py CHANGED
@@ -220,9 +220,13 @@ class DataTrainingArguments:
220
  )
221
  },
222
  )
223
- streaming: bool = field(
224
  default=True,
225
- metadata={"help": "Whether to use streaming mode to load and pre-process the data."},
 
 
 
 
226
  )
227
 
228
 
@@ -360,12 +364,14 @@ def main():
360
  f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
361
  f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
362
  )
 
363
  logger.info(f"Training/evaluation parameters {training_args}")
 
 
364
 
365
  # Set the verbosity to info of the Transformers logger (on main process only):
366
  if is_main_process(training_args.local_rank):
367
  transformers.utils.logging.set_verbosity_info()
368
- logger.info("Training/evaluation parameters %s", training_args)
369
 
370
  # 3. Detecting last checkpoint and eventually continue from last checkpoint
371
  last_checkpoint = None
@@ -423,27 +429,31 @@ def main():
423
  set_seed(training_args.seed)
424
 
425
  # 4. Load dataset
426
- raw_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
 
 
 
 
427
 
428
  if training_args.do_train:
429
- raw_datasets["train"] = load_maybe_streaming_dataset(
430
  data_args.dataset_name,
431
  data_args.dataset_config_name,
432
  split=data_args.train_split_name,
433
  use_auth_token=True if model_args.use_auth_token else None,
434
- streaming=data_args.streaming,
435
  )
436
 
437
  if training_args.do_eval:
438
- raw_datasets["eval"] = load_maybe_streaming_dataset(
439
  data_args.dataset_name,
440
  data_args.dataset_config_name,
441
  split=data_args.eval_split_name,
442
  use_auth_token=True if model_args.use_auth_token else None,
443
- streaming=data_args.streaming,
444
  )
445
 
446
- raw_datasets_features = list(next(iter(raw_datasets.values())).features.keys())
447
 
448
  if data_args.audio_column_name not in raw_datasets_features:
449
  raise ValueError(
@@ -510,7 +520,13 @@ def main():
510
  tokenizer.set_prefix_tokens(language=data_args.language, task=data_args.task)
511
 
512
  # 6. Explicitly resample speech dataset
513
- raw_datasets = raw_datasets.cast_column(
 
 
 
 
 
 
514
  data_args.audio_column_name, datasets.features.Audio(
515
  sampling_rate=feature_extractor.sampling_rate,
516
  mono=True
@@ -531,60 +547,84 @@ def main():
531
  normalizer = BelarusianTextNormalizer() # custom normalizer based on 'official' text normalizer from OpenAI
532
 
533
  if data_args.max_train_samples is not None:
534
- raw_datasets["train"] = (
535
- raw_datasets["train"].take(data_args.max_train_samples)
536
- if data_args.streaming
537
- else raw_datasets["train"].select(range(data_args.max_train_samples))
538
  )
539
 
540
  if data_args.max_eval_samples is not None:
541
- raw_datasets["eval"] = (
542
- raw_datasets["eval"].take(data_args.max_eval_samples)
543
- if data_args.streaming
544
- else raw_datasets["eval"].select(range(data_args.max_eval_samples))
545
  )
546
 
547
- def prepare_dataset(batch, labels_max_len: int = None):
548
  # process audio
549
- sample = batch[audio_column_name]
550
- inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
551
  # process audio length
552
- batch[model_input_name] = inputs.get(model_input_name)[0]
553
- batch["input_length"] = len(sample["array"])
554
 
555
  # process targets
556
- input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name]
557
  if do_remove_punctuation:
558
  input_str = normalizer(input_str).strip()
559
- batch['labels'] = tokenizer(input_str).input_ids
560
- batch['labels_length'] = len(batch['labels']) # include special characters
561
 
562
- batch['labels_truncated'] = 0
563
  # need to truncate validation and test labels that are longer that model.config.max_length.
564
  # can't drop such examples because this will affect validation and test scores.
565
  # thus need to truncate.
566
  if labels_max_len is not None:
567
- if len(batch['labels']) > labels_max_len:
568
- batch['labels'] = batch['labels'][:labels_max_len]
569
- batch['labels_truncated'] = 1
570
 
571
- return batch
572
 
573
  with training_args.main_process_first(desc="dataset map pre-processing"):
574
- vectorized_datasets = IterableDatasetDict()
575
-
576
- vectorized_datasets['train'] = raw_datasets['train'].map(
577
- prepare_dataset, remove_columns=raw_datasets_features,
578
- fn_kwargs=dict(labels_max_len=None),
579
- ).with_format("torch")
580
- vectorized_datasets['eval'] = raw_datasets['eval'].map(
581
- prepare_dataset, remove_columns=raw_datasets_features,
582
- fn_kwargs=dict(labels_max_len=max_labels_length),
583
- ).with_format("torch")
584
-
585
- if training_args.do_train and data_args.streaming:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
586
  # manually shuffle if streaming (done by the trainer for non-streaming)
587
- vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(
588
  buffer_size=data_args.shuffle_buffer_size,
589
  seed=training_args.seed,
590
  )
@@ -601,11 +641,11 @@ def main():
601
  if training_args.do_train:
602
  # Filter items from train set only.
603
  # Should keep them in eval set not to affect eval metrics.
604
- vectorized_datasets["train"] = vectorized_datasets["train"].filter(
605
  is_audio_in_length_range,
606
  input_columns=["input_length"],
607
  )
608
- vectorized_datasets["train"] = vectorized_datasets["train"].filter(
609
  are_labels_in_length_range,
610
  input_columns=["labels_length"],
611
  )
@@ -657,18 +697,20 @@ def main():
657
  if isinstance(train_dataloader.dataset, IterableDatasetShard):
658
  pass # set_epoch() is handled by the Trainer
659
  elif isinstance(train_dataloader.dataset, IterableDataset):
 
 
660
  train_dataloader.dataset.set_epoch(train_dataloader.dataset._epoch + 1)
661
 
662
  # Initialize Trainer
663
  trainer = Seq2SeqTrainer(
664
  model=model,
665
  args=training_args,
666
- train_dataset=vectorized_datasets["train"] if training_args.do_train else None,
667
- eval_dataset=vectorized_datasets["eval"] if training_args.do_eval else None,
668
  tokenizer=processor,
669
  data_collator=data_collator,
670
  compute_metrics=compute_metrics if training_args.predict_with_generate else None,
671
- callbacks=[ShuffleCallback()] if data_args.streaming else None,
672
  )
673
 
674
  # 12. Training
 
220
  )
221
  },
222
  )
223
+ streaming_train: bool = field(
224
  default=True,
225
+ metadata={"help": "Whether to use streaming mode to load and pre-process the train split."},
226
+ )
227
+ streaming_eval: bool = field(
228
+ default=True,
229
+ metadata={"help": "Whether to use streaming mode to load and pre-process the evaluation split."},
230
  )
231
 
232
 
 
364
  f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
365
  f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
366
  )
367
+
368
  logger.info(f"Training/evaluation parameters {training_args}")
369
+ logger.info(f"Data parameters: {data_args}")
370
+ logger.info(f"Model parameters: {model_args}")
371
 
372
  # Set the verbosity to info of the Transformers logger (on main process only):
373
  if is_main_process(training_args.local_rank):
374
  transformers.utils.logging.set_verbosity_info()
 
375
 
376
  # 3. Detecting last checkpoint and eventually continue from last checkpoint
377
  last_checkpoint = None
 
429
  set_seed(training_args.seed)
430
 
431
  # 4. Load dataset
432
+
433
+ # TODO: replace dataset dicts with single key to IterableDataset and to Dataset.
434
+ # don't know how to do it know - using dict simply because they work.
435
+ raw_train = IterableDatasetDict() if data_args.streaming_train else DatasetDict()
436
+ raw_eval = IterableDatasetDict() if data_args.streaming_eval else DatasetDict()
437
 
438
  if training_args.do_train:
439
+ raw_train['train'] = load_maybe_streaming_dataset(
440
  data_args.dataset_name,
441
  data_args.dataset_config_name,
442
  split=data_args.train_split_name,
443
  use_auth_token=True if model_args.use_auth_token else None,
444
+ streaming=data_args.streaming_train,
445
  )
446
 
447
  if training_args.do_eval:
448
+ raw_eval['eval'] = load_maybe_streaming_dataset(
449
  data_args.dataset_name,
450
  data_args.dataset_config_name,
451
  split=data_args.eval_split_name,
452
  use_auth_token=True if model_args.use_auth_token else None,
453
+ streaming=data_args.streaming_eval,
454
  )
455
 
456
+ raw_datasets_features = list(next(iter(raw_train.values())).features.keys())
457
 
458
  if data_args.audio_column_name not in raw_datasets_features:
459
  raise ValueError(
 
520
  tokenizer.set_prefix_tokens(language=data_args.language, task=data_args.task)
521
 
522
  # 6. Explicitly resample speech dataset
523
+ raw_train = raw_train.cast_column(
524
+ data_args.audio_column_name, datasets.features.Audio(
525
+ sampling_rate=feature_extractor.sampling_rate,
526
+ mono=True
527
+ )
528
+ )
529
+ raw_eval = raw_eval.cast_column(
530
  data_args.audio_column_name, datasets.features.Audio(
531
  sampling_rate=feature_extractor.sampling_rate,
532
  mono=True
 
547
  normalizer = BelarusianTextNormalizer() # custom normalizer based on 'official' text normalizer from OpenAI
548
 
549
  if data_args.max_train_samples is not None:
550
+ raw_train['train'] = (
551
+ raw_train['train'].take(data_args.max_train_samples)
552
+ if data_args.streaming_train
553
+ else raw_train['train'].select(range(data_args.max_train_samples))
554
  )
555
 
556
  if data_args.max_eval_samples is not None:
557
+ raw_eval['eval'] = (
558
+ raw_eval['eval'].take(data_args.max_eval_samples)
559
+ if data_args.streaming_eval
560
+ else raw_eval['eval'].select(range(data_args.max_eval_samples))
561
  )
562
 
563
+ def prepare_dataset(sample, labels_max_len: int = None):
564
  # process audio
565
+ audio = sample[audio_column_name]
566
+ inputs = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"])
567
  # process audio length
568
+ sample[model_input_name] = inputs.get(model_input_name)[0]
569
+ sample["input_length"] = len(audio["array"])
570
 
571
  # process targets
572
+ input_str = sample[text_column_name].lower() if do_lower_case else sample[text_column_name]
573
  if do_remove_punctuation:
574
  input_str = normalizer(input_str).strip()
575
+ sample['labels'] = tokenizer(input_str).input_ids
576
+ sample['labels_length'] = len(sample['labels']) # include special characters
577
 
578
+ sample['labels_truncated'] = 0
579
  # need to truncate validation and test labels that are longer that model.config.max_length.
580
  # can't drop such examples because this will affect validation and test scores.
581
  # thus need to truncate.
582
  if labels_max_len is not None:
583
+ if len(sample['labels']) > labels_max_len:
584
+ sample['labels'] = sample['labels'][:labels_max_len]
585
+ sample['labels_truncated'] = 1
586
 
587
+ return sample
588
 
589
  with training_args.main_process_first(desc="dataset map pre-processing"):
590
+ logger.info(f'vectorizing dataset')
591
+
592
+ # TODO: replace dataset dicts with single key to IterableDataset and to Dataset.
593
+ # don't know how to do it know - using dict simply because they work.
594
+ vectorized_train = IterableDatasetDict() if data_args.streaming_train else DatasetDict()
595
+ vectorized_eval = IterableDatasetDict() if data_args.streaming_eval else DatasetDict()
596
+
597
+ num_proc = None
598
+ if data_args.streaming_train or data_args.streaming_eval:
599
+ logger.info(f'will preprocess data using {num_proc} processes.')
600
+
601
+ if data_args.streaming_train:
602
+ vectorized_train['train'] = raw_train['train'].map(
603
+ prepare_dataset, remove_columns=raw_datasets_features,
604
+ fn_kwargs=dict(labels_max_len=None),
605
+ ).with_format("torch")
606
+ else:
607
+ vectorized_train['train'] = raw_train['train'].map(
608
+ prepare_dataset, remove_columns=raw_datasets_features,
609
+ num_proc=num_proc,
610
+ fn_kwargs=dict(labels_max_len=None),
611
+ ).with_format("torch")
612
+
613
+ if data_args.streaming_eval:
614
+ vectorized_eval['eval'] = raw_eval['eval'].map(
615
+ prepare_dataset, remove_columns=raw_datasets_features,
616
+ fn_kwargs=dict(labels_max_len=max_labels_length),
617
+ ).with_format("torch")
618
+ else:
619
+ vectorized_eval['eval'] = raw_eval['eval'].map(
620
+ prepare_dataset, remove_columns=raw_datasets_features,
621
+ num_proc=num_proc,
622
+ fn_kwargs=dict(labels_max_len=max_labels_length),
623
+ ).with_format("torch")
624
+
625
+ if training_args.do_train and data_args.streaming_train:
626
  # manually shuffle if streaming (done by the trainer for non-streaming)
627
+ vectorized_train['train'] = vectorized_train['train'].shuffle(
628
  buffer_size=data_args.shuffle_buffer_size,
629
  seed=training_args.seed,
630
  )
 
641
  if training_args.do_train:
642
  # Filter items from train set only.
643
  # Should keep them in eval set not to affect eval metrics.
644
+ vectorized_train['train'] = vectorized_train['train'].filter(
645
  is_audio_in_length_range,
646
  input_columns=["input_length"],
647
  )
648
+ vectorized_train['train'] = vectorized_train['train'].filter(
649
  are_labels_in_length_range,
650
  input_columns=["labels_length"],
651
  )
 
697
  if isinstance(train_dataloader.dataset, IterableDatasetShard):
698
  pass # set_epoch() is handled by the Trainer
699
  elif isinstance(train_dataloader.dataset, IterableDataset):
700
+ logger.info(f'ShuffleCallback. shuffling train dataset. '
701
+ f'seed: {training_args.seed}. dataset epoch: {train_dataloader.dataset._epoch}')
702
  train_dataloader.dataset.set_epoch(train_dataloader.dataset._epoch + 1)
703
 
704
  # Initialize Trainer
705
  trainer = Seq2SeqTrainer(
706
  model=model,
707
  args=training_args,
708
+ train_dataset=vectorized_train['train'] if training_args.do_train else None,
709
+ eval_dataset=vectorized_eval['eval'] if training_args.do_eval else None,
710
  tokenizer=processor,
711
  data_collator=data_collator,
712
  compute_metrics=compute_metrics if training_args.predict_with_generate else None,
713
+ callbacks=[ShuffleCallback()] if data_args.streaming_train else None,
714
  )
715
 
716
  # 12. Training
src/{run_debug.sh → run_tiny_debug.sh} RENAMED
File without changes