|
import argparse |
|
import logging |
|
import os |
|
import random |
|
|
|
import torch |
|
from fastai.callbacks.general_sched import GeneralScheduler, TrainingPhase |
|
from fastai.distributed import * |
|
from fastai.vision import * |
|
from torch.backends import cudnn |
|
|
|
from callbacks import DumpPrediction, IterationCallback, TextAccuracy, TopKTextAccuracy |
|
from dataset import ImageDataset, TextDataset |
|
from losses import MultiLosses |
|
from utils import Config, Logger, MyDataParallel, MyConcatDataset |
|
|
|
|
|
def _set_random_seed(seed): |
|
if seed is not None: |
|
random.seed(seed) |
|
torch.manual_seed(seed) |
|
cudnn.deterministic = True |
|
logging.warning('You have chosen to seed training. ' |
|
'This will slow down your training!') |
|
|
|
def _get_training_phases(config, n): |
|
lr = np.array(config.optimizer_lr) |
|
periods = config.optimizer_scheduler_periods |
|
sigma = [config.optimizer_scheduler_gamma ** i for i in range(len(periods))] |
|
phases = [TrainingPhase(n * periods[i]).schedule_hp('lr', lr * sigma[i]) |
|
for i in range(len(periods))] |
|
return phases |
|
|
|
def _get_dataset(ds_type, paths, is_training, config, **kwargs): |
|
kwargs.update({ |
|
'img_h': config.dataset_image_height, |
|
'img_w': config.dataset_image_width, |
|
'max_length': config.dataset_max_length, |
|
'case_sensitive': config.dataset_case_sensitive, |
|
'charset_path': config.dataset_charset_path, |
|
'data_aug': config.dataset_data_aug, |
|
'deteriorate_ratio': config.dataset_deteriorate_ratio, |
|
'is_training': is_training, |
|
'multiscales': config.dataset_multiscales, |
|
'one_hot_y': config.dataset_one_hot_y, |
|
}) |
|
datasets = [ds_type(p, **kwargs) for p in paths] |
|
if len(datasets) > 1: return MyConcatDataset(datasets) |
|
else: return datasets[0] |
|
|
|
|
|
def _get_language_databaunch(config): |
|
kwargs = { |
|
'max_length': config.dataset_max_length, |
|
'case_sensitive': config.dataset_case_sensitive, |
|
'charset_path': config.dataset_charset_path, |
|
'smooth_label': config.dataset_smooth_label, |
|
'smooth_factor': config.dataset_smooth_factor, |
|
'one_hot_y': config.dataset_one_hot_y, |
|
'use_sm': config.dataset_use_sm, |
|
} |
|
train_ds = TextDataset(config.dataset_train_roots[0], is_training=True, **kwargs) |
|
valid_ds = TextDataset(config.dataset_test_roots[0], is_training=False, **kwargs) |
|
data = DataBunch.create( |
|
path=train_ds.path, |
|
train_ds=train_ds, |
|
valid_ds=valid_ds, |
|
bs=config.dataset_train_batch_size, |
|
val_bs=config.dataset_test_batch_size, |
|
num_workers=config.dataset_num_workers, |
|
pin_memory=config.dataset_pin_memory) |
|
logging.info(f'{len(data.train_ds)} training items found.') |
|
if not data.empty_val: |
|
logging.info(f'{len(data.valid_ds)} valid items found.') |
|
return data |
|
|
|
def _get_databaunch(config): |
|
|
|
if config.global_phase == 'test': config.dataset_train_roots = config.dataset_test_roots |
|
train_ds = _get_dataset(ImageDataset, config.dataset_train_roots, True, config) |
|
valid_ds = _get_dataset(ImageDataset, config.dataset_test_roots, False, config) |
|
data = ImageDataBunch.create( |
|
train_ds=train_ds, |
|
valid_ds=valid_ds, |
|
bs=config.dataset_train_batch_size, |
|
val_bs=config.dataset_test_batch_size, |
|
num_workers=config.dataset_num_workers, |
|
pin_memory=config.dataset_pin_memory).normalize(imagenet_stats) |
|
ar_tfm = lambda x: ((x[0], x[1]), x[1]) |
|
data.add_tfm(ar_tfm) |
|
|
|
logging.info(f'{len(data.train_ds)} training items found.') |
|
if not data.empty_val: |
|
logging.info(f'{len(data.valid_ds)} valid items found.') |
|
|
|
return data |
|
|
|
def _get_model(config): |
|
import importlib |
|
names = config.model_name.split('.') |
|
module_name, class_name = '.'.join(names[:-1]), names[-1] |
|
cls = getattr(importlib.import_module(module_name), class_name) |
|
model = cls(config) |
|
logging.info(model) |
|
return model |
|
|
|
|
|
def _get_learner(config, data, model, local_rank=None): |
|
strict = ifnone(config.model_strict, True) |
|
if config.global_stage == 'pretrain-language': |
|
metrics = [TopKTextAccuracy( |
|
k=ifnone(config.model_k, 5), |
|
charset_path=config.dataset_charset_path, |
|
max_length=config.dataset_max_length + 1, |
|
case_sensitive=config.dataset_eval_case_sensisitves, |
|
model_eval=config.model_eval)] |
|
else: |
|
metrics = [TextAccuracy( |
|
charset_path=config.dataset_charset_path, |
|
max_length=config.dataset_max_length + 1, |
|
case_sensitive=config.dataset_eval_case_sensisitves, |
|
model_eval=config.model_eval)] |
|
opt_type = getattr(torch.optim, config.optimizer_type) |
|
learner = Learner(data, model, silent=True, model_dir='.', |
|
true_wd=config.optimizer_true_wd, |
|
wd=config.optimizer_wd, |
|
bn_wd=config.optimizer_bn_wd, |
|
path=config.global_workdir, |
|
metrics=metrics, |
|
opt_func=partial(opt_type, **config.optimizer_args or dict()), |
|
loss_func=MultiLosses(one_hot=config.dataset_one_hot_y)) |
|
learner.split(lambda m: children(m)) |
|
|
|
if config.global_phase == 'train': |
|
num_replicas = 1 if local_rank is None else torch.distributed.get_world_size() |
|
phases = _get_training_phases(config, len(learner.data.train_dl)//num_replicas) |
|
learner.callback_fns += [ |
|
partial(GeneralScheduler, phases=phases), |
|
partial(GradientClipping, clip=config.optimizer_clip_grad), |
|
partial(IterationCallback, name=config.global_name, |
|
show_iters=config.training_show_iters, |
|
eval_iters=config.training_eval_iters, |
|
save_iters=config.training_save_iters, |
|
start_iters=config.training_start_iters, |
|
stats_iters=config.training_stats_iters)] |
|
else: |
|
learner.callbacks += [ |
|
DumpPrediction(learn=learner, |
|
dataset='-'.join([Path(p).name for p in config.dataset_test_roots]),charset_path=config.dataset_charset_path, |
|
model_eval=config.model_eval, |
|
debug=config.global_debug, |
|
image_only=config.global_image_only)] |
|
|
|
learner.rank = local_rank |
|
if local_rank is not None: |
|
logging.info(f'Set model to distributed with rank {local_rank}.') |
|
learner.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(learner.model) |
|
learner.model.to(local_rank) |
|
learner = learner.to_distributed(local_rank) |
|
|
|
if torch.cuda.device_count() > 1 and local_rank is None: |
|
logging.info(f'Use {torch.cuda.device_count()} GPUs.') |
|
learner.model = MyDataParallel(learner.model) |
|
|
|
if config.model_checkpoint: |
|
if Path(config.model_checkpoint).exists(): |
|
with open(config.model_checkpoint, 'rb') as f: |
|
buffer = io.BytesIO(f.read()) |
|
learner.load(buffer, strict=strict) |
|
else: |
|
from distutils.dir_util import copy_tree |
|
src = Path('/data/fangsc/model')/config.global_name |
|
trg = Path('/output')/config.global_name |
|
if src.exists(): copy_tree(str(src), str(trg)) |
|
learner.load(config.model_checkpoint, strict=strict) |
|
logging.info(f'Read model from {config.model_checkpoint}') |
|
elif config.global_phase == 'test': |
|
learner.load(f'best-{config.global_name}', strict=strict) |
|
logging.info(f'Read model from best-{config.global_name}') |
|
|
|
if learner.opt_func.func.__name__ == 'Adadelta': |
|
learner.fit(epochs=0, lr=config.optimizer_lr) |
|
learner.opt.mom = 0. |
|
|
|
return learner |
|
|
|
def main(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--config', type=str, required=True, |
|
help='path to config file') |
|
parser.add_argument('--phase', type=str, default=None, choices=['train', 'test']) |
|
parser.add_argument('--name', type=str, default=None) |
|
parser.add_argument('--checkpoint', type=str, default=None) |
|
parser.add_argument('--test_root', type=str, default=None) |
|
parser.add_argument("--local_rank", type=int, default=None) |
|
parser.add_argument('--debug', action='store_true', default=None) |
|
parser.add_argument('--image_only', action='store_true', default=None) |
|
parser.add_argument('--model_strict', action='store_false', default=None) |
|
parser.add_argument('--model_eval', type=str, default=None, |
|
choices=['alignment', 'vision', 'language']) |
|
args = parser.parse_args() |
|
config = Config(args.config) |
|
if args.name is not None: config.global_name = args.name |
|
if args.phase is not None: config.global_phase = args.phase |
|
if args.test_root is not None: config.dataset_test_roots = [args.test_root] |
|
if args.checkpoint is not None: config.model_checkpoint = args.checkpoint |
|
if args.debug is not None: config.global_debug = args.debug |
|
if args.image_only is not None: config.global_image_only = args.image_only |
|
if args.model_eval is not None: config.model_eval = args.model_eval |
|
if args.model_strict is not None: config.model_strict = args.model_strict |
|
|
|
Logger.init(config.global_workdir, config.global_name, config.global_phase) |
|
Logger.enable_file() |
|
_set_random_seed(config.global_seed) |
|
logging.info(config) |
|
|
|
if args.local_rank is not None: |
|
logging.info(f'Init distribution training at device {args.local_rank}.') |
|
torch.cuda.set_device(args.local_rank) |
|
torch.distributed.init_process_group(backend='nccl', init_method='env://') |
|
|
|
logging.info('Construct dataset.') |
|
if config.global_stage == 'pretrain-language': data = _get_language_databaunch(config) |
|
else: data = _get_databaunch(config) |
|
|
|
logging.info('Construct model.') |
|
model = _get_model(config) |
|
|
|
logging.info('Construct learner.') |
|
learner = _get_learner(config, data, model, args.local_rank) |
|
|
|
if config.global_phase == 'train': |
|
logging.info('Start training.') |
|
learner.fit(epochs=config.training_epochs, |
|
lr=config.optimizer_lr) |
|
else: |
|
logging.info('Start validate') |
|
last_metrics = learner.validate() |
|
log_str = f'eval loss = {last_metrics[0]:6.3f}, ' \ |
|
f'ccr = {last_metrics[1]:6.3f}, cwr = {last_metrics[2]:6.3f}, ' \ |
|
f'ted = {last_metrics[3]:6.3f}, ned = {last_metrics[4]:6.0f}, ' \ |
|
f'ted/w = {last_metrics[5]:6.3f}, ' |
|
logging.info(log_str) |
|
|
|
if __name__ == '__main__': |
|
main() |
|
|