import timeit import numpy as np import os import os.path as osp import shutil import copy import torch import torch.nn as nn import torch.distributed as dist from .cfg_holder import cfg_unique_holder as cfguh from . import sync print_console_local_rank0_only = True def print_log(*console_info): local_rank = sync.get_rank('local') if print_console_local_rank0_only and (local_rank!=0): return console_info = [str(i) for i in console_info] console_info = ' '.join(console_info) print(console_info) if local_rank!=0: return log_file = None try: log_file = cfguh().cfg.train.log_file except: try: log_file = cfguh().cfg.eval.log_file except: return if log_file is not None: with open(log_file, 'a') as f: f.write(console_info + '\n') class distributed_log_manager(object): def __init__(self): self.sum = {} self.cnt = {} self.time_check = timeit.default_timer() cfgt = cfguh().cfg.train use_tensorboard = getattr(cfgt, 'log_tensorboard', False) self.ddp = sync.is_ddp() self.rank = sync.get_rank('local') self.world_size = sync.get_world_size('local') self.tb = None if use_tensorboard and (self.rank==0): import tensorboardX monitoring_dir = osp.join(cfguh().cfg.train.log_dir, 'tensorboard') self.tb = tensorboardX.SummaryWriter(osp.join(monitoring_dir)) def accumulate(self, n, **data): if n < 0: raise ValueError for itemn, di in data.items(): if itemn in self.sum: self.sum[itemn] += di * n self.cnt[itemn] += n else: self.sum[itemn] = di * n self.cnt[itemn] = n def get_mean_value_dict(self): value_gather = [ self.sum[itemn]/self.cnt[itemn] \ for itemn in sorted(self.sum.keys()) ] value_gather_tensor = torch.FloatTensor(value_gather).to(self.rank) if self.ddp: dist.all_reduce(value_gather_tensor, op=dist.ReduceOp.SUM) value_gather_tensor /= self.world_size mean = {} for idx, itemn in enumerate(sorted(self.sum.keys())): mean[itemn] = value_gather_tensor[idx].item() return mean def tensorboard_log(self, step, data, mode='train', **extra): if self.tb is None: return if mode == 'train': self.tb.add_scalar('other/epochn', extra['epochn'], step) if 'lr' in extra: self.tb.add_scalar('other/lr', extra['lr'], step) for itemn, di in data.items(): if itemn.find('loss') == 0: self.tb.add_scalar('loss/'+itemn, di, step) elif itemn == 'Loss': self.tb.add_scalar('Loss', di, step) else: self.tb.add_scalar('other/'+itemn, di, step) elif mode == 'eval': if isinstance(data, dict): for itemn, di in data.items(): self.tb.add_scalar('eval/'+itemn, di, step) else: self.tb.add_scalar('eval', data, step) return def train_summary(self, itern, epochn, samplen, lr, tbstep=None): console_info = [ 'Iter:{}'.format(itern), 'Epoch:{}'.format(epochn), 'Sample:{}'.format(samplen),] if lr is not None: console_info += ['LR:{:.4E}'.format(lr)] mean = self.get_mean_value_dict() tbstep = itern if tbstep is None else tbstep self.tensorboard_log( tbstep, mean, mode='train', itern=itern, epochn=epochn, lr=lr) loss = mean.pop('Loss') mean_info = ['Loss:{:.4f}'.format(loss)] + [ '{}:{:.4f}'.format(itemn, mean[itemn]) \ for itemn in sorted(mean.keys()) \ if itemn.find('loss') == 0 ] console_info += mean_info console_info.append('Time:{:.2f}s'.format( timeit.default_timer() - self.time_check)) return ' , '.join(console_info) def clear(self): self.sum = {} self.cnt = {} self.time_check = timeit.default_timer() def tensorboard_close(self): if self.tb is not None: self.tb.close() # ----- also include some small utils ----- def torch_to_numpy(*argv): if len(argv) > 1: data = list(argv) else: data = argv[0] if isinstance(data, torch.Tensor): return data.to('cpu').detach().numpy() elif isinstance(data, (list, tuple)): out = [] for di in data: out.append(torch_to_numpy(di)) return out elif isinstance(data, dict): out = {} for ni, di in data.items(): out[ni] = torch_to_numpy(di) return out else: return data