super-image / app.py
Eugene Siow
Add pre download and setup of models first and DRLN and EDSR models.
41506bf
raw
history blame
3.39 kB
import torch
import gradio as gr
from random import randint
from pathlib import Path
from super_image import ImageLoader, EdsrModel, MsrnModel, MdsrModel, AwsrnModel, A2nModel, CarnModel, PanModel, \
HanModel, DrlnModel, RcanModel
title = "super-image"
description = "State of the Art Image Super-Resolution Models."
article = "<p style='text-align: center'><a href='https://github.com/eugenesiow/super-image'>Github Repo</a>" \
"| <a href='https://eugenesiow.github.io/super-image/'>Documentation</a> " \
"| <a href='https://github.com/eugenesiow/super-image#scale-x2'>Models</a></p>"
def get_model(model_name, scale):
if model_name == 'EDSR':
model = EdsrModel.from_pretrained('eugenesiow/edsr', scale=scale)
elif model_name == 'MSRN':
model = MsrnModel.from_pretrained('eugenesiow/msrn', scale=scale)
elif model_name == 'MDSR':
model = MdsrModel.from_pretrained('eugenesiow/mdsr', scale=scale)
elif model_name == 'AWSRN-BAM':
model = AwsrnModel.from_pretrained('eugenesiow/awsrn-bam', scale=scale)
elif model_name == 'A2N':
model = A2nModel.from_pretrained('eugenesiow/a2n', scale=scale)
elif model_name == 'CARN':
model = CarnModel.from_pretrained('eugenesiow/carn', scale=scale)
elif model_name == 'PAN':
model = PanModel.from_pretrained('eugenesiow/pan', scale=scale)
elif model_name == 'HAN':
model = HanModel.from_pretrained('eugenesiow/han', scale=scale)
elif model_name == 'DRLN':
model = DrlnModel.from_pretrained('eugenesiow/drln', scale=scale)
elif model_name == 'RCAN':
model = RcanModel.from_pretrained('eugenesiow/rcan', scale=scale)
else:
model = EdsrModel.from_pretrained('eugenesiow/edsr-base', scale=scale)
return model
def inference(img, scale_str, model_name):
_id = randint(1, 1000)
output_dir = Path('./tmp/')
output_dir.mkdir(parents=True, exist_ok=True)
output_file = output_dir / ('output_image' + str(_id) + '.jpg')
scale = int(scale_str.replace('x', ''))
model = get_model(model_name, scale)
inputs = ImageLoader.load_image(img)
preds = model(inputs)
output_file_str = str(output_file.resolve())
ImageLoader.save_image(preds, output_file_str)
return output_file_str
torch.hub.download_url_to_file('http://people.rennes.inria.fr/Aline.Roumy/results/images_SR_BMVC12/input_groundtruth/baby_mini_d3_gaussian.bmp',
'baby.bmp')
torch.hub.download_url_to_file('http://people.rennes.inria.fr/Aline.Roumy/results/images_SR_BMVC12/input_groundtruth/woman_mini_d3_gaussian.bmp',
'woman.bmp')
models = ['EDSR-base', 'DRLN', 'EDSR', 'MSRN', 'MDSR', 'AWSRN-BAM', 'A2N', 'CARN', 'PAN']
scales = [2, 3, 4]
for model_name in models:
for scale in scales:
get_model(model_name, scale)
gr.Interface(
inference,
[
gr.inputs.Image(type="pil", label="Input"),
gr.inputs.Radio(["x2", "x3", "x4"], label='scale'),
gr.inputs.Dropdown(choices=models,
label='Model')
],
gr.outputs.Image(type="file", label="Output"),
title=title,
description=description,
article=article,
examples=[
['baby.bmp', 'x2', 'EDSR-base'],
['woman.bmp', 'x3', 'DRLN']
],
enable_queue=True
).launch(debug=False)