Spaces:
Running
Running
from transformers import TrainingArguments | |
from typing import Any, Optional | |
from dataclasses import dataclass, field | |
#............................................. | |
#### ARGUMENTS | |
class ModelArguments: | |
""" | |
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. | |
""" | |
model_name_or_path: str = field( | |
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} | |
) | |
config_name: Optional[str] = field( | |
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} | |
) | |
tokenizer_name: Optional[str] = field( | |
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} | |
) | |
feature_extractor_name: Optional[str] = field( | |
default=None, metadata={"help": "feature extractor name or path if not the same as model_name"} | |
) | |
cache_dir: Optional[str] = field( | |
default=None, | |
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"}, | |
) | |
use_fast_tokenizer: bool = field( | |
default=True, | |
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, | |
) | |
model_revision: str = field( | |
default="main", | |
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, | |
) | |
token: str = field( | |
default=None, | |
metadata={ | |
"help": ( | |
"The token to use as HTTP bearer authorization for remote files. If not specified, will use the token " | |
"generated when running `huggingface-cli login` (stored in `~/.huggingface`)." | |
) | |
}, | |
) | |
use_auth_token: bool = field( | |
default=None, | |
metadata={ | |
"help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token`." | |
}, | |
) | |
trust_remote_code: bool = field( | |
default=False, | |
metadata={ | |
"help": ( | |
"Whether or not to allow for custom models defined on the Hub in their own modeling files. This option" | |
"should only be set to `True` for repositories you trust and in which you have read the code, as it will" | |
"execute code present on the Hub on your local machine." | |
) | |
}, | |
) | |
override_speaker_embeddings: bool = field( | |
default=False, | |
metadata={ | |
"help": ( | |
"If `True` and if `speaker_id_column_name` is specified, it will replace current speaker embeddings with a new set of speaker embeddings." | |
"If the model from the checkpoint didn't have speaker embeddings, it will initialize speaker embeddings." | |
) | |
}, | |
) | |
override_vocabulary_embeddings: bool = field( | |
default=False, | |
metadata={ | |
"help": ( | |
"If `True`, it will resize the token embeddings based on the vocabulary size of the tokenizer. In other words, use this when you use a different tokenizer than the one that was used during pretraining." | |
) | |
}, | |
) | |
#............................................................................................. | |
class VITSTrainingArguments(TrainingArguments): | |
do_step_schedule_per_epoch: bool = field( | |
default=True, | |
metadata={ | |
"help": ( | |
"Whether or not to perform scheduler steps per epoch or per steps. If `True`, the scheduler will be `ExponentialLR` parametrized with `lr_decay`." | |
) | |
}, | |
) | |
lr_decay: float = field( | |
default=0.999875, | |
metadata={"help": "Learning rate decay, used with `ExponentialLR` when `do_step_schedule_per_epoch`."}, | |
) | |
weight_duration: float = field(default=1.0, metadata={"help": "Duration loss weight."}) | |
weight_kl: float = field(default=1.5, metadata={"help": "KL loss weight."}) | |
weight_mel: float = field(default=35.0, metadata={"help": "Mel-spectrogram loss weight"}) | |
weight_disc: float = field(default=3.0, metadata={"help": "Discriminator loss weight"}) | |
weight_gen: float = field(default=1.0, metadata={"help": "Generator loss weight"}) | |
weight_fmaps: float = field(default=1.0, metadata={"help": "Feature map loss weight"}) | |
d_learning_rate: float = field(default=2e-4, metadata={"help": "Feature map loss weight"}) | |
d_adam_beta1: float = field(default=0.8, metadata={"help": "Feature map loss weight"}) | |
d_adam_beta2: float = field(default=0.99, metadata={"help": "Feature map loss weight"}) | |
#............................................................................................. | |
class DataTrainingArguments: | |
""" | |
Arguments pertaining to what data we are going to input our model for training and eval. | |
""" | |
project_name: str = field( | |
default="vits_finetuning", | |
metadata={"help": "The project name associated to this run. Useful to track your experiment."}, | |
) | |
dataset_name: str = field( | |
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} | |
) | |
dataset_config_name: Optional[str] = field( | |
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} | |
) | |
overwrite_cache: bool = field( | |
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} | |
) | |
preprocessing_num_workers: Optional[int] = field( | |
default=None, | |
metadata={"help": "The number of processes to use for the preprocessing."}, | |
) | |
max_train_samples: Optional[int] = field( | |
default=None, | |
metadata={ | |
"help": ( | |
"For debugging purposes or quicker training, truncate the number of training examples to this " | |
"value if set." | |
) | |
}, | |
) | |
max_eval_samples: Optional[int] = field( | |
default=None, | |
metadata={ | |
"help": ( | |
"For debugging purposes or quicker training, truncate the number of evaluation examples to this " | |
"value if set." | |
) | |
}, | |
) | |
audio_column_name: str = field( | |
default="audio", | |
metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"}, | |
) | |
text_column_name: str = field( | |
default="text", | |
metadata={"help": "The name of the dataset column containing the text data. Defaults to 'text'"}, | |
) | |
speaker_id_column_name: str = field( | |
default=None, | |
metadata={ | |
"help": """If set, corresponds to the name of the speaker id column containing the speaker ids. | |
If `override_speaker_embeddings=False`: | |
it assumes that speakers are indexed from 0 to `num_speakers-1`. | |
`num_speakers` and `speaker_embedding_size` have to be set in the model config. | |
If `override_speaker_embeddings=True`: | |
It will use this column to compute how many speakers there are. | |
Defaults to None, i.e it is not used by default.""" | |
}, | |
) | |
filter_on_speaker_id: int = field( | |
default=None, | |
metadata={ | |
"help": ( | |
"If `speaker_id_column_name` and `filter_on_speaker_id` are set, will filter the dataset to keep a single speaker_id (`filter_on_speaker_id`) " | |
) | |
}, | |
) | |
max_tokens_length: float = field( | |
default=450, | |
metadata={ | |
"help": ("Truncate audio files with a transcription that are longer than `max_tokens_length` tokens") | |
}, | |
) | |
max_duration_in_seconds: float = field( | |
default=20.0, | |
metadata={ | |
"help": ( | |
"Truncate audio files that are longer than `max_duration_in_seconds` seconds to" | |
" 'max_duration_in_seconds`" | |
) | |
}, | |
) | |
min_duration_in_seconds: float = field( | |
default=0.0, metadata={"help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"} | |
) | |
preprocessing_only: bool = field( | |
default=False, | |
metadata={ | |
"help": ( | |
"Whether to only do data preprocessing and skip training. This is especially useful when data" | |
" preprocessing errors out in distributed training due to timeout. In this case, one should run the" | |
" preprocessing in a non-distributed setup with `preprocessing_only=True` so that the cached datasets" | |
" can consequently be loaded in distributed training" | |
) | |
}, | |
) | |
train_split_name: str = field( | |
default="train", | |
metadata={ | |
"help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'" | |
}, | |
) | |
eval_split_name: str = field( | |
default="test", | |
metadata={ | |
"help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'" | |
}, | |
) | |
do_lower_case: bool = field( | |
default=False, | |
metadata={"help": "Whether the input text should be lower cased."}, | |
) | |
do_normalize: bool = field( | |
default=False, | |
metadata={"help": "Whether the input waveform should be normalized."}, | |
) | |
full_generation_sample_text: str = field( | |
default="This is a test, let's see what comes out of this.", | |
metadata={ | |
"help": ( | |
"Language for multilingual fine-tuning. This argument should be set for multilingual fine-tuning " | |
"only. For English speech recognition, it should be set to `None`." | |
) | |
}, | |
) | |
uroman_path: str = field( | |
default=None, | |
metadata={ | |
"help": ( | |
"Absolute path to the uroman package. To use if your model requires `uroman`." | |
"An easy way to check it is to go on your model card and manually check `is_uroman` in the `tokenizer_config.json," | |
"e.g the French checkpoint doesn't need it: https://huggingface.co/facebook/mms-tts-fra/blob/main/tokenizer_config.json#L4" | |
) | |
}, | |
) | |
#............................................................................................. | |