Commit
·
5e05341
1
Parent(s):
4f87524
reset script
Browse files
run_speech_recognition_seq2seq_streaming.py
CHANGED
@@ -20,10 +20,8 @@ with 🤗 Datasets' streaming mode.
|
|
20 |
# You can also adapt this script for your own sequence to sequence speech
|
21 |
# recognition task. Pointers for this are left as comments.
|
22 |
|
23 |
-
import json
|
24 |
import logging
|
25 |
import os
|
26 |
-
import subprocess
|
27 |
import sys
|
28 |
from dataclasses import dataclass, field
|
29 |
from typing import Any, Dict, List, Optional, Union
|
@@ -49,12 +47,12 @@ from transformers import (
|
|
49 |
set_seed,
|
50 |
)
|
51 |
from transformers.models.whisper.english_normalizer import BasicTextNormalizer
|
52 |
-
from transformers.models.whisper.tokenization_whisper import TO_LANGUAGE_CODE, LANGUAGES
|
53 |
from transformers.trainer_pt_utils import IterableDatasetShard
|
54 |
from transformers.trainer_utils import get_last_checkpoint, is_main_process
|
55 |
from transformers.utils import check_min_version, send_example_telemetry
|
56 |
from transformers.utils.versions import require_version
|
57 |
|
|
|
58 |
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
59 |
check_min_version("4.25.0.dev0")
|
60 |
|
@@ -62,8 +60,6 @@ require_version("datasets>=1.18.2", "To fix: pip install -r examples/pytorch/spe
|
|
62 |
|
63 |
logger = logging.getLogger(__name__)
|
64 |
|
65 |
-
SENDING_NOTIFICATION = "*** Sending notification to email ***"
|
66 |
-
RECIPIENT_ADDRESS = "[email protected]"
|
67 |
|
68 |
wandb_token = os.environ.get("WANDB_TOKEN", "None")
|
69 |
hf_token = os.environ.get("HF_TOKEN", None)
|
@@ -165,16 +161,10 @@ class DataTrainingArguments:
|
|
165 |
Arguments pertaining to what data we are going to input our model for training and eval.
|
166 |
"""
|
167 |
|
168 |
-
|
169 |
-
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
|
170 |
-
)
|
171 |
-
dataset_train_config_name: Optional[str] = field(
|
172 |
-
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
173 |
-
)
|
174 |
-
dataset_eval_name: str = field(
|
175 |
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
|
176 |
)
|
177 |
-
|
178 |
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
179 |
)
|
180 |
text_column: Optional[str] = field(
|
@@ -243,16 +233,7 @@ class DataTrainingArguments:
|
|
243 |
default=True,
|
244 |
metadata={"help": "Whether to normalise the references and predictions in the eval WER calculation."},
|
245 |
)
|
246 |
-
|
247 |
-
default=None,
|
248 |
-
metadata={
|
249 |
-
"help": (
|
250 |
-
"Language for multilingual fine-tuning. This argument should be set for multilingual fine-tuning "
|
251 |
-
"only. For English speech recognition, it should be set to `None`."
|
252 |
-
)
|
253 |
-
},
|
254 |
-
)
|
255 |
-
language_eval: str = field(
|
256 |
default=None,
|
257 |
metadata={
|
258 |
"help": (
|
@@ -293,9 +274,6 @@ class DataCollatorSpeechSeq2SeqWithPadding:
|
|
293 |
|
294 |
processor: Any
|
295 |
decoder_start_token_id: int
|
296 |
-
task_id: int
|
297 |
-
# TODO: remove - infer language from dataset
|
298 |
-
language_id: int = -100
|
299 |
|
300 |
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
301 |
# split inputs and labels since they have to be of different lengths and need
|
@@ -303,7 +281,6 @@ class DataCollatorSpeechSeq2SeqWithPadding:
|
|
303 |
model_input_name = self.processor.model_input_names[0]
|
304 |
input_features = [{model_input_name: feature[model_input_name]} for feature in features]
|
305 |
label_features = [{"input_ids": feature["labels"]} for feature in features]
|
306 |
-
# lang_features = [f"<|{TO_LANGUAGE_CODE[feature['language']]}|>" for feature in features]
|
307 |
|
308 |
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
|
309 |
|
@@ -314,177 +291,40 @@ class DataCollatorSpeechSeq2SeqWithPadding:
|
|
314 |
|
315 |
# if bos token is appended in previous tokenization step,
|
316 |
# cut bos token here as it's append later anyways
|
317 |
-
|
318 |
-
# lang_token_ids = self.processor.tokenizer(lang_features).input_ids
|
319 |
-
# # Replace language and task if they are in the beginning, otherwise add them
|
320 |
-
# if (labels[:, 1] == self.task_id).all().cpu().item():
|
321 |
-
# labels[:, 0] = lang_token_ids
|
322 |
-
# labels[:, 1] = torch.full_like(labels[:, 1], self.task_id)
|
323 |
-
# else:
|
324 |
-
# # convert task id to tensor of labels dim to concatenate
|
325 |
-
# task_id = torch.full_like(labels[:, 0], self.task_id)
|
326 |
-
# labels = torch.cat((lang_token_ids, task_id, labels), dim=1)
|
327 |
-
|
328 |
-
# Set language to pad token
|
329 |
if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
|
330 |
-
labels
|
331 |
-
# labels[:, 0] = torch.full_like(labels[:, 0], -100)
|
332 |
-
# labels[:, 1] = torch.full_like(labels[:, 1], -100)
|
333 |
-
|
334 |
-
# remove start of sentence token from labels
|
335 |
-
# if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
|
336 |
-
# labels = labels[:, 1:]
|
337 |
-
|
338 |
-
# # add start of sentence token to labels + language + task
|
339 |
-
# labels = torch.cat((torch.full_like(labels[:, 0], self.task_id).unsqueeze(0).T, labels), dim=-1)
|
340 |
-
# labels = torch.cat((torch.full_like(labels[:, 0], self.language_id).unsqueeze(0).T, labels), dim=-1)
|
341 |
-
# labels = torch.cat((torch.full_like(labels[:, 0], self.decoder_start_token_id).unsqueeze(0).T, labels), dim=-1)
|
342 |
|
343 |
batch["labels"] = labels
|
344 |
|
345 |
return batch
|
346 |
|
347 |
|
348 |
-
def
|
349 |
-
"""
|
350 |
-
Send an email to the specified address with the specified message
|
351 |
-
"""
|
352 |
-
sender = os.environ.get("EMAIL_ADDRESS", None)
|
353 |
-
password = os.environ.get("EMAIL_PASSWORD", None)
|
354 |
-
if sender is None:
|
355 |
-
logging.warning("No email address specified, not sending notification")
|
356 |
-
if password is None:
|
357 |
-
logging.warning("No email password specified, not sending notification")
|
358 |
-
if message is None:
|
359 |
-
message = "Training is finished!"
|
360 |
-
|
361 |
-
if sender is not None:
|
362 |
-
import smtplib
|
363 |
-
from email.mime.text import MIMEText
|
364 |
-
|
365 |
-
msg = MIMEText(message)
|
366 |
-
msg["Subject"] = "Training updates..."
|
367 |
-
msg["From"] = "[email protected]"
|
368 |
-
msg["To"] = recipient
|
369 |
-
|
370 |
-
# send the email
|
371 |
-
smtp_obj = smtplib.SMTP("smtp.gmail.com", 587)
|
372 |
-
smtp_obj.starttls()
|
373 |
-
smtp_obj.login(sender, password)
|
374 |
-
smtp_obj.sendmail(sender, recipient, msg.as_string())
|
375 |
-
smtp_obj.quit()
|
376 |
-
|
377 |
-
|
378 |
-
def rename_col_and_resample(dataset, dataset_name, text_column_names, text_col_name_ref, audio_column_name, sampling_rate):
|
379 |
-
raw_datasets_features = list(dataset.features.keys())
|
380 |
-
logger.info(f"Dataset {dataset_name} - Features: {raw_datasets_features}")
|
381 |
-
|
382 |
-
if text_col_name_ref not in raw_datasets_features:
|
383 |
-
if len(text_column_names) == 1:
|
384 |
-
raise ValueError("None of the text column names provided found in dataset."
|
385 |
-
f"Text columns: {text_column_names}"
|
386 |
-
f"Dataset columns: {raw_datasets_features}")
|
387 |
-
flag = False
|
388 |
-
for text_column_name in text_column_names:
|
389 |
-
if text_column_name in raw_datasets_features:
|
390 |
-
logger.info(f"Renaming text column {text_column_name} to {text_col_name_ref}")
|
391 |
-
dataset = dataset.rename_column(text_column_name, text_col_name_ref)
|
392 |
-
flag = True
|
393 |
-
break
|
394 |
-
if flag is False:
|
395 |
-
raise ValueError("None of the text column names provided found in dataset."
|
396 |
-
f"Text columns: {text_column_names}"
|
397 |
-
f"Dataset columns: {raw_datasets_features}")
|
398 |
-
if audio_column_name is not None and sampling_rate is not None:
|
399 |
-
ds_sr = int(dataset.features[audio_column_name].sampling_rate)
|
400 |
-
if ds_sr != sampling_rate:
|
401 |
-
dataset = dataset.cast_column(
|
402 |
-
audio_column_name, datasets.features.Audio(sampling_rate=sampling_rate)
|
403 |
-
)
|
404 |
-
|
405 |
-
raw_datasets_features = list(dataset.features.keys())
|
406 |
-
raw_datasets_features.remove(audio_column_name)
|
407 |
-
raw_datasets_features.remove(text_col_name_ref)
|
408 |
-
# Keep only audio and sentence
|
409 |
-
dataset = dataset.remove_columns(column_names=raw_datasets_features)
|
410 |
-
return dataset
|
411 |
-
|
412 |
-
|
413 |
-
def load_maybe_streaming_dataset(
|
414 |
-
dataset_names,
|
415 |
-
dataset_config_names,
|
416 |
-
split="train",
|
417 |
-
streaming=True,
|
418 |
-
audio_column_name=None,
|
419 |
-
sampling_rate=None,
|
420 |
-
**kwargs
|
421 |
-
):
|
422 |
"""
|
423 |
Utility function to load a dataset in streaming mode. For datasets with multiple splits,
|
424 |
each split is loaded individually and then splits combined by taking alternating examples from
|
425 |
each (interleaving).
|
426 |
"""
|
427 |
-
|
428 |
-
if "text_column_name" in kwargs:
|
429 |
-
text_column_names = kwargs.pop("text_column_name").split(",")
|
430 |
-
text_col_name_ref = text_column_names[0]
|
431 |
-
|
432 |
-
if "," in dataset_names or "+" in split:
|
433 |
# load multiple splits separated by the `+` symbol with streaming mode
|
434 |
-
dataset_splits = [
|
435 |
-
|
436 |
-
|
437 |
-
|
438 |
-
for split_name in split_names.split("+"):
|
439 |
-
if dataset_config_name:
|
440 |
-
dataset = load_dataset(dataset_name, dataset_config_name, split=split_name, streaming=streaming, **kwargs)
|
441 |
-
else:
|
442 |
-
dataset = load_dataset(dataset_name, split=split_name, streaming=streaming, **kwargs)
|
443 |
-
|
444 |
-
dataset = rename_col_and_resample(
|
445 |
-
dataset,
|
446 |
-
dataset_name,
|
447 |
-
text_column_names,
|
448 |
-
text_col_name_ref,
|
449 |
-
audio_column_name,
|
450 |
-
sampling_rate
|
451 |
-
)
|
452 |
-
|
453 |
-
dataset_splits.append(dataset)
|
454 |
-
|
455 |
# interleave multiple splits to form one dataset
|
456 |
-
interleaved_dataset = interleave_datasets(dataset_splits
|
457 |
return interleaved_dataset
|
458 |
else:
|
459 |
# load a single split *with* streaming mode
|
460 |
-
|
461 |
-
dataset = load_dataset(dataset_names, dataset_config_names, split=split, streaming=streaming, **kwargs)
|
462 |
-
dataset = rename_col_and_resample(
|
463 |
-
dataset,
|
464 |
-
dataset_names,
|
465 |
-
text_column_names,
|
466 |
-
text_col_name_ref,
|
467 |
-
audio_column_name,
|
468 |
-
sampling_rate
|
469 |
-
)
|
470 |
return dataset
|
471 |
|
472 |
|
473 |
-
def print_data_samples(dataset, tokenizer, max_samples=5):
|
474 |
-
shown_samples = 0
|
475 |
-
for batch in dataset:
|
476 |
-
print("Target: ", tokenizer.decode(batch["labels"]))
|
477 |
-
shown_samples += len(batch)
|
478 |
-
if shown_samples >= max_samples:
|
479 |
-
break
|
480 |
-
|
481 |
-
|
482 |
def main():
|
483 |
# 1. Parse input arguments
|
484 |
# See all possible arguments in src/transformers/training_args.py
|
485 |
# or by passing the --help flag to this script.
|
486 |
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
487 |
-
logger.info("*** Parse args ***")
|
488 |
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
|
489 |
|
490 |
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
@@ -499,7 +339,6 @@ def main():
|
|
499 |
send_example_telemetry("run_speech_recognition_seq2seq_streaming", model_args, data_args)
|
500 |
|
501 |
# 2. Setup logging
|
502 |
-
logger.info("*** Setup logging ***")
|
503 |
logging.basicConfig(
|
504 |
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
505 |
datefmt="%m/%d/%Y %H:%M:%S",
|
@@ -544,94 +383,78 @@ def main():
|
|
544 |
# Set seed before initializing model.
|
545 |
set_seed(training_args.seed)
|
546 |
|
547 |
-
# Load feature extractor
|
548 |
-
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
549 |
-
model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path,
|
550 |
-
cache_dir=model_args.cache_dir,
|
551 |
-
revision=model_args.model_revision,
|
552 |
-
use_auth_token=hf_token if model_args.use_auth_token else None,
|
553 |
-
)
|
554 |
-
|
555 |
# 4. Load dataset
|
556 |
-
logger.info("*** Load dataset ***")
|
557 |
raw_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
|
558 |
|
559 |
-
if len(data_args.language_eval.split(",")) > 1:
|
560 |
-
raise ValueError("Implementation does not support multiple language evaluation.")
|
561 |
-
|
562 |
if training_args.do_train:
|
563 |
raw_datasets["train"] = load_maybe_streaming_dataset(
|
564 |
-
data_args.
|
565 |
-
data_args.
|
566 |
split=data_args.train_split_name,
|
567 |
-
use_auth_token=
|
568 |
streaming=data_args.streaming,
|
569 |
-
text_column_name=data_args.text_column_name,
|
570 |
-
audio_column_name=data_args.audio_column_name,
|
571 |
-
sampling_rate=int(feature_extractor.sampling_rate),
|
572 |
-
# language=data_args.language_train
|
573 |
)
|
574 |
|
575 |
if training_args.do_eval:
|
576 |
raw_datasets["eval"] = load_maybe_streaming_dataset(
|
577 |
-
data_args.
|
578 |
-
data_args.
|
579 |
split=data_args.eval_split_name,
|
580 |
-
use_auth_token=
|
581 |
streaming=data_args.streaming,
|
582 |
-
text_column_name=data_args.text_column_name,
|
583 |
-
audio_column_name=data_args.audio_column_name,
|
584 |
-
sampling_rate=int(feature_extractor.sampling_rate),
|
585 |
-
# language=data_args.language_eval
|
586 |
)
|
587 |
|
588 |
raw_datasets_features = list(next(iter(raw_datasets.values())).features.keys())
|
589 |
|
590 |
if data_args.audio_column_name not in raw_datasets_features:
|
591 |
raise ValueError(
|
592 |
-
f"--audio_column_name '{data_args.audio_column_name}' not found in dataset. "
|
593 |
"Make sure to set `--audio_column_name` to the correct audio column - one of "
|
594 |
f"{', '.join(raw_datasets_features)}."
|
595 |
)
|
596 |
|
597 |
-
data_args.text_column_name = data_args.text_column_name.split(",")[0]
|
598 |
if data_args.text_column_name not in raw_datasets_features:
|
599 |
raise ValueError(
|
600 |
-
f"--text_column_name {data_args.text_column_name} not found in dataset. "
|
601 |
"Make sure to set `--text_column_name` to the correct text column - one of "
|
602 |
f"{', '.join(raw_datasets_features)}."
|
603 |
)
|
604 |
|
605 |
# 5. Load pretrained model, tokenizer, and feature extractor
|
606 |
-
|
607 |
# Distributed training:
|
608 |
# The .from_pretrained methods guarantee that only one local process can concurrently
|
609 |
config = AutoConfig.from_pretrained(
|
610 |
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
|
611 |
cache_dir=model_args.cache_dir,
|
612 |
revision=model_args.model_revision,
|
613 |
-
use_auth_token=
|
614 |
)
|
615 |
|
616 |
-
# Forced decoder ids will be overwritten before evaluation
|
617 |
config.update({"forced_decoder_ids": model_args.forced_decoder_ids, "suppress_tokens": model_args.suppress_tokens})
|
618 |
|
619 |
if training_args.gradient_checkpointing:
|
620 |
config.update({"use_cache": False})
|
621 |
|
|
|
|
|
|
|
|
|
|
|
|
|
622 |
tokenizer = AutoTokenizer.from_pretrained(
|
623 |
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
|
624 |
cache_dir=model_args.cache_dir,
|
625 |
use_fast=model_args.use_fast_tokenizer,
|
626 |
revision=model_args.model_revision,
|
627 |
-
use_auth_token=
|
628 |
)
|
629 |
model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
630 |
model_args.model_name_or_path,
|
631 |
config=config,
|
632 |
cache_dir=model_args.cache_dir,
|
633 |
revision=model_args.model_revision,
|
634 |
-
use_auth_token=
|
635 |
)
|
636 |
|
637 |
if model.config.decoder_start_token_id is None:
|
@@ -642,26 +465,20 @@ def main():
|
|
642 |
|
643 |
if model_args.freeze_encoder:
|
644 |
model.freeze_encoder()
|
645 |
-
|
646 |
-
tokenizer.set_prefix_tokens(language="swedish", task=data_args.task)
|
647 |
|
648 |
-
|
649 |
-
|
650 |
-
|
651 |
-
# tokenizer.set_prefix_tokens(language=data_args.language_train, task=data_args.task)
|
652 |
-
# elif data_args.language_train is not None and len(data_args.language_train.split(",")) > 1:
|
653 |
-
# # make sure language and task are not stored in the model config
|
654 |
-
# model.config.forced_decoder_ids = None
|
655 |
|
656 |
# 6. Resample speech dataset if necessary
|
657 |
-
|
658 |
-
|
659 |
-
|
660 |
-
|
|
|
661 |
|
662 |
# 7. Preprocessing the datasets.
|
663 |
# We need to read the audio files as arrays and tokenize the targets.
|
664 |
-
logger.info("*** Preprocess dataset ***")
|
665 |
max_input_length = data_args.max_duration_in_seconds * feature_extractor.sampling_rate
|
666 |
min_input_length = data_args.min_duration_in_seconds * feature_extractor.sampling_rate
|
667 |
audio_column_name = data_args.audio_column_name
|
@@ -701,7 +518,6 @@ def main():
|
|
701 |
return batch
|
702 |
|
703 |
with training_args.main_process_first(desc="dataset map pre-processing"):
|
704 |
-
# raw_datasets_features.remove("language")
|
705 |
vectorized_datasets = raw_datasets.map(
|
706 |
prepare_dataset,
|
707 |
remove_columns=raw_datasets_features,
|
@@ -726,7 +542,6 @@ def main():
|
|
726 |
)
|
727 |
|
728 |
# 8. Load Metric
|
729 |
-
logger.info("*** Load metric ***")
|
730 |
metric = evaluate.load("wer")
|
731 |
do_normalize_eval = data_args.do_normalize_eval
|
732 |
|
@@ -751,7 +566,6 @@ def main():
|
|
751 |
return {"wer": wer}
|
752 |
|
753 |
# 9. Create a single speech processor
|
754 |
-
logger.info("*** Init processor ***")
|
755 |
if is_main_process(training_args.local_rank):
|
756 |
# save feature extractor, tokenizer and config
|
757 |
feature_extractor.save_pretrained(training_args.output_dir)
|
@@ -761,20 +575,14 @@ def main():
|
|
761 |
processor = AutoProcessor.from_pretrained(training_args.output_dir)
|
762 |
|
763 |
# 10. Define data collator
|
764 |
-
task_token = data_args.task
|
765 |
-
if not task_token.startswith('<|'):
|
766 |
-
task_token = f'<{task_token}>'
|
767 |
-
task_id = tokenizer(task_token).input_ids[0]
|
768 |
data_collator = DataCollatorSpeechSeq2SeqWithPadding(
|
769 |
processor=processor,
|
770 |
decoder_start_token_id=model.config.decoder_start_token_id,
|
771 |
-
task_id=task_id
|
772 |
)
|
773 |
|
774 |
# 11. Configure Trainer
|
775 |
# Trainer callback to reinitialise and reshuffle the streamable datasets at the beginning of each epoch
|
776 |
# Only required for streaming: Trainer automatically shuffles non-streaming datasets
|
777 |
-
logger.info("*** Set shuffle callback ***")
|
778 |
class ShuffleCallback(TrainerCallback):
|
779 |
def on_epoch_begin(self, args, state, control, train_dataloader, **kwargs):
|
780 |
if isinstance(train_dataloader.dataset, IterableDatasetShard):
|
@@ -782,9 +590,7 @@ def main():
|
|
782 |
elif isinstance(train_dataloader.dataset, IterableDataset):
|
783 |
train_dataloader.dataset.set_epoch(train_dataloader.dataset._epoch + 1)
|
784 |
|
785 |
-
|
786 |
# Initialize Trainer
|
787 |
-
logger.info("*** Init trainer ***")
|
788 |
trainer = Seq2SeqTrainer(
|
789 |
model=model,
|
790 |
args=training_args,
|
@@ -795,139 +601,63 @@ def main():
|
|
795 |
compute_metrics=compute_metrics if training_args.predict_with_generate else None,
|
796 |
callbacks=[ShuffleCallback()] if data_args.streaming else None,
|
797 |
)
|
798 |
-
logger.info("*** Trainer initialized ***")
|
799 |
-
|
800 |
-
orig_push_to_hub = trainer.args.push_to_hub
|
801 |
-
trainer.args.push_to_hub = False
|
802 |
|
803 |
# 12. Training
|
804 |
if training_args.do_train:
|
805 |
-
logger.info("*** Train ***")
|
806 |
-
print_data_samples(vectorized_datasets["train"], tokenizer)
|
807 |
checkpoint = None
|
808 |
if training_args.resume_from_checkpoint is not None:
|
809 |
checkpoint = training_args.resume_from_checkpoint
|
810 |
elif last_checkpoint is not None:
|
811 |
checkpoint = last_checkpoint
|
812 |
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
813 |
-
logger.info("*** Training completed ***")
|
814 |
-
logger.info("*** Saving model ***")
|
815 |
-
# We don't want to push the model to the hub now
|
816 |
-
# so we temporarily set to false the push_to_hub attribute
|
817 |
-
# and then reset it to the original value
|
818 |
trainer.save_model() # Saves the feature extractor too for easy upload
|
819 |
-
|
820 |
metrics = train_result.metrics
|
821 |
if data_args.max_train_samples:
|
822 |
metrics["train_samples"] = data_args.max_train_samples
|
823 |
-
logger.info("*** Logging metrics ***")
|
824 |
trainer.log_metrics("train", metrics)
|
825 |
-
logger.info("*** Metrics logged ***")
|
826 |
-
logger.info("*** Saving metrics ***")
|
827 |
trainer.save_metrics("train", metrics)
|
828 |
-
logger.info("*** Metrics saved ***")
|
829 |
-
logger.info("*** Saving state ***")
|
830 |
trainer.save_state()
|
831 |
-
logger.info("*** State saved ***")
|
832 |
-
|
833 |
-
# Run a test prediction to check outputs
|
834 |
-
predictions = trainer.predict(
|
835 |
-
test_dataset=vectorized_datasets["eval"].shuffle(seed=training_args.seed).take(5),
|
836 |
-
metric_key_prefix="test",
|
837 |
-
max_length=training_args.generation_max_length,
|
838 |
-
num_beams=training_args.generation_num_beams,
|
839 |
-
)
|
840 |
-
logger.info("*** Test prediction done ***")
|
841 |
-
preds = tokenizer.batch_decode(predictions.predictions)
|
842 |
-
labels = tokenizer.batch_decode(predictions.label_ids)
|
843 |
-
pred_labels = [f"Prediction: {pred}\nLabel: {label}\n" for pred, label in zip(preds, labels)]
|
844 |
-
logger.info("Before setting language and task")
|
845 |
-
logger.info(f"{pred_labels}")
|
846 |
-
language_name = LANGUAGES[data_args.language_eval]
|
847 |
-
trainer.model.config.forced_decoder_ids = \
|
848 |
-
tokenizer.get_decoder_prompt_ids(language=language_name, task=data_args.task, no_timestamps=True)
|
849 |
-
preds = tokenizer.batch_decode(predictions.predictions)
|
850 |
-
labels = tokenizer.batch_decode(predictions.label_ids)
|
851 |
-
pred_labels = [f"Prediction: {pred}\nLabel: {label}\n" for pred, label in zip(preds, labels)]
|
852 |
-
logger.info("After setting language and task")
|
853 |
-
logger.info(f"{pred_labels}")
|
854 |
|
855 |
# 13. Evaluation
|
856 |
results = {}
|
857 |
if training_args.do_eval:
|
858 |
logger.info("*** Evaluate ***")
|
859 |
-
print_data_samples(vectorized_datasets["eval"], tokenizer)
|
860 |
metrics = trainer.evaluate(
|
861 |
metric_key_prefix="eval",
|
862 |
max_length=training_args.generation_max_length,
|
863 |
num_beams=training_args.generation_num_beams,
|
864 |
)
|
865 |
-
logger.info("*** Evaluation done ***")
|
866 |
if data_args.max_eval_samples:
|
867 |
metrics["eval_samples"] = data_args.max_eval_samples
|
868 |
-
|
869 |
trainer.log_metrics("eval", metrics)
|
870 |
-
logger.info("*** Metrics logged ***")
|
871 |
-
logger.info("*** Saving metrics ***")
|
872 |
trainer.save_metrics("eval", metrics)
|
873 |
-
logger.info("*** Metrics saved ***")
|
874 |
|
875 |
# 14. Write Training Stats
|
876 |
-
logger.info("*** Writing training stats ***")
|
877 |
kwargs = {
|
878 |
"finetuned_from": model_args.model_name_or_path,
|
879 |
"tasks": "automatic-speech-recognition",
|
880 |
"tags": "whisper-event",
|
881 |
}
|
882 |
-
if data_args.
|
883 |
-
|
884 |
-
|
885 |
-
|
886 |
-
|
887 |
-
|
888 |
-
|
889 |
-
|
890 |
-
# kwargs["dataset"] = "\n".join(dataset_config_names_list)
|
891 |
-
# if "common_voice" in data_args.dataset_name:
|
892 |
-
# kwargs["language"] = data_args.dataset_config_name[:2]
|
893 |
-
if data_args.language_train is not None:
|
894 |
-
languages = list(set(data_args.language_train.split(",")))
|
895 |
-
kwargs["language"] = languages
|
896 |
if model_args.model_index_name is not None:
|
897 |
kwargs["model_name"] = model_args.model_index_name
|
898 |
|
899 |
-
logger.info("*** Training stats written ***")
|
900 |
-
logger.info(json.dumps(kwargs, indent=4))
|
901 |
-
|
902 |
-
# Training complete notification
|
903 |
-
logger.info("*** Training and eval complete ***")
|
904 |
-
logger.info(SENDING_NOTIFICATION)
|
905 |
-
with open(os.path.join(training_args.output_dir, "train_results.json"), "r") as f:
|
906 |
-
train_results = json.load(f)
|
907 |
-
with open(os.path.join(training_args.output_dir, "eval_results.json"), "r") as f:
|
908 |
-
eval_results = json.load(f)
|
909 |
-
notify_me(recipient=RECIPIENT_ADDRESS,
|
910 |
-
message=f"Training complete! {train_results = } {eval_results = }")
|
911 |
-
|
912 |
-
trainer.args.push_to_hub = orig_push_to_hub
|
913 |
if training_args.push_to_hub:
|
914 |
-
logger.info("*** Pushing to hub ***")
|
915 |
trainer.push_to_hub(**kwargs)
|
916 |
-
logger.info("*** Pushed to hub ***")
|
917 |
-
logger.info(SENDING_NOTIFICATION)
|
918 |
else:
|
919 |
-
logger.info("*** Creating model card ***")
|
920 |
trainer.create_model_card(**kwargs)
|
921 |
-
logger.info("*** Model card created ***")
|
922 |
-
logger.info(SENDING_NOTIFICATION)
|
923 |
-
|
924 |
-
with open(os.path.join(training_args.output_dir, "README.md"), "r") as f:
|
925 |
-
readme = f.read()
|
926 |
-
notify_me(recipient=RECIPIENT_ADDRESS,
|
927 |
-
message=f"Model pushed to hub! {readme = }")
|
928 |
|
929 |
return results
|
930 |
|
931 |
|
932 |
if __name__ == "__main__":
|
933 |
-
main()
|
|
|
20 |
# You can also adapt this script for your own sequence to sequence speech
|
21 |
# recognition task. Pointers for this are left as comments.
|
22 |
|
|
|
23 |
import logging
|
24 |
import os
|
|
|
25 |
import sys
|
26 |
from dataclasses import dataclass, field
|
27 |
from typing import Any, Dict, List, Optional, Union
|
|
|
47 |
set_seed,
|
48 |
)
|
49 |
from transformers.models.whisper.english_normalizer import BasicTextNormalizer
|
|
|
50 |
from transformers.trainer_pt_utils import IterableDatasetShard
|
51 |
from transformers.trainer_utils import get_last_checkpoint, is_main_process
|
52 |
from transformers.utils import check_min_version, send_example_telemetry
|
53 |
from transformers.utils.versions import require_version
|
54 |
|
55 |
+
|
56 |
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
57 |
check_min_version("4.25.0.dev0")
|
58 |
|
|
|
60 |
|
61 |
logger = logging.getLogger(__name__)
|
62 |
|
|
|
|
|
63 |
|
64 |
wandb_token = os.environ.get("WANDB_TOKEN", "None")
|
65 |
hf_token = os.environ.get("HF_TOKEN", None)
|
|
|
161 |
Arguments pertaining to what data we are going to input our model for training and eval.
|
162 |
"""
|
163 |
|
164 |
+
dataset_name: str = field(
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
|
166 |
)
|
167 |
+
dataset_config_name: Optional[str] = field(
|
168 |
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
169 |
)
|
170 |
text_column: Optional[str] = field(
|
|
|
233 |
default=True,
|
234 |
metadata={"help": "Whether to normalise the references and predictions in the eval WER calculation."},
|
235 |
)
|
236 |
+
language: str = field(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
237 |
default=None,
|
238 |
metadata={
|
239 |
"help": (
|
|
|
274 |
|
275 |
processor: Any
|
276 |
decoder_start_token_id: int
|
|
|
|
|
|
|
277 |
|
278 |
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
279 |
# split inputs and labels since they have to be of different lengths and need
|
|
|
281 |
model_input_name = self.processor.model_input_names[0]
|
282 |
input_features = [{model_input_name: feature[model_input_name]} for feature in features]
|
283 |
label_features = [{"input_ids": feature["labels"]} for feature in features]
|
|
|
284 |
|
285 |
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
|
286 |
|
|
|
291 |
|
292 |
# if bos token is appended in previous tokenization step,
|
293 |
# cut bos token here as it's append later anyways
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
294 |
if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
|
295 |
+
labels = labels[:, 1:]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
296 |
|
297 |
batch["labels"] = labels
|
298 |
|
299 |
return batch
|
300 |
|
301 |
|
302 |
+
def load_maybe_streaming_dataset(dataset_name, dataset_config_name, split="train", streaming=True, **kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
303 |
"""
|
304 |
Utility function to load a dataset in streaming mode. For datasets with multiple splits,
|
305 |
each split is loaded individually and then splits combined by taking alternating examples from
|
306 |
each (interleaving).
|
307 |
"""
|
308 |
+
if "+" in split:
|
|
|
|
|
|
|
|
|
|
|
309 |
# load multiple splits separated by the `+` symbol with streaming mode
|
310 |
+
dataset_splits = [
|
311 |
+
load_dataset(dataset_name, dataset_config_name, split=split_name, streaming=streaming, **kwargs)
|
312 |
+
for split_name in split.split("+")
|
313 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
314 |
# interleave multiple splits to form one dataset
|
315 |
+
interleaved_dataset = interleave_datasets(dataset_splits)
|
316 |
return interleaved_dataset
|
317 |
else:
|
318 |
# load a single split *with* streaming mode
|
319 |
+
dataset = load_dataset(dataset_name, dataset_config_name, split=split, streaming=streaming, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
320 |
return dataset
|
321 |
|
322 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
323 |
def main():
|
324 |
# 1. Parse input arguments
|
325 |
# See all possible arguments in src/transformers/training_args.py
|
326 |
# or by passing the --help flag to this script.
|
327 |
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
|
|
328 |
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
|
329 |
|
330 |
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
|
|
339 |
send_example_telemetry("run_speech_recognition_seq2seq_streaming", model_args, data_args)
|
340 |
|
341 |
# 2. Setup logging
|
|
|
342 |
logging.basicConfig(
|
343 |
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
344 |
datefmt="%m/%d/%Y %H:%M:%S",
|
|
|
383 |
# Set seed before initializing model.
|
384 |
set_seed(training_args.seed)
|
385 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
386 |
# 4. Load dataset
|
|
|
387 |
raw_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
|
388 |
|
|
|
|
|
|
|
389 |
if training_args.do_train:
|
390 |
raw_datasets["train"] = load_maybe_streaming_dataset(
|
391 |
+
data_args.dataset_name,
|
392 |
+
data_args.dataset_config_name,
|
393 |
split=data_args.train_split_name,
|
394 |
+
use_auth_token=True if model_args.use_auth_token else None,
|
395 |
streaming=data_args.streaming,
|
|
|
|
|
|
|
|
|
396 |
)
|
397 |
|
398 |
if training_args.do_eval:
|
399 |
raw_datasets["eval"] = load_maybe_streaming_dataset(
|
400 |
+
data_args.dataset_name,
|
401 |
+
data_args.dataset_config_name,
|
402 |
split=data_args.eval_split_name,
|
403 |
+
use_auth_token=True if model_args.use_auth_token else None,
|
404 |
streaming=data_args.streaming,
|
|
|
|
|
|
|
|
|
405 |
)
|
406 |
|
407 |
raw_datasets_features = list(next(iter(raw_datasets.values())).features.keys())
|
408 |
|
409 |
if data_args.audio_column_name not in raw_datasets_features:
|
410 |
raise ValueError(
|
411 |
+
f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'. "
|
412 |
"Make sure to set `--audio_column_name` to the correct audio column - one of "
|
413 |
f"{', '.join(raw_datasets_features)}."
|
414 |
)
|
415 |
|
|
|
416 |
if data_args.text_column_name not in raw_datasets_features:
|
417 |
raise ValueError(
|
418 |
+
f"--text_column_name {data_args.text_column_name} not found in dataset '{data_args.dataset_name}'. "
|
419 |
"Make sure to set `--text_column_name` to the correct text column - one of "
|
420 |
f"{', '.join(raw_datasets_features)}."
|
421 |
)
|
422 |
|
423 |
# 5. Load pretrained model, tokenizer, and feature extractor
|
424 |
+
#
|
425 |
# Distributed training:
|
426 |
# The .from_pretrained methods guarantee that only one local process can concurrently
|
427 |
config = AutoConfig.from_pretrained(
|
428 |
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
|
429 |
cache_dir=model_args.cache_dir,
|
430 |
revision=model_args.model_revision,
|
431 |
+
use_auth_token=True if model_args.use_auth_token else None,
|
432 |
)
|
433 |
|
|
|
434 |
config.update({"forced_decoder_ids": model_args.forced_decoder_ids, "suppress_tokens": model_args.suppress_tokens})
|
435 |
|
436 |
if training_args.gradient_checkpointing:
|
437 |
config.update({"use_cache": False})
|
438 |
|
439 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
440 |
+
model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path,
|
441 |
+
cache_dir=model_args.cache_dir,
|
442 |
+
revision=model_args.model_revision,
|
443 |
+
use_auth_token=True if model_args.use_auth_token else None,
|
444 |
+
)
|
445 |
tokenizer = AutoTokenizer.from_pretrained(
|
446 |
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
|
447 |
cache_dir=model_args.cache_dir,
|
448 |
use_fast=model_args.use_fast_tokenizer,
|
449 |
revision=model_args.model_revision,
|
450 |
+
use_auth_token=True if model_args.use_auth_token else None,
|
451 |
)
|
452 |
model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
453 |
model_args.model_name_or_path,
|
454 |
config=config,
|
455 |
cache_dir=model_args.cache_dir,
|
456 |
revision=model_args.model_revision,
|
457 |
+
use_auth_token=True if model_args.use_auth_token else None,
|
458 |
)
|
459 |
|
460 |
if model.config.decoder_start_token_id is None:
|
|
|
465 |
|
466 |
if model_args.freeze_encoder:
|
467 |
model.freeze_encoder()
|
|
|
|
|
468 |
|
469 |
+
if data_args.language is not None:
|
470 |
+
# We only need to set the task id when the language is specified (i.e. in a multilingual setting)
|
471 |
+
tokenizer.set_prefix_tokens(language=data_args.language, task=data_args.task)
|
|
|
|
|
|
|
|
|
472 |
|
473 |
# 6. Resample speech dataset if necessary
|
474 |
+
dataset_sampling_rate = next(iter(raw_datasets.values())).features[data_args.audio_column_name].sampling_rate
|
475 |
+
if dataset_sampling_rate != feature_extractor.sampling_rate:
|
476 |
+
raw_datasets = raw_datasets.cast_column(
|
477 |
+
data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
|
478 |
+
)
|
479 |
|
480 |
# 7. Preprocessing the datasets.
|
481 |
# We need to read the audio files as arrays and tokenize the targets.
|
|
|
482 |
max_input_length = data_args.max_duration_in_seconds * feature_extractor.sampling_rate
|
483 |
min_input_length = data_args.min_duration_in_seconds * feature_extractor.sampling_rate
|
484 |
audio_column_name = data_args.audio_column_name
|
|
|
518 |
return batch
|
519 |
|
520 |
with training_args.main_process_first(desc="dataset map pre-processing"):
|
|
|
521 |
vectorized_datasets = raw_datasets.map(
|
522 |
prepare_dataset,
|
523 |
remove_columns=raw_datasets_features,
|
|
|
542 |
)
|
543 |
|
544 |
# 8. Load Metric
|
|
|
545 |
metric = evaluate.load("wer")
|
546 |
do_normalize_eval = data_args.do_normalize_eval
|
547 |
|
|
|
566 |
return {"wer": wer}
|
567 |
|
568 |
# 9. Create a single speech processor
|
|
|
569 |
if is_main_process(training_args.local_rank):
|
570 |
# save feature extractor, tokenizer and config
|
571 |
feature_extractor.save_pretrained(training_args.output_dir)
|
|
|
575 |
processor = AutoProcessor.from_pretrained(training_args.output_dir)
|
576 |
|
577 |
# 10. Define data collator
|
|
|
|
|
|
|
|
|
578 |
data_collator = DataCollatorSpeechSeq2SeqWithPadding(
|
579 |
processor=processor,
|
580 |
decoder_start_token_id=model.config.decoder_start_token_id,
|
|
|
581 |
)
|
582 |
|
583 |
# 11. Configure Trainer
|
584 |
# Trainer callback to reinitialise and reshuffle the streamable datasets at the beginning of each epoch
|
585 |
# Only required for streaming: Trainer automatically shuffles non-streaming datasets
|
|
|
586 |
class ShuffleCallback(TrainerCallback):
|
587 |
def on_epoch_begin(self, args, state, control, train_dataloader, **kwargs):
|
588 |
if isinstance(train_dataloader.dataset, IterableDatasetShard):
|
|
|
590 |
elif isinstance(train_dataloader.dataset, IterableDataset):
|
591 |
train_dataloader.dataset.set_epoch(train_dataloader.dataset._epoch + 1)
|
592 |
|
|
|
593 |
# Initialize Trainer
|
|
|
594 |
trainer = Seq2SeqTrainer(
|
595 |
model=model,
|
596 |
args=training_args,
|
|
|
601 |
compute_metrics=compute_metrics if training_args.predict_with_generate else None,
|
602 |
callbacks=[ShuffleCallback()] if data_args.streaming else None,
|
603 |
)
|
|
|
|
|
|
|
|
|
604 |
|
605 |
# 12. Training
|
606 |
if training_args.do_train:
|
|
|
|
|
607 |
checkpoint = None
|
608 |
if training_args.resume_from_checkpoint is not None:
|
609 |
checkpoint = training_args.resume_from_checkpoint
|
610 |
elif last_checkpoint is not None:
|
611 |
checkpoint = last_checkpoint
|
612 |
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
|
|
|
|
|
|
|
|
|
|
613 |
trainer.save_model() # Saves the feature extractor too for easy upload
|
614 |
+
|
615 |
metrics = train_result.metrics
|
616 |
if data_args.max_train_samples:
|
617 |
metrics["train_samples"] = data_args.max_train_samples
|
|
|
618 |
trainer.log_metrics("train", metrics)
|
|
|
|
|
619 |
trainer.save_metrics("train", metrics)
|
|
|
|
|
620 |
trainer.save_state()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
621 |
|
622 |
# 13. Evaluation
|
623 |
results = {}
|
624 |
if training_args.do_eval:
|
625 |
logger.info("*** Evaluate ***")
|
|
|
626 |
metrics = trainer.evaluate(
|
627 |
metric_key_prefix="eval",
|
628 |
max_length=training_args.generation_max_length,
|
629 |
num_beams=training_args.generation_num_beams,
|
630 |
)
|
|
|
631 |
if data_args.max_eval_samples:
|
632 |
metrics["eval_samples"] = data_args.max_eval_samples
|
633 |
+
|
634 |
trainer.log_metrics("eval", metrics)
|
|
|
|
|
635 |
trainer.save_metrics("eval", metrics)
|
|
|
636 |
|
637 |
# 14. Write Training Stats
|
|
|
638 |
kwargs = {
|
639 |
"finetuned_from": model_args.model_name_or_path,
|
640 |
"tasks": "automatic-speech-recognition",
|
641 |
"tags": "whisper-event",
|
642 |
}
|
643 |
+
if data_args.dataset_name is not None:
|
644 |
+
kwargs["dataset_tags"] = data_args.dataset_name
|
645 |
+
if data_args.dataset_config_name is not None:
|
646 |
+
kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}"
|
647 |
+
else:
|
648 |
+
kwargs["dataset"] = data_args.dataset_name
|
649 |
+
if "common_voice" in data_args.dataset_name:
|
650 |
+
kwargs["language"] = data_args.dataset_config_name[:2]
|
|
|
|
|
|
|
|
|
|
|
|
|
651 |
if model_args.model_index_name is not None:
|
652 |
kwargs["model_name"] = model_args.model_index_name
|
653 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
654 |
if training_args.push_to_hub:
|
|
|
655 |
trainer.push_to_hub(**kwargs)
|
|
|
|
|
656 |
else:
|
|
|
657 |
trainer.create_model_card(**kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
658 |
|
659 |
return results
|
660 |
|
661 |
|
662 |
if __name__ == "__main__":
|
663 |
+
main()
|