import argparse |
import random |
import torch |
import yaml |
from collections import OrderedDict |
from os import path as osp |
from basicsr.utils import set_random_seed |
from basicsr.utils.dist_util import get_dist_info, init_dist, master_only |
def ordered_yaml(): |
"""Support OrderedDict for yaml. |
Returns: |
yaml Loader and Dumper. |
""" |
try: |
from yaml import CDumper as Dumper |
from yaml import CLoader as Loader |
except ImportError: |
from yaml import Dumper, Loader |
_mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG |
def dict_representer(dumper, data): |
return dumper.represent_dict(data.items()) |
def dict_constructor(loader, node): |
return OrderedDict(loader.construct_pairs(node)) |
Dumper.add_representer(OrderedDict, dict_representer) |
Loader.add_constructor(_mapping_tag, dict_constructor) |
return Loader, Dumper |
def dict2str(opt, indent_level=1): |
"""dict to string for printing options. |
Args: |
opt (dict): Option dict. |
indent_level (int): Indent level. Default: 1. |
Return: |
(str): Option string for printing. |
""" |
msg = '\n' |
for k, v in opt.items(): |
if isinstance(v, dict): |
msg += ' ' * (indent_level * 2) + k + ':[' |
msg += dict2str(v, indent_level + 1) |
msg += ' ' * (indent_level * 2) + ']\n' |
else: |
msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n' |
return msg |
def _postprocess_yml_value(value): |
if value == '~' or value.lower() == 'none': |
return None |
if value.lower() == 'true': |
return True |
elif value.lower() == 'false': |
return False |
if value.startswith('!!float'): |
return float(value.replace('!!float', '')) |
if value.isdigit(): |
return int(value) |
elif value.replace('.', '', 1).isdigit() and value.count('.') < 2: |
return float(value) |
if value.startswith('['): |
return eval(value) |
return value |
def parse_options(root_path, SR, is_train=True): |
parser = argparse.ArgumentParser() |
if SR == 'x4': |
file_path = osp.join(root_path,'options/test/test_RGT_x4.yml') |
if SR == 'x2': |
file_path = osp.join(root_path,'options/test/test_RGT_x2.yml') |
parser.add_argument('-opt', type=str, default = file_path, help='Path to option YAML file.') |
parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm'], default='none', help='job launcher') |
parser.add_argument('--auto_resume', action='store_true') |
parser.add_argument('--debug', action='store_true') |
parser.add_argument('--local_rank', type=int, default=0) |
parser.add_argument( |
'--force_yml', nargs='+', default=None, help='Force to update yml files. Examples: train:ema_decay=0.999') |
args = parser.parse_args() |
with open(args.opt, mode='r') as f: |
opt = yaml.load(f, Loader=ordered_yaml()[0]) |
if args.launcher == 'none': |
opt['dist'] = False |
print('Disable distributed.', flush=True) |
else: |
opt['dist'] = True |
if args.launcher == 'slurm' and 'dist_params' in opt: |
init_dist(args.launcher, **opt['dist_params']) |
else: |
init_dist(args.launcher) |
opt['rank'], opt['world_size'] = get_dist_info() |
seed = opt.get('manual_seed') |
if seed is None: |
seed = random.randint(1, 10000) |
opt['manual_seed'] = seed |
set_random_seed(seed + opt['rank']) |
if args.force_yml is not None: |
for entry in args.force_yml: |
keys, value = entry.split('=') |
keys, value = keys.strip(), value.strip() |
value = _postprocess_yml_value(value) |
eval_str = 'opt' |
for key in keys.split(':'): |
eval_str += f'["{key}"]' |
eval_str += '=value' |
exec(eval_str) |
opt['auto_resume'] = args.auto_resume |
opt['is_train'] = is_train |
if args.debug and not opt['name'].startswith('debug'): |
opt['name'] = 'debug_' + opt['name'] |
if opt['num_gpu'] == 'auto': |
opt['num_gpu'] = torch.cuda.device_count() |
for phase, dataset in opt['datasets'].items(): |
phase = phase.split('_')[0] |
dataset['phase'] = phase |
if 'scale' in opt: |
dataset['scale'] = opt['scale'] |
if dataset.get('dataroot_gt') is not None: |
dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt']) |
if dataset.get('dataroot_lq') is not None: |
dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq']) |
for key, val in opt['path'].items(): |
if (val is not None) and ('resume_state' in key or 'pretrain_network' in key): |
opt['path'][key] = osp.expanduser(val) |
if is_train: |
experiments_root = osp.join(root_path, 'experiments', opt['name']) |
opt['path']['experiments_root'] = experiments_root |
opt['path']['models'] = osp.join(experiments_root, 'models') |
opt['path']['training_states'] = osp.join(experiments_root, 'training_states') |
opt['path']['log'] = experiments_root |
opt['path']['visualization'] = osp.join(experiments_root, 'visualization') |
if 'debug' in opt['name']: |
if 'val' in opt: |
opt['val']['val_freq'] = 8 |
opt['logger']['print_freq'] = 1 |
opt['logger']['save_checkpoint_freq'] = 8 |
else: |
results_root = osp.join(root_path, 'results', opt['name']) |
opt['path']['results_root'] = results_root |
opt['path']['log'] = results_root |
opt['path']['visualization'] = osp.join(results_root, 'visualization') |
return opt, args |
@master_only |
def copy_opt_file(opt_file, experiments_root): |
import sys |
import time |
from shutil import copyfile |
cmd = ' '.join(sys.argv) |
filename = osp.join(experiments_root, osp.basename(opt_file)) |
copyfile(opt_file, filename) |
with open(filename, 'r+') as f: |
lines = f.readlines() |
lines.insert(0, f'# GENERATE TIME: {time.asctime()}\n# CMD:\n# {cmd}\n\n') |
f.seek(0) |
f.writelines(lines) |