|
from easydict import EasyDict |
|
|
|
|
|
CONFIG = EasyDict({}) |
|
|
|
CONFIG.is_default = True |
|
CONFIG.version = "baseline" |
|
CONFIG.phase = "train" |
|
|
|
CONFIG.dist = False |
|
CONFIG.wandb = False |
|
|
|
CONFIG.local_rank = 0 |
|
CONFIG.gpu = 0 |
|
CONFIG.world_size = 1 |
|
|
|
|
|
CONFIG.model = EasyDict({}) |
|
|
|
CONFIG.model.freeze_seg = True |
|
CONFIG.model.multi_scale = False |
|
CONFIG.model.imagenet_pretrain = True |
|
CONFIG.model.imagenet_pretrain_path = "/home/liyaoyi/Source/python/attentionMatting/pretrain/model_best_resnet34_En_nomixup.pth" |
|
CONFIG.model.batch_size = 16 |
|
|
|
CONFIG.model.mask_channel = 1 |
|
CONFIG.model.trimap_channel = 3 |
|
|
|
|
|
CONFIG.model.self_refine_width1 = 30 |
|
CONFIG.model.self_refine_width2 = 15 |
|
CONFIG.model.self_mask_width = 10 |
|
|
|
|
|
CONFIG.model.arch = EasyDict({}) |
|
|
|
CONFIG.model.arch.encoder = "res_shortcut_encoder_29" |
|
CONFIG.model.arch.decoder = "res_shortcut_decoder_22" |
|
CONFIG.model.arch.m2m = "conv_baseline" |
|
CONFIG.model.arch.seg = "maskrcnn" |
|
|
|
CONFIG.model.arch.discriminator = None |
|
|
|
|
|
|
|
CONFIG.data = EasyDict({}) |
|
CONFIG.data.cutmask_prob = 0 |
|
CONFIG.data.workers = 0 |
|
CONFIG.data.pha_ratio = 0.5 |
|
|
|
CONFIG.data.train_fg = None |
|
CONFIG.data.train_alpha = None |
|
CONFIG.data.train_bg = None |
|
CONFIG.data.test_merged = None |
|
CONFIG.data.test_alpha = None |
|
CONFIG.data.test_trimap = None |
|
CONFIG.data.imagematte_fg = None |
|
CONFIG.data.imagematte_pha = None |
|
CONFIG.data.d646_fg = None |
|
CONFIG.data.d646_pha = None |
|
CONFIG.data.aim_fg = None |
|
CONFIG.data.aim_pha = None |
|
CONFIG.data.human2k_fg = None |
|
CONFIG.data.human2k_pha = None |
|
CONFIG.data.am2k_fg = None |
|
CONFIG.data.am2k_pha = None |
|
CONFIG.data.coco_bg = None |
|
CONFIG.data.bg20k_bg = None |
|
CONFIG.data.rim_pha = None |
|
CONFIG.data.rim_img = None |
|
CONFIG.data.spd_pha = None |
|
CONFIG.data.spd_img = None |
|
|
|
CONFIG.data.crop_size = 1024 |
|
|
|
CONFIG.data.real_world_aug = False |
|
CONFIG.data.augmentation = True |
|
CONFIG.data.random_interp = True |
|
|
|
|
|
CONFIG.benchmark = EasyDict({}) |
|
CONFIG.benchmark.him2k_img = '/home/jiachen.li/data/HIM2K/images/natural' |
|
CONFIG.benchmark.him2k_alpha = '/home/jiachen.li/data/HIM2K/alphas/natural' |
|
CONFIG.benchmark.him2k_comp_img = '/home/jiachen.li/data/HIM2K/images/comp' |
|
CONFIG.benchmark.him2k_comp_alpha = '/home/jiachen.li/data/HIM2K/alphas/comp' |
|
CONFIG.benchmark.rwp636_img = '/home/jiachen.li/data/RealWorldPortrait-636/image' |
|
CONFIG.benchmark.rwp636_alpha = '/home/jiachen.li/data/RealWorldPortrait-636/alpha' |
|
CONFIG.benchmark.ppm100_img = '/home/jiachen.li/data/PPM-100/image' |
|
CONFIG.benchmark.ppm100_alpha = '/home/jiachen.li/data/PPM-100/matte' |
|
CONFIG.benchmark.am2k_img = '/home/jiachen.li/data/AM2k/validation/original' |
|
CONFIG.benchmark.am2k_alpha = '/home/jiachen.li/data/AM2k/validation/mask' |
|
CONFIG.benchmark.rw100_img = '/home/jiachen.li/data/RefMatte_RW_100/image_all' |
|
CONFIG.benchmark.rw100_alpha = '/home/jiachen.li/data/RefMatte_RW_100/mask' |
|
CONFIG.benchmark.rw100_text = '/home/jiachen.li/data/RefMatte_RW_100/refmatte_rw100_label.json' |
|
CONFIG.benchmark.rw100_index = '/home/jiachen.li/data/RefMatte_RW_100/eval_index_expression.json' |
|
CONFIG.benchmark.vm_img = '/home/jiachen.li/data/videomatte_512x288' |
|
|
|
|
|
CONFIG.train = EasyDict({}) |
|
CONFIG.train.total_step = 100000 |
|
CONFIG.train.warmup_step = 5000 |
|
CONFIG.train.val_step = 1000 |
|
|
|
CONFIG.train.G_lr = 1e-3 |
|
|
|
CONFIG.train.beta1 = 0.5 |
|
CONFIG.train.beta2 = 0.999 |
|
|
|
CONFIG.train.rec_weight = 1 |
|
CONFIG.train.comp_weight = 1 |
|
CONFIG.train.lap_weight = 1 |
|
|
|
CONFIG.train.clip_grad = True |
|
|
|
CONFIG.train.resume_checkpoint = None |
|
|
|
CONFIG.train.reset_lr = False |
|
|
|
|
|
|
|
CONFIG.log = EasyDict({}) |
|
CONFIG.log.tensorboard_path = "./logs/tensorboard" |
|
CONFIG.log.tensorboard_step = 100 |
|
|
|
CONFIG.log.tensorboard_image_step = 500 |
|
CONFIG.log.logging_path = "./logs/stdout" |
|
CONFIG.log.logging_step = 10 |
|
CONFIG.log.logging_level = "DEBUG" |
|
CONFIG.log.checkpoint_path = "./checkpoints" |
|
CONFIG.log.checkpoint_step = 10000 |
|
|
|
|
|
def load_config(custom_config, default_config=CONFIG, prefix="CONFIG"): |
|
""" |
|
This function will recursively overwrite the default config by a custom config |
|
:param default_config: |
|
:param custom_config: parsed from config/config.toml |
|
:param prefix: prefix for config key |
|
:return: None |
|
""" |
|
if "is_default" in default_config: |
|
default_config.is_default = False |
|
|
|
for key in custom_config.keys(): |
|
full_key = ".".join([prefix, key]) |
|
if key not in default_config: |
|
raise NotImplementedError("Unknown config key: {}".format(full_key)) |
|
elif isinstance(custom_config[key], dict): |
|
if isinstance(default_config[key], dict): |
|
load_config(default_config=default_config[key], |
|
custom_config=custom_config[key], |
|
prefix=full_key) |
|
else: |
|
raise ValueError("{}: Expected {}, got dict instead.".format(full_key, type(custom_config[key]))) |
|
else: |
|
if isinstance(default_config[key], dict): |
|
raise ValueError("{}: Expected dict, got {} instead.".format(full_key, type(custom_config[key]))) |
|
else: |
|
default_config[key] = custom_config[key] |
|
|
|
|
|
|