|
|
|
|
|
|
|
|
|
import os, sys |
|
import argparse |
|
from pathlib import Path |
|
|
|
from omegaconf import OmegaConf |
|
from sampler import ResShiftSampler |
|
|
|
from utils.util_opts import str2bool |
|
from basicsr.utils.download_util import load_file_from_url |
|
|
|
_STEP = { |
|
'v1': 15, |
|
'v2': 15, |
|
'v3': 4, |
|
'bicsr': 4, |
|
'inpaint_imagenet': 4, |
|
'inpaint_face': 4, |
|
'faceir': 4, |
|
} |
|
_LINK = { |
|
'vqgan': 'https://github.com/zsyOAOA/ResShift/releases/download/v2.0/autoencoder_vq_f4.pth', |
|
'vqgan_face256': 'https://github.com/zsyOAOA/ResShift/releases/download/v2.0/celeba256_vq_f4_dim3_face.pth', |
|
'vqgan_face512': 'https://github.com/zsyOAOA/ResShift/releases/download/v2.0/ffhq512_vq_f8_dim8_face.pth', |
|
'v1': 'https://github.com/zsyOAOA/ResShift/releases/download/v2.0/resshift_realsrx4_s15_v1.pth', |
|
'v2': 'https://github.com/zsyOAOA/ResShift/releases/download/v2.0/resshift_realsrx4_s15_v2.pth', |
|
'v3': 'https://github.com/zsyOAOA/ResShift/releases/download/v2.0/resshift_realsrx4_s4_v3.pth', |
|
'bicsr': 'https://github.com/zsyOAOA/ResShift/releases/download/v2.0/resshift_bicsrx4_s4.pth', |
|
'inpaint_imagenet': 'https://github.com/zsyOAOA/ResShift/releases/download/v2.0/resshift_inpainting_imagenet_s4.pth', |
|
'inpaint_face': 'https://github.com/zsyOAOA/ResShift/releases/download/v2.0/resshift_inpainting_face_s4.pth', |
|
'faceir': 'https://github.com/zsyOAOA/ResShift/releases/download/v2.0/resshift_faceir_s4.pth', |
|
} |
|
|
|
def get_parser(**parser_kwargs): |
|
parser = argparse.ArgumentParser(**parser_kwargs) |
|
parser.add_argument("-i", "--in_path", type=str, default="", help="Input path.") |
|
parser.add_argument("-o", "--out_path", type=str, default="./results", help="Output path.") |
|
parser.add_argument("--mask_path", type=str, default="", help="Mask path for inpainting.") |
|
parser.add_argument("--scale", type=int, default=4, help="Scale factor for SR.") |
|
parser.add_argument("--seed", type=int, default=12345, help="Random seed.") |
|
parser.add_argument("--bs", type=int, default=1, help="Batch size.") |
|
parser.add_argument( |
|
"-v", |
|
"--version", |
|
type=str, |
|
default="v1", |
|
choices=["v1", "v2", "v3"], |
|
help="Checkpoint version.", |
|
) |
|
parser.add_argument( |
|
"--chop_size", |
|
type=int, |
|
default=512, |
|
choices=[512, 256, 64], |
|
help="Chopping forward.", |
|
) |
|
parser.add_argument( |
|
"--chop_stride", |
|
type=int, |
|
default=-1, |
|
help="Chopping stride.", |
|
) |
|
parser.add_argument( |
|
"--task", |
|
type=str, |
|
default="realsr", |
|
choices=['realsr', 'bicsr', 'inpaint_imagenet', 'inpaint_face', 'faceir'], |
|
help="Chopping forward.", |
|
) |
|
args = parser.parse_args() |
|
|
|
return args |
|
|
|
def get_configs(args): |
|
ckpt_dir = Path('./weights') |
|
if not ckpt_dir.exists(): |
|
ckpt_dir.mkdir() |
|
|
|
if args.task == 'realsr': |
|
if args.version in ['v1', 'v2']: |
|
configs = OmegaConf.load('./configs/realsr_swinunet_realesrgan256.yaml') |
|
elif args.version == 'v3': |
|
configs = OmegaConf.load('./configs/realsr_swinunet_realesrgan256_journal.yaml') |
|
else: |
|
raise ValueError(f"Unexpected version type: {args.version}") |
|
assert args.scale == 4, 'We only support the 4x super-resolution now!' |
|
ckpt_url = _LINK[args.version] |
|
ckpt_path = ckpt_dir / f'resshift_{args.task}x{args.scale}_s{_STEP[args.version]}_{args.version}.pth' |
|
vqgan_url = _LINK['vqgan'] |
|
vqgan_path = ckpt_dir / f'autoencoder_vq_f4.pth' |
|
elif args.task == 'bicsr': |
|
configs = OmegaConf.load('./configs/bicx4_swinunet_lpips.yaml') |
|
assert args.scale == 4, 'We only support the 4x super-resolution now!' |
|
ckpt_url = _LINK[args.task] |
|
ckpt_path = ckpt_dir / f'resshift_{args.task}x{args.scale}_s{_STEP[args.task]}.pth' |
|
vqgan_url = _LINK['vqgan'] |
|
vqgan_path = ckpt_dir / f'autoencoder_vq_f4.pth' |
|
elif args.task == 'inpaint_imagenet': |
|
configs = OmegaConf.load('./configs/inpaint_lama256_imagenet.yaml') |
|
assert args.scale == 1, 'Please set scale equals 1 for image inpainting!' |
|
ckpt_url = _LINK[args.task] |
|
ckpt_path = ckpt_dir / f'resshift_{args.task}_s{_STEP[args.task]}.pth' |
|
vqgan_url = _LINK['vqgan'] |
|
vqgan_path = ckpt_dir / f'autoencoder_vq_f4.pth' |
|
elif args.task == 'inpaint_face': |
|
configs = OmegaConf.load('./configs/inpaint_lama256_face.yaml') |
|
assert args.scale == 1, 'Please set scale equals 1 for image inpainting!' |
|
ckpt_url = _LINK[args.task] |
|
ckpt_path = ckpt_dir / f'resshift_{args.task}_s{_STEP[args.task]}.pth' |
|
vqgan_url = _LINK['vqgan_face256'] |
|
vqgan_path = ckpt_dir / f'celeba256_vq_f4_dim3_face.pth' |
|
elif args.task == 'faceir': |
|
configs = OmegaConf.load('./configs/faceir_gfpgan512_lpips.yaml') |
|
assert args.scale == 1, 'Please set scale equals 1 for face restoration!' |
|
ckpt_url = _LINK[args.task] |
|
ckpt_path = ckpt_dir / f'resshift_{args.task}_s{_STEP[args.task]}.pth' |
|
vqgan_url = _LINK['vqgan_face512'] |
|
vqgan_path = ckpt_dir / f'ffhq512_vq_f8_dim8_face.pth' |
|
else: |
|
raise TypeError(f"Unexpected task type: {args.task}!") |
|
|
|
|
|
if not ckpt_path.exists(): |
|
load_file_from_url( |
|
url=ckpt_url, |
|
model_dir=ckpt_dir, |
|
progress=True, |
|
file_name=ckpt_path.name, |
|
) |
|
if not vqgan_path.exists(): |
|
load_file_from_url( |
|
url=vqgan_url, |
|
model_dir=ckpt_dir, |
|
progress=True, |
|
file_name=vqgan_path.name, |
|
) |
|
|
|
configs.model.ckpt_path = str(ckpt_path) |
|
configs.diffusion.params.sf = args.scale |
|
configs.autoencoder.ckpt_path = str(vqgan_path) |
|
|
|
|
|
if not Path(args.out_path).exists(): |
|
Path(args.out_path).mkdir(parents=True) |
|
|
|
if args.chop_stride < 0: |
|
if args.chop_size == 512: |
|
chop_stride = (512 - 64) * (4 // args.scale) |
|
elif args.chop_size == 256: |
|
chop_stride = (256 - 32) * (4 // args.scale) |
|
elif args.chop_size == 64: |
|
chop_stride = (64 - 16) * (4 // args.scale) |
|
else: |
|
raise ValueError("Chop size must be in [512, 256]") |
|
else: |
|
chop_stride = args.chop_stride * (4 // args.scale) |
|
args.chop_size *= (4 // args.scale) |
|
print(f"Chopping size/stride: {args.chop_size}/{chop_stride}") |
|
|
|
return configs, chop_stride |
|
|
|
def main(): |
|
args = get_parser() |
|
|
|
configs, chop_stride = get_configs(args) |
|
|
|
resshift_sampler = ResShiftSampler( |
|
configs, |
|
sf=args.scale, |
|
chop_size=args.chop_size, |
|
chop_stride=chop_stride, |
|
chop_bs=1, |
|
use_amp=True, |
|
seed=args.seed, |
|
padding_offset=configs.model.params.get('lq_size', 64), |
|
) |
|
|
|
|
|
if args.task.startswith('inpaint'): |
|
assert args.mask_path, 'Please input the mask path for inpainting!' |
|
mask_path = args.mask_path |
|
else: |
|
mask_path = None |
|
|
|
resshift_sampler.inference( |
|
args.in_path, |
|
args.out_path, |
|
mask_path=mask_path, |
|
bs=args.bs, |
|
noise_repeat=False |
|
) |
|
|
|
if __name__ == '__main__': |
|
main() |
|
|