|
|
|
|
|
|
|
|
|
import argparse |
|
import gradio as gr |
|
from pathlib import Path |
|
|
|
from omegaconf import OmegaConf |
|
from sampler import ResShiftSampler |
|
|
|
from utils import util_image |
|
from basicsr.utils.download_util import load_file_from_url |
|
|
|
_STEP = { |
|
'v1': 15, |
|
'v2': 15, |
|
'v3': 4, |
|
'bicsr': 4, |
|
'inpaint_imagenet': 4, |
|
'inpaint_face': 4, |
|
'faceir': 4, |
|
} |
|
_LINK = { |
|
'vqgan': 'https://github.com/zsyOAOA/ResShift/releases/download/v2.0/autoencoder_vq_f4.pth', |
|
'vqgan_face256': 'https://github.com/zsyOAOA/ResShift/releases/download/v2.0/celeba256_vq_f4_dim3_face.pth', |
|
'vqgan_face512': 'https://github.com/zsyOAOA/ResShift/releases/download/v2.0/ffhq512_vq_f8_dim8_face.pth', |
|
'v1': 'https://github.com/zsyOAOA/ResShift/releases/download/v2.0/resshift_realsrx4_s15_v1.pth', |
|
'v2': 'https://github.com/zsyOAOA/ResShift/releases/download/v2.0/resshift_realsrx4_s15_v2.pth', |
|
'v3': 'https://github.com/zsyOAOA/ResShift/releases/download/v2.0/resshift_realsrx4_s4_v3.pth', |
|
'bicsr': 'https://github.com/zsyOAOA/ResShift/releases/download/v2.0/resshift_bicsrx4_s4.pth', |
|
'inpaint_imagenet': 'https://github.com/zsyOAOA/ResShift/releases/download/v2.0/resshift_inpainting_imagenet_s4.pth', |
|
'inpaint_face': 'https://github.com/zsyOAOA/ResShift/releases/download/v2.0/resshift_inpainting_face_s4.pth', |
|
'faceir': 'https://github.com/zsyOAOA/ResShift/releases/download/v2.0/resshift_faceir_s4.pth', |
|
} |
|
|
|
def get_configs(task='realsr', version='v3', scale=4): |
|
ckpt_dir = Path('./weights') |
|
if not ckpt_dir.exists(): |
|
ckpt_dir.mkdir() |
|
|
|
if task == 'realsr': |
|
if version in ['v1', 'v2']: |
|
configs = OmegaConf.load('./configs/realsr_swinunet_realesrgan256.yaml') |
|
elif version == 'v3': |
|
configs = OmegaConf.load('./configs/realsr_swinunet_realesrgan256_journal.yaml') |
|
else: |
|
raise ValueError(f"Unexpected version type: {version}") |
|
assert scale == 4, 'We only support the 4x super-resolution now!' |
|
ckpt_url = _LINK[version] |
|
ckpt_path = ckpt_dir / f'resshift_{task}x{scale}_s{_STEP[version]}_{version}.pth' |
|
vqgan_url = _LINK['vqgan'] |
|
vqgan_path = ckpt_dir / f'autoencoder_vq_f4.pth' |
|
elif task == 'bicsr': |
|
configs = OmegaConf.load('./configs/bicx4_swinunet_lpips.yaml') |
|
assert scale == 4, 'We only support the 4x super-resolution now!' |
|
ckpt_url = _LINK[task] |
|
ckpt_path = ckpt_dir / f'resshift_{task}x{scale}_s{_STEP[task]}.pth' |
|
vqgan_url = _LINK['vqgan'] |
|
vqgan_path = ckpt_dir / f'autoencoder_vq_f4.pth' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
else: |
|
raise TypeError(f"Unexpected task type: {task}!") |
|
|
|
|
|
if not ckpt_path.exists(): |
|
load_file_from_url( |
|
url=ckpt_url, |
|
model_dir=ckpt_dir, |
|
progress=True, |
|
file_name=ckpt_path.name, |
|
) |
|
if not vqgan_path.exists(): |
|
load_file_from_url( |
|
url=vqgan_url, |
|
model_dir=ckpt_dir, |
|
progress=True, |
|
file_name=vqgan_path.name, |
|
) |
|
|
|
configs.model.ckpt_path = str(ckpt_path) |
|
configs.diffusion.params.sf = scale |
|
configs.autoencoder.ckpt_path = str(vqgan_path) |
|
|
|
return configs |
|
|
|
def predict(in_path, task='realsrx4', seed=12345, scale=4, version='v3'): |
|
configs = get_configs(task, version, scale) |
|
resshift_sampler = ResShiftSampler( |
|
configs, |
|
sf=scale, |
|
chop_size=256, |
|
chop_stride=224, |
|
chop_bs=1, |
|
use_amp=True, |
|
seed=seed, |
|
padding_offset=configs.model.params.get('lq_size', 64), |
|
) |
|
|
|
out_dir = Path('restored_output') |
|
if not out_dir.exists(): |
|
out_dir.mkdir() |
|
|
|
resshift_sampler.inference( |
|
in_path, |
|
out_dir, |
|
mask_path=None, |
|
bs=1, |
|
noise_repeat=False |
|
) |
|
|
|
out_path = out_dir / f"{Path(in_path).stem}.png" |
|
assert out_path.exists(), 'Super-resolution failed!' |
|
im_sr = util_image.imread(out_path, chn="rgb", dtype="uint8") |
|
|
|
return im_sr, str(out_path) |
|
|
|
title = "ResShift: Efficient Diffusion Model for Image Super-resolution by Residual Shifting" |
|
description = r""" |
|
<b>Official Gradio demo</b> for <a href='https://github.com/zsyOAOA/ResShift' target='_blank'><b>ResShift: Efficient Diffusion Model for Image Super-resolution by Residual Shifting</b></a>.<br> |
|
π₯ ResShift is an efficient diffusion model designed for image super-resolution or restoration.<br> |
|
""" |
|
article = r""" |
|
If ResShift is helpful for your work, please help to β the <a href='https://github.com/zsyOAOA/ResShift' target='_blank'>Github Repo</a>. Thanks! |
|
[![GitHub Stars](https://img.shields.io/github/stars/zsyOAOA/ResShift?affiliations=OWNER&color=green&style=social)](https://github.com/zsyOAOA/ResShift) |
|
|
|
--- |
|
If our work is useful for your research, please consider citing: |
|
```bibtex |
|
@inproceedings{yue2023resshift, |
|
title={ResShift: Efficient Diffusion Model for Image Super-resolution by Residual Shifting}, |
|
author={Yue, Zongsheng and Wang, Jianyi and Loy, Chen Change}, |
|
booktitle = {Advances in Neural Information Processing Systems (NeurIPS)}, |
|
year={2023}, |
|
volume = {36}, |
|
pages = {13294--13307}, |
|
} |
|
``` |
|
|
|
π **License** |
|
|
|
This project is licensed under <a rel="license" href="https://github.com/zsyOAOA/ResShift/blob/master/LICENSE">S-Lab License 1.0</a>. |
|
Redistribution and use for non-commercial purposes should follow this license. |
|
|
|
π§ **Contact** |
|
|
|
If you have any questions, please feel free to contact me via <b>[email protected]</b>. |
|
![visitors](https://visitor-badge.laobi.icu/badge?page_id=zsyOAOA/ResShift) |
|
""" |
|
demo = gr.Interface( |
|
fn=predict, |
|
inputs=[ |
|
gr.Image(type="filepath", label="Input: Low Quality Image"), |
|
gr.Dropdown( |
|
choices=["realsr", "bicsr"], |
|
value="realsr", |
|
label="Task", |
|
), |
|
gr.Number(value=12345, precision=0, label="Ranom seed") |
|
], |
|
outputs=[ |
|
gr.Image(type="numpy", label="Output: High Quality Image"), |
|
gr.File(label="Download the output") |
|
], |
|
title=title, |
|
description=description, |
|
article=article, |
|
examples=[ |
|
['./testdata/RealSet65/0030.jpg', "realsr", 12345], |
|
['./testdata/RealSet65/dog2.png', "realsr", 12345], |
|
['./testdata/RealSet65/bears.jpg', "realsr", 12345], |
|
['./testdata/RealSet65/oldphoto6.png', "realsr", 12345], |
|
['./testdata/Bicubicx4/lq_matlab/ILSVRC2012_val_00000067.png', "bicsr", 12345], |
|
['./testdata/Bicubicx4/lq_matlab/ILSVRC2012_val_00016898.png', "bicsr", 12345], |
|
] |
|
) |
|
|
|
demo.queue(concurrency_count=4) |
|
demo.launch(share=True) |
|
|
|
|