Spaces:
Build error
Build error
import os | |
import os.path as osp | |
import logging | |
import yaml | |
from util import OrderedYaml | |
Loader, Dumper = OrderedYaml() | |
def parse(opt_path, is_train=True): | |
with open(opt_path, mode='r') as f: | |
opt = yaml.load(f, Loader=Loader) | |
# export CUDA_VISIBLE_DEVICES | |
gpu_list = ','.join(str(x) for x in opt.get('gpu_ids', [])) | |
# os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list | |
# print('export CUDA_VISIBLE_DEVICES=' + gpu_list) | |
opt['is_train'] = is_train | |
if opt['distortion'] == 'sr': | |
scale = opt['scale'] | |
# datasets | |
for phase, dataset in opt['datasets'].items(): | |
phase = phase.split('_')[0] | |
dataset['phase'] = phase | |
if opt['distortion'] == 'sr': | |
dataset['scale'] = scale | |
is_lmdb = False | |
if dataset.get('dataroot_GT', None) is not None: | |
dataset['dataroot_GT'] = osp.expanduser(dataset['dataroot_GT']) | |
if dataset['dataroot_GT'].endswith('lmdb'): | |
is_lmdb = True | |
if dataset.get('dataroot_LQ', None) is not None: | |
dataset['dataroot_LQ'] = osp.expanduser(dataset['dataroot_LQ']) | |
if dataset['dataroot_LQ'].endswith('lmdb'): | |
is_lmdb = True | |
dataset['data_type'] = 'lmdb' if is_lmdb else 'img' | |
# if dataset['mode'].endswith('mc'): # for memcached | |
# dataset['data_type'] = 'mc' | |
# dataset['mode'] = dataset['mode'].replace('_mc', '') | |
# path | |
for key, path in opt['path'].items(): | |
if path and key in opt['path'] and key != 'strict_load': | |
opt['path'][key] = osp.expanduser(path) | |
opt['path']['root'] = osp.abspath( | |
osp.join(__file__, osp.pardir, osp.pardir, osp.pardir)) | |
if is_train: | |
experiments_root = osp.join( | |
opt['path']['root'], 'experiments', opt['name']) | |
opt['path']['experiments_root'] = experiments_root | |
opt['path']['models'] = osp.join(experiments_root, 'models') | |
opt['path']['training_state'] = osp.join( | |
experiments_root, 'training_state') | |
opt['path']['log'] = experiments_root | |
opt['path']['val_images'] = osp.join(experiments_root, 'val_images') | |
# change some options for debug mode | |
if 'debug' in opt['name']: | |
opt['train']['val_freq'] = 8 | |
opt['logger']['print_freq'] = 1 | |
opt['logger']['save_checkpoint_freq'] = 8 | |
else: # test | |
if not opt['path'].get('results_root', None): | |
results_root = osp.join( | |
opt['path']['root'], 'results', opt['name']) | |
opt['path']['results_root'] = results_root | |
opt['path']['log'] = opt['path']['results_root'] | |
# network | |
if opt['distortion'] == 'sr': | |
opt['network_G']['scale'] = scale | |
# relative learning rate | |
if 'train' in opt: | |
niter = opt['train']['niter'] | |
if 'T_period_rel' in opt['train']: | |
opt['train']['T_period'] = [int(x * niter) | |
for x in opt['train']['T_period_rel']] | |
if 'restarts_rel' in opt['train']: | |
opt['train']['restarts'] = [int(x * niter) | |
for x in opt['train']['restarts_rel']] | |
if 'lr_steps_rel' in opt['train']: | |
opt['train']['lr_steps'] = [int(x * niter) | |
for x in opt['train']['lr_steps_rel']] | |
if 'lr_steps_inverse_rel' in opt['train']: | |
opt['train']['lr_steps_inverse'] = [ | |
int(x * niter) for x in opt['train']['lr_steps_inverse_rel']] | |
print(opt['train']) | |
return opt | |
def dict2str(opt, indent_l=1): | |
'''dict to string for logger''' | |
msg = '' | |
for k, v in opt.items(): | |
if isinstance(v, dict): | |
msg += ' ' * (indent_l * 2) + k + ':[\n' | |
msg += dict2str(v, indent_l + 1) | |
msg += ' ' * (indent_l * 2) + ']\n' | |
else: | |
msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n' | |
return msg | |
class NoneDict(dict): | |
def __missing__(self, key): | |
return None | |
# convert to NoneDict, which return None for missing key. | |
def dict_to_nonedict(opt): | |
if isinstance(opt, dict): | |
new_opt = dict() | |
for key, sub_opt in opt.items(): | |
new_opt[key] = dict_to_nonedict(sub_opt) | |
return NoneDict(**new_opt) | |
elif isinstance(opt, list): | |
return [dict_to_nonedict(sub_opt) for sub_opt in opt] | |
else: | |
return opt | |
def check_resume(opt, resume_iter): | |
'''Check resume states and pretrain_model paths''' | |
logger = logging.getLogger('base') | |
if opt['path']['resume_state']: | |
if opt['path'].get('pretrain_model_G', None) is not None or opt['path'].get( | |
'pretrain_model_D', None) is not None: | |
logger.warning( | |
'pretrain_model path will be ignored when resuming training.') | |
opt['path']['pretrain_model_G'] = osp.join(opt['path']['models'], | |
'{}_G.pth'.format(resume_iter)) | |
logger.info('Set [pretrain_model_G] to ' + | |
opt['path']['pretrain_model_G']) | |
if 'gan' in opt['model']: | |
opt['path']['pretrain_model_D'] = osp.join(opt['path']['models'], | |
'{}_D.pth'.format(resume_iter)) | |
logger.info('Set [pretrain_model_D] to ' + | |
opt['path']['pretrain_model_D']) | |