|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
Usage: |
|
|
|
export CUDA_VISIBLE_DEVICES="0,1,2,3" |
|
|
|
./conformer_ctc3/train.py \ |
|
--world-size 4 \ |
|
--num-epochs 30 \ |
|
--start-epoch 1 \ |
|
--exp-dir conformer_ctc3/exp \ |
|
--full-libri 1 \ |
|
--max-duration 300 |
|
|
|
# For mix precision training: |
|
|
|
./conformer_ctc3/train.py \ |
|
--world-size 4 \ |
|
--num-epochs 30 \ |
|
--start-epoch 1 \ |
|
--use-fp16 1 \ |
|
--exp-dir conformer_ctc3/exp \ |
|
--full-libri 1 \ |
|
--max-duration 550 |
|
|
|
# train a streaming model |
|
./conformer_ctc3/train.py \ |
|
--world-size 4 \ |
|
--num-epochs 30 \ |
|
--start-epoch 1 \ |
|
--exp-dir conformer_ctc3/exp \ |
|
--full-libri 1 \ |
|
--dynamic-chunk-training 1 \ |
|
--causal-convolution 1 \ |
|
--short-chunk-size 25 \ |
|
--num-left-chunks 4 \ |
|
--max-duration 300 \ |
|
--delay-penalty 0.0 |
|
""" |
|
|
|
import argparse |
|
import copy |
|
import logging |
|
from pathlib import Path |
|
from shutil import copyfile |
|
from typing import Any, Dict, Optional, Tuple, Union |
|
|
|
import k2 |
|
import optim |
|
import torch |
|
import torch.multiprocessing as mp |
|
import torch.nn as nn |
|
from asr_datamodule import LibriSpeechAsrDataModule |
|
from conformer import Conformer |
|
from lhotse.cut import Cut |
|
from lhotse.dataset.sampling.base import CutSampler |
|
from lhotse.utils import fix_random_seed |
|
from model import CTCModel |
|
from optim import Eden, Eve |
|
from torch import Tensor |
|
from torch.cuda.amp import GradScaler |
|
from torch.nn.parallel import DistributedDataParallel as DDP |
|
from torch.utils.tensorboard import SummaryWriter |
|
|
|
from icefall import diagnostics |
|
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler |
|
from icefall.checkpoint import load_checkpoint, remove_checkpoints |
|
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl |
|
from icefall.checkpoint import ( |
|
save_checkpoint_with_global_batch_idx, |
|
update_averaged_model, |
|
) |
|
from icefall.dist import cleanup_dist, setup_dist |
|
from icefall.env import get_env_info |
|
from icefall.graph_compiler import CtcTrainingGraphCompiler |
|
from icefall.lexicon import Lexicon |
|
from icefall.utils import ( |
|
AttributeDict, |
|
MetricsTracker, |
|
encode_supervisions, |
|
setup_logger, |
|
str2bool, |
|
) |
|
|
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] |
|
|
|
|
|
def add_model_arguments(parser: argparse.ArgumentParser): |
|
parser.add_argument( |
|
"--dynamic-chunk-training", |
|
type=str2bool, |
|
default=False, |
|
help="""Whether to use dynamic_chunk_training, if you want a streaming |
|
model, this requires to be True. |
|
""", |
|
) |
|
|
|
parser.add_argument( |
|
"--causal-convolution", |
|
type=str2bool, |
|
default=False, |
|
help="""Whether to use causal convolution, this requires to be True when |
|
using dynamic_chunk_training. |
|
""", |
|
) |
|
|
|
parser.add_argument( |
|
"--short-chunk-size", |
|
type=int, |
|
default=25, |
|
help="""Chunk length of dynamic training, the chunk size would be either |
|
max sequence length of current batch or uniformly sampled from (1, short_chunk_size). |
|
""", |
|
) |
|
|
|
parser.add_argument( |
|
"--num-left-chunks", |
|
type=int, |
|
default=4, |
|
help="How many left context can be seen in chunks when calculating attention.", |
|
) |
|
|
|
|
|
def get_parser(): |
|
parser = argparse.ArgumentParser( |
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter |
|
) |
|
|
|
parser.add_argument( |
|
"--world-size", |
|
type=int, |
|
default=1, |
|
help="Number of GPUs for DDP training.", |
|
) |
|
|
|
parser.add_argument( |
|
"--master-port", |
|
type=int, |
|
default=12354, |
|
help="Master port to use for DDP training.", |
|
) |
|
|
|
parser.add_argument( |
|
"--tensorboard", |
|
type=str2bool, |
|
default=True, |
|
help="Should various information be logged in tensorboard.", |
|
) |
|
|
|
parser.add_argument( |
|
"--num-epochs", |
|
type=int, |
|
default=30, |
|
help="Number of epochs to train.", |
|
) |
|
|
|
parser.add_argument( |
|
"--start-epoch", |
|
type=int, |
|
default=1, |
|
help="""Resume training from this epoch. It should be positive. |
|
If larger than 1, it will load checkpoint from |
|
exp-dir/epoch-{start_epoch-1}.pt |
|
""", |
|
) |
|
|
|
parser.add_argument( |
|
"--start-batch", |
|
type=int, |
|
default=0, |
|
help="""If positive, --start-epoch is ignored and |
|
it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt |
|
""", |
|
) |
|
|
|
parser.add_argument( |
|
"--exp-dir", |
|
type=str, |
|
default="conformer_ctc3/exp", |
|
help="""The experiment dir. |
|
It specifies the directory where all training related |
|
files, e.g., checkpoints, log, etc, are saved |
|
""", |
|
) |
|
|
|
parser.add_argument( |
|
"--lang-dir", |
|
type=str, |
|
default="data/lang_bpe_500", |
|
help="""The lang dir |
|
It contains language related input files such as |
|
"lexicon.txt" |
|
""", |
|
) |
|
|
|
parser.add_argument( |
|
"--initial-lr", |
|
type=float, |
|
default=0.003, |
|
help="""The initial learning rate. This value should not need to be |
|
changed.""", |
|
) |
|
|
|
parser.add_argument( |
|
"--lr-batches", |
|
type=float, |
|
default=5000, |
|
help="""Number of steps that affects how rapidly the learning rate decreases. |
|
We suggest not to change this.""", |
|
) |
|
|
|
parser.add_argument( |
|
"--lr-epochs", |
|
type=float, |
|
default=6, |
|
help="""Number of epochs that affects how rapidly the learning rate decreases. |
|
""", |
|
) |
|
|
|
parser.add_argument( |
|
"--seed", |
|
type=int, |
|
default=42, |
|
help="The seed for random generators intended for reproducibility", |
|
) |
|
|
|
parser.add_argument( |
|
"--print-diagnostics", |
|
type=str2bool, |
|
default=False, |
|
help="Accumulate stats on activations, print them and exit.", |
|
) |
|
|
|
parser.add_argument( |
|
"--save-every-n", |
|
type=int, |
|
default=8000, |
|
help="""Save checkpoint after processing this number of batches" |
|
periodically. We save checkpoint to exp-dir/ whenever |
|
params.batch_idx_train % save_every_n == 0. The checkpoint filename |
|
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' |
|
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the |
|
end of each epoch where `xxx` is the epoch number counting from 0. |
|
""", |
|
) |
|
|
|
parser.add_argument( |
|
"--keep-last-k", |
|
type=int, |
|
default=20, |
|
help="""Only keep this number of checkpoints on disk. |
|
For instance, if it is 3, there are only 3 checkpoints |
|
in the exp-dir with filenames `checkpoint-xxx.pt`. |
|
It does not affect checkpoints with name `epoch-xxx.pt`. |
|
""", |
|
) |
|
|
|
parser.add_argument( |
|
"--average-period", |
|
type=int, |
|
default=100, |
|
help="""Update the averaged model, namely `model_avg`, after processing |
|
this number of batches. `model_avg` is a separate version of model, |
|
in which each floating-point parameter is the average of all the |
|
parameters from the start of training. Each time we take the average, |
|
we do: `model_avg = model * (average_period / batch_idx_train) + |
|
model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. |
|
""", |
|
) |
|
|
|
parser.add_argument( |
|
"--use-fp16", |
|
type=str2bool, |
|
default=False, |
|
help="Whether to use half precision training.", |
|
) |
|
|
|
parser.add_argument( |
|
"--delay-penalty", |
|
type=float, |
|
default=0.0, |
|
help="""A constant used to scale the symbol delay penalty, |
|
to encourage symbol emit earlier for streaming models. |
|
It is almost the same as the `delay_penalty` in our `rnnt_loss`, See |
|
https://github.com/k2-fsa/k2/issues/955 and |
|
https://arxiv.org/pdf/2211.00490.pdf for more details.""", |
|
) |
|
|
|
parser.add_argument( |
|
"--nnet-delay-penalty", |
|
type=float, |
|
default=0.0, |
|
help="""A constant to penalize symbol delay, which is applied on |
|
the nnet_output after log-softmax. |
|
We recommend using --delay-penalty instead. |
|
See https://github.com/k2-fsa/icefall/pull/669 for details.""", |
|
) |
|
|
|
add_model_arguments(parser) |
|
|
|
return parser |
|
|
|
|
|
def get_params() -> AttributeDict: |
|
"""Return a dict containing training parameters. |
|
|
|
All training related parameters that are not passed from the commandline |
|
are saved in the variable `params`. |
|
|
|
Commandline options are merged into `params` after they are parsed, so |
|
you can also access them via `params`. |
|
|
|
Explanation of options saved in `params`: |
|
|
|
- best_train_loss: Best training loss so far. It is used to select |
|
the model that has the lowest training loss. It is |
|
updated during the training. |
|
|
|
- best_valid_loss: Best validation loss so far. It is used to select |
|
the model that has the lowest validation loss. It is |
|
updated during the training. |
|
|
|
- best_train_epoch: It is the epoch that has the best training loss. |
|
|
|
- best_valid_epoch: It is the epoch that has the best validation loss. |
|
|
|
- batch_idx_train: Used to writing statistics to tensorboard. It |
|
contains number of batches trained so far across |
|
epochs. |
|
|
|
- log_interval: Print training loss if batch_idx % log_interval` is 0 |
|
|
|
- reset_interval: Reset statistics if batch_idx % reset_interval is 0 |
|
|
|
- valid_interval: Run validation if batch_idx % valid_interval is 0 |
|
|
|
- feature_dim: The model input dim. It has to match the one used |
|
in computing features. |
|
|
|
- subsampling_factor: The subsampling factor for the model. |
|
|
|
- encoder_dim: Hidden dim for multi-head attention model. |
|
|
|
- num_decoder_layers: Number of decoder layer of transformer decoder. |
|
|
|
- warm_step: The warm_step for Noam optimizer. |
|
""" |
|
params = AttributeDict( |
|
{ |
|
"best_train_loss": float("inf"), |
|
"best_valid_loss": float("inf"), |
|
"best_train_epoch": -1, |
|
"best_valid_epoch": -1, |
|
"batch_idx_train": 0, |
|
"log_interval": 50, |
|
"reset_interval": 200, |
|
"valid_interval": 3000, |
|
|
|
"feature_dim": 80, |
|
"subsampling_factor": 4, |
|
"encoder_dim": 512, |
|
"nhead": 8, |
|
"dim_feedforward": 2048, |
|
"num_encoder_layers": 12, |
|
|
|
"beam_size": 10, |
|
"reduction": "none", |
|
"use_double_scores": True, |
|
|
|
"model_warm_step": 3000, |
|
"env_info": get_env_info(), |
|
} |
|
) |
|
|
|
return params |
|
|
|
|
|
def get_encoder_model(params: AttributeDict) -> nn.Module: |
|
|
|
encoder = Conformer( |
|
num_features=params.feature_dim, |
|
subsampling_factor=params.subsampling_factor, |
|
d_model=params.encoder_dim, |
|
nhead=params.nhead, |
|
dim_feedforward=params.dim_feedforward, |
|
num_encoder_layers=params.num_encoder_layers, |
|
dynamic_chunk_training=params.dynamic_chunk_training, |
|
short_chunk_size=params.short_chunk_size, |
|
num_left_chunks=params.num_left_chunks, |
|
causal=params.causal_convolution, |
|
) |
|
return encoder |
|
|
|
|
|
def get_ctc_model(params: AttributeDict) -> nn.Module: |
|
encoder = get_encoder_model(params) |
|
model = CTCModel( |
|
encoder=encoder, |
|
encoder_dim=params.encoder_dim, |
|
vocab_size=params.vocab_size, |
|
) |
|
return model |
|
|
|
|
|
def load_checkpoint_if_available( |
|
params: AttributeDict, |
|
model: nn.Module, |
|
model_avg: nn.Module = None, |
|
optimizer: Optional[torch.optim.Optimizer] = None, |
|
scheduler: Optional[LRSchedulerType] = None, |
|
) -> Optional[Dict[str, Any]]: |
|
"""Load checkpoint from file. |
|
|
|
If params.start_batch is positive, it will load the checkpoint from |
|
`params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if |
|
params.start_epoch is larger than 1, it will load the checkpoint from |
|
`params.start_epoch - 1`. |
|
|
|
Apart from loading state dict for `model` and `optimizer` it also updates |
|
`best_train_epoch`, `best_train_loss`, `best_valid_epoch`, |
|
and `best_valid_loss` in `params`. |
|
|
|
Args: |
|
params: |
|
The return value of :func:`get_params`. |
|
model: |
|
The training model. |
|
model_avg: |
|
The stored model averaged from the start of training. |
|
optimizer: |
|
The optimizer that we are using. |
|
scheduler: |
|
The scheduler that we are using. |
|
Returns: |
|
Return a dict containing previously saved training info. |
|
""" |
|
if params.start_batch > 0: |
|
filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" |
|
elif params.start_epoch > 1: |
|
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" |
|
else: |
|
return None |
|
|
|
assert filename.is_file(), f"{filename} does not exist!" |
|
|
|
saved_params = load_checkpoint( |
|
filename, |
|
model=model, |
|
model_avg=model_avg, |
|
optimizer=optimizer, |
|
scheduler=scheduler, |
|
) |
|
|
|
keys = [ |
|
"best_train_epoch", |
|
"best_valid_epoch", |
|
"batch_idx_train", |
|
"best_train_loss", |
|
"best_valid_loss", |
|
] |
|
for k in keys: |
|
params[k] = saved_params[k] |
|
|
|
if params.start_batch > 0: |
|
if "cur_epoch" in saved_params: |
|
params["start_epoch"] = saved_params["cur_epoch"] |
|
|
|
return saved_params |
|
|
|
|
|
def save_checkpoint( |
|
params: AttributeDict, |
|
model: Union[nn.Module, DDP], |
|
model_avg: Optional[nn.Module] = None, |
|
optimizer: Optional[torch.optim.Optimizer] = None, |
|
scheduler: Optional[LRSchedulerType] = None, |
|
sampler: Optional[CutSampler] = None, |
|
scaler: Optional[GradScaler] = None, |
|
rank: int = 0, |
|
) -> None: |
|
"""Save model, optimizer, scheduler and training stats to file. |
|
|
|
Args: |
|
params: |
|
It is returned by :func:`get_params`. |
|
model: |
|
The training model. |
|
model_avg: |
|
The stored model averaged from the start of training. |
|
optimizer: |
|
The optimizer used in the training. |
|
sampler: |
|
The sampler for the training dataset. |
|
scaler: |
|
The scaler used for mix precision training. |
|
""" |
|
if rank != 0: |
|
return |
|
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" |
|
save_checkpoint_impl( |
|
filename=filename, |
|
model=model, |
|
model_avg=model_avg, |
|
params=params, |
|
optimizer=optimizer, |
|
scheduler=scheduler, |
|
sampler=sampler, |
|
scaler=scaler, |
|
rank=rank, |
|
) |
|
|
|
if params.best_train_epoch == params.cur_epoch: |
|
best_train_filename = params.exp_dir / "best-train-loss.pt" |
|
copyfile(src=filename, dst=best_train_filename) |
|
|
|
if params.best_valid_epoch == params.cur_epoch: |
|
best_valid_filename = params.exp_dir / "best-valid-loss.pt" |
|
copyfile(src=filename, dst=best_valid_filename) |
|
|
|
|
|
def compute_loss( |
|
params: AttributeDict, |
|
model: Union[nn.Module, DDP], |
|
graph_compiler: Union[BpeCtcTrainingGraphCompiler, CtcTrainingGraphCompiler], |
|
batch: dict, |
|
is_training: bool, |
|
warmup: float = 1.0, |
|
) -> Tuple[Tensor, MetricsTracker]: |
|
""" |
|
Compute RNN-T loss given the model and its inputs. |
|
|
|
Args: |
|
params: |
|
Parameters for training. See :func:`get_params`. |
|
model: |
|
The model for training. It is an instance of Conformer in our case. |
|
graph_compiler: |
|
It is used to build a decoding graph from a ctc topo and training |
|
transcript. The training transcript is contained in the given `batch`, |
|
while the ctc topo is built when this compiler is instantiated. |
|
batch: |
|
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` |
|
for the content in it. |
|
is_training: |
|
True for training. False for validation. When it is True, this |
|
function enables autograd during computation; when it is False, it |
|
disables autograd. |
|
warmup: a floating point value which increases throughout training; |
|
values >= 1.0 are fully warmed up and have all modules present. |
|
""" |
|
device = model.device if isinstance(model, DDP) else next(model.parameters()).device |
|
feature = batch["inputs"] |
|
|
|
assert feature.ndim == 3 |
|
feature = feature.to(device) |
|
|
|
supervisions = batch["supervisions"] |
|
feature_lens = supervisions["num_frames"].to(device) |
|
|
|
with torch.set_grad_enabled(is_training): |
|
nnet_output, encoder_out_lens = model( |
|
feature, |
|
feature_lens, |
|
warmup=warmup, |
|
delay_penalty=params.nnet_delay_penalty if warmup >= 1.0 else 0, |
|
) |
|
assert torch.all(encoder_out_lens > 0) |
|
|
|
|
|
|
|
|
|
supervision_segments, texts = encode_supervisions( |
|
supervisions, subsampling_factor=params.subsampling_factor |
|
) |
|
|
|
if isinstance(graph_compiler, BpeCtcTrainingGraphCompiler): |
|
|
|
token_ids = graph_compiler.texts_to_ids(texts) |
|
decoding_graph = graph_compiler.compile(token_ids) |
|
elif isinstance(graph_compiler, CtcTrainingGraphCompiler): |
|
|
|
decoding_graph = graph_compiler.compile(texts) |
|
else: |
|
raise ValueError(f"Unsupported type of graph compiler: {type(graph_compiler)}") |
|
|
|
dense_fsa_vec = k2.DenseFsaVec( |
|
nnet_output, |
|
supervision_segments, |
|
allow_truncate=params.subsampling_factor - 1, |
|
) |
|
|
|
ctc_loss = k2.ctc_loss( |
|
decoding_graph=decoding_graph, |
|
dense_fsa_vec=dense_fsa_vec, |
|
output_beam=params.beam_size, |
|
delay_penalty=params.delay_penalty if warmup >= 1.0 else 0.0, |
|
reduction=params.reduction, |
|
use_double_scores=params.use_double_scores, |
|
) |
|
ctc_loss_is_finite = torch.isfinite(ctc_loss) |
|
if not torch.all(ctc_loss_is_finite): |
|
logging.info("Not all losses are finite!\n" f"ctc_loss: {ctc_loss}") |
|
ctc_loss = ctc_loss[ctc_loss_is_finite] |
|
|
|
|
|
|
|
if torch.all(~ctc_loss_is_finite): |
|
raise ValueError( |
|
"There are too many utterances in this batch " |
|
"leading to inf or nan losses." |
|
) |
|
loss = ctc_loss.sum() |
|
|
|
assert loss.requires_grad == is_training |
|
|
|
info = MetricsTracker() |
|
|
|
|
|
|
|
|
|
info["frames"] = supervision_segments[:, 2].sum().item() |
|
|
|
info["utterances"] = feature.size(0) |
|
|
|
info["utt_duration"] = feature_lens.sum().item() |
|
|
|
info["utt_pad_proportion"] = ( |
|
((feature.size(1) - feature_lens) / feature.size(1)).sum().item() |
|
) |
|
|
|
|
|
info["loss"] = loss.detach().cpu().item() |
|
|
|
return loss, info |
|
|
|
|
|
def compute_validation_loss( |
|
params: AttributeDict, |
|
model: Union[nn.Module, DDP], |
|
graph_compiler: Union[BpeCtcTrainingGraphCompiler, CtcTrainingGraphCompiler], |
|
valid_dl: torch.utils.data.DataLoader, |
|
world_size: int = 1, |
|
) -> MetricsTracker: |
|
"""Run the validation process.""" |
|
model.eval() |
|
|
|
tot_loss = MetricsTracker() |
|
|
|
for batch_idx, batch in enumerate(valid_dl): |
|
loss, loss_info = compute_loss( |
|
params=params, |
|
model=model, |
|
graph_compiler=graph_compiler, |
|
batch=batch, |
|
is_training=False, |
|
) |
|
assert loss.requires_grad is False |
|
tot_loss = tot_loss + loss_info |
|
|
|
if world_size > 1: |
|
tot_loss.reduce(loss.device) |
|
|
|
loss_value = tot_loss["loss"] / tot_loss["frames"] |
|
if loss_value < params.best_valid_loss: |
|
params.best_valid_epoch = params.cur_epoch |
|
params.best_valid_loss = loss_value |
|
|
|
return tot_loss |
|
|
|
|
|
def train_one_epoch( |
|
params: AttributeDict, |
|
model: Union[nn.Module, DDP], |
|
optimizer: torch.optim.Optimizer, |
|
scheduler: LRSchedulerType, |
|
graph_compiler: Union[BpeCtcTrainingGraphCompiler, CtcTrainingGraphCompiler], |
|
train_dl: torch.utils.data.DataLoader, |
|
valid_dl: torch.utils.data.DataLoader, |
|
scaler: GradScaler, |
|
model_avg: Optional[nn.Module] = None, |
|
tb_writer: Optional[SummaryWriter] = None, |
|
world_size: int = 1, |
|
rank: int = 0, |
|
) -> None: |
|
"""Train the model for one epoch. |
|
|
|
The training loss from the mean of all frames is saved in |
|
`params.train_loss`. It runs the validation process every |
|
`params.valid_interval` batches. |
|
|
|
Args: |
|
params: |
|
It is returned by :func:`get_params`. |
|
model: |
|
The model for training. |
|
optimizer: |
|
The optimizer we are using. |
|
scheduler: |
|
The learning rate scheduler, we call step() every step. |
|
graph_compiler: |
|
It is used to build a decoding graph from a ctc topo and training |
|
transcript. The training transcript is contained in the given `batch`, |
|
while the ctc topo is built when this compiler is instantiated. |
|
train_dl: |
|
Dataloader for the training dataset. |
|
valid_dl: |
|
Dataloader for the validation dataset. |
|
scaler: |
|
The scaler used for mix precision training. |
|
model_avg: |
|
The stored model averaged from the start of training. |
|
tb_writer: |
|
Writer to write log messages to tensorboard. |
|
world_size: |
|
Number of nodes in DDP training. If it is 1, DDP is disabled. |
|
rank: |
|
The rank of the node in DDP training. If no DDP is used, it should |
|
be set to 0. |
|
""" |
|
model.train() |
|
|
|
tot_loss = MetricsTracker() |
|
|
|
for batch_idx, batch in enumerate(train_dl): |
|
params.batch_idx_train += 1 |
|
batch_size = len(batch["supervisions"]["text"]) |
|
|
|
with torch.cuda.amp.autocast(enabled=params.use_fp16): |
|
loss, loss_info = compute_loss( |
|
params=params, |
|
model=model, |
|
graph_compiler=graph_compiler, |
|
batch=batch, |
|
is_training=True, |
|
warmup=(params.batch_idx_train / params.model_warm_step), |
|
) |
|
|
|
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info |
|
|
|
|
|
|
|
scaler.scale(loss).backward() |
|
scheduler.step_batch(params.batch_idx_train) |
|
scaler.step(optimizer) |
|
scaler.update() |
|
optimizer.zero_grad() |
|
|
|
if params.print_diagnostics and batch_idx == 30: |
|
return |
|
|
|
if ( |
|
rank == 0 |
|
and params.batch_idx_train > 0 |
|
and params.batch_idx_train % params.average_period == 0 |
|
): |
|
update_averaged_model( |
|
params=params, |
|
model_cur=model, |
|
model_avg=model_avg, |
|
) |
|
|
|
if ( |
|
params.batch_idx_train > 0 |
|
and params.batch_idx_train % params.save_every_n == 0 |
|
): |
|
save_checkpoint_with_global_batch_idx( |
|
out_dir=params.exp_dir, |
|
global_batch_idx=params.batch_idx_train, |
|
model=model, |
|
model_avg=model_avg, |
|
params=params, |
|
optimizer=optimizer, |
|
scheduler=scheduler, |
|
sampler=train_dl.sampler, |
|
scaler=scaler, |
|
rank=rank, |
|
) |
|
remove_checkpoints( |
|
out_dir=params.exp_dir, |
|
topk=params.keep_last_k, |
|
rank=rank, |
|
) |
|
|
|
if batch_idx % params.log_interval == 0: |
|
cur_lr = scheduler.get_last_lr()[0] |
|
logging.info( |
|
f"Epoch {params.cur_epoch}, " |
|
f"batch {batch_idx}, loss[{loss_info}], " |
|
f"tot_loss[{tot_loss}], batch size: {batch_size}, " |
|
f"lr: {cur_lr:.2e}" |
|
) |
|
|
|
if tb_writer is not None: |
|
tb_writer.add_scalar( |
|
"train/learning_rate", cur_lr, params.batch_idx_train |
|
) |
|
|
|
loss_info.write_summary( |
|
tb_writer, "train/current_", params.batch_idx_train |
|
) |
|
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) |
|
|
|
if batch_idx > 0 and batch_idx % params.valid_interval == 0: |
|
logging.info("Computing validation loss") |
|
valid_info = compute_validation_loss( |
|
params=params, |
|
model=model, |
|
graph_compiler=graph_compiler, |
|
valid_dl=valid_dl, |
|
world_size=world_size, |
|
) |
|
model.train() |
|
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") |
|
if tb_writer is not None: |
|
valid_info.write_summary( |
|
tb_writer, "train/valid_", params.batch_idx_train |
|
) |
|
|
|
loss_value = tot_loss["loss"] / tot_loss["frames"] |
|
params.train_loss = loss_value |
|
if params.train_loss < params.best_train_loss: |
|
params.best_train_epoch = params.cur_epoch |
|
params.best_train_loss = params.train_loss |
|
|
|
|
|
def run(rank, world_size, args): |
|
""" |
|
Args: |
|
rank: |
|
It is a value between 0 and `world_size-1`, which is |
|
passed automatically by `mp.spawn()` in :func:`main`. |
|
The node with rank 0 is responsible for saving checkpoint. |
|
world_size: |
|
Number of GPUs for DDP training. |
|
args: |
|
The return value of get_parser().parse_args() |
|
""" |
|
params = get_params() |
|
params.update(vars(args)) |
|
if params.full_libri is False: |
|
params.valid_interval = 1600 |
|
|
|
fix_random_seed(params.seed) |
|
if world_size > 1: |
|
setup_dist(rank, world_size, params.master_port) |
|
|
|
setup_logger(f"{params.exp_dir}/log/log-train") |
|
logging.info("Training started") |
|
|
|
if args.tensorboard and rank == 0: |
|
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") |
|
else: |
|
tb_writer = None |
|
|
|
lexicon = Lexicon(params.lang_dir) |
|
max_token_id = max(lexicon.tokens) |
|
params.vocab_size = max_token_id + 1 |
|
|
|
device = torch.device("cpu") |
|
if torch.cuda.is_available(): |
|
device = torch.device("cuda", rank) |
|
logging.info(f"Device: {device}") |
|
|
|
if "lang_bpe" in str(params.lang_dir): |
|
graph_compiler = BpeCtcTrainingGraphCompiler( |
|
params.lang_dir, |
|
device=device, |
|
sos_token="<sos/eos>", |
|
eos_token="<sos/eos>", |
|
) |
|
elif "lang_phone" in str(params.lang_dir): |
|
graph_compiler = CtcTrainingGraphCompiler( |
|
lexicon, |
|
device=device, |
|
need_repeat_flag=params.delay_penalty > 0, |
|
) |
|
|
|
|
|
graph_compiler.sos_id = 1 |
|
graph_compiler.eos_id = 1 |
|
else: |
|
raise ValueError( |
|
f"Unsupported type of lang dir (we expected it to have " |
|
f"'lang_bpe' or 'lang_phone' in its name): {params.lang_dir}" |
|
) |
|
|
|
if params.dynamic_chunk_training: |
|
assert ( |
|
params.causal_convolution |
|
), "dynamic_chunk_training requires causal convolution" |
|
|
|
logging.info(params) |
|
|
|
logging.info("About to create model") |
|
model = get_ctc_model(params) |
|
|
|
num_param = sum([p.numel() for p in model.parameters()]) |
|
logging.info(f"Number of model parameters: {num_param}") |
|
|
|
assert params.save_every_n >= params.average_period |
|
model_avg: Optional[nn.Module] = None |
|
if rank == 0: |
|
|
|
model_avg = copy.deepcopy(model) |
|
|
|
assert params.start_epoch > 0, params.start_epoch |
|
checkpoints = load_checkpoint_if_available( |
|
params=params, model=model, model_avg=model_avg |
|
) |
|
|
|
model.to(device) |
|
if world_size > 1: |
|
logging.info("Using DDP") |
|
model = DDP(model, device_ids=[rank]) |
|
|
|
optimizer = Eve(model.parameters(), lr=params.initial_lr) |
|
|
|
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) |
|
|
|
if checkpoints and "optimizer" in checkpoints: |
|
logging.info("Loading optimizer state dict") |
|
optimizer.load_state_dict(checkpoints["optimizer"]) |
|
|
|
if ( |
|
checkpoints |
|
and "scheduler" in checkpoints |
|
and checkpoints["scheduler"] is not None |
|
): |
|
logging.info("Loading scheduler state dict") |
|
scheduler.load_state_dict(checkpoints["scheduler"]) |
|
|
|
if params.print_diagnostics: |
|
diagnostic = diagnostics.attach_diagnostics(model) |
|
|
|
librispeech = LibriSpeechAsrDataModule(args) |
|
|
|
train_cuts = librispeech.train_clean_100_cuts() |
|
|
|
|
|
|
|
|
|
def remove_short_and_long_utt(c: Cut): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return 1.0 <= c.duration <= 20.0 |
|
|
|
train_cuts = train_cuts.filter(remove_short_and_long_utt) |
|
|
|
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: |
|
|
|
|
|
sampler_state_dict = checkpoints["sampler"] |
|
else: |
|
sampler_state_dict = None |
|
|
|
train_dl = librispeech.train_dataloaders( |
|
train_cuts, sampler_state_dict=sampler_state_dict |
|
) |
|
|
|
valid_cuts = librispeech.dev_clean_cuts() |
|
|
|
valid_dl = librispeech.valid_dataloaders(valid_cuts) |
|
|
|
if params.start_batch <= 0 and not params.print_diagnostics: |
|
scan_pessimistic_batches_for_oom( |
|
model=model, |
|
train_dl=train_dl, |
|
optimizer=optimizer, |
|
graph_compiler=graph_compiler, |
|
params=params, |
|
warmup=0.0 if params.start_epoch == 1 else 1.0, |
|
) |
|
|
|
scaler = GradScaler(enabled=params.use_fp16) |
|
if checkpoints and "grad_scaler" in checkpoints: |
|
logging.info("Loading grad scaler state dict") |
|
scaler.load_state_dict(checkpoints["grad_scaler"]) |
|
|
|
for epoch in range(params.start_epoch, params.num_epochs + 1): |
|
scheduler.step_epoch(epoch - 1) |
|
fix_random_seed(params.seed + epoch - 1) |
|
train_dl.sampler.set_epoch(epoch - 1) |
|
|
|
if tb_writer is not None: |
|
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) |
|
|
|
params.cur_epoch = epoch |
|
|
|
train_one_epoch( |
|
params=params, |
|
model=model, |
|
model_avg=model_avg, |
|
optimizer=optimizer, |
|
scheduler=scheduler, |
|
graph_compiler=graph_compiler, |
|
train_dl=train_dl, |
|
valid_dl=valid_dl, |
|
scaler=scaler, |
|
tb_writer=tb_writer, |
|
world_size=world_size, |
|
rank=rank, |
|
) |
|
|
|
if params.print_diagnostics: |
|
diagnostic.print_diagnostics() |
|
break |
|
|
|
save_checkpoint( |
|
params=params, |
|
model=model, |
|
model_avg=model_avg, |
|
optimizer=optimizer, |
|
scheduler=scheduler, |
|
sampler=train_dl.sampler, |
|
scaler=scaler, |
|
rank=rank, |
|
) |
|
|
|
logging.info("Done!") |
|
|
|
if world_size > 1: |
|
torch.distributed.barrier() |
|
cleanup_dist() |
|
|
|
|
|
def scan_pessimistic_batches_for_oom( |
|
model: Union[nn.Module, DDP], |
|
train_dl: torch.utils.data.DataLoader, |
|
optimizer: torch.optim.Optimizer, |
|
graph_compiler: Union[BpeCtcTrainingGraphCompiler, CtcTrainingGraphCompiler], |
|
params: AttributeDict, |
|
warmup: float, |
|
): |
|
from lhotse.dataset import find_pessimistic_batches |
|
|
|
logging.info( |
|
"Sanity check -- see if any of the batches in epoch 1 would cause OOM." |
|
) |
|
batches, crit_values = find_pessimistic_batches(train_dl.sampler) |
|
for criterion, cuts in batches.items(): |
|
batch = train_dl.dataset[cuts] |
|
try: |
|
with torch.cuda.amp.autocast(enabled=params.use_fp16): |
|
loss, _ = compute_loss( |
|
params=params, |
|
model=model, |
|
graph_compiler=graph_compiler, |
|
batch=batch, |
|
is_training=True, |
|
warmup=warmup, |
|
) |
|
loss.backward() |
|
optimizer.step() |
|
optimizer.zero_grad() |
|
except RuntimeError as e: |
|
if "CUDA out of memory" in str(e): |
|
logging.error( |
|
"Your GPU ran out of memory with the current " |
|
"max_duration setting. We recommend decreasing " |
|
"max_duration and trying again.\n" |
|
f"Failing criterion: {criterion} " |
|
f"(={crit_values[criterion]}) ..." |
|
) |
|
raise |
|
|
|
|
|
def main(): |
|
parser = get_parser() |
|
LibriSpeechAsrDataModule.add_arguments(parser) |
|
args = parser.parse_args() |
|
args.exp_dir = Path(args.exp_dir) |
|
|
|
world_size = args.world_size |
|
assert world_size >= 1 |
|
if world_size > 1: |
|
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) |
|
else: |
|
run(rank=0, world_size=1, args=args) |
|
|
|
|
|
torch.set_num_threads(1) |
|
torch.set_num_interop_threads(1) |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|