videoenhancer / frame_esrgan.py
peterkros's picture
Upload 12 files
9df91a5 verified
import torch
import torchvision
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
import cv2
import argparse
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import image_slicer
from image_slicer import join
from PIL import Image
import numpy as np
from tqdm import tqdm
def convert_from_image_to_cv2(img: Image) -> np.ndarray:
# return np.asarray(img)
return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
def upscale(model_path, im_path):
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
upsampler = RealESRGANer(scale=4, model_path=model_path, model=model, tile=0, tile_pad=10, pre_pad=0, half=False)
img = cv2.imread(im_path, cv2.IMREAD_UNCHANGED)
output, _ = upsampler.enhance(img, outscale=4)
return output
def upscale_slice(model_path, image, slice):
width, height = Image.open(image).size
tiles = image_slicer.slice(image, slice, save=False)
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
upsampler = RealESRGANer(scale=4, model_path=model_path, model=model, tile=0, tile_pad=10, pre_pad=0, half=False)
for tile in tiles:
output, _ = upsampler.enhance(np.array(tile.image), outscale=4)
tile.image = Image.fromarray(output)
tile.coords = (tile.coords[0]*4, tile.coords[1]*4)
return convert_from_image_to_cv2(join(tiles, width=width*4, height=height*4))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-m', '--model_path', type=str, help='REQUIRED: specify path of the model being used')
parser.add_argument('-i', '--input', type=str, help='REQUIRED: specify path of the image you want to upscale')
parser.add_argument('-o', '--output', type=str, help='REQUIRED: specify path where you want to save image')
parser.add_argument('-v', '--visualize', action='store_true', help='OPTIONAL: add this to see how image looks before and after upscale')
parser.add_argument('-s', '--slice', nargs='?', type=int, const=4, help='OPTIONAL: specify weather to split frames, recommended to use to help with VRAM unless you got a fucken quadro or something')
parser.add_argument('-r', '--resize', nargs='?', type=str, const='1920x1080', help="OPTIONAL: specify whether to resize image to a specific resolution. Specify with widthxheight, for example 1920x1080")
args = parser.parse_args()
if args.model_path and args.input and args.output:
if args.slice:
output = upscale_slice(args.model_path, args.input, args.slice)
else:
output = upscale(args.model_path, args.input)
if args.visualize:
plt.imshow(mpimg.imread(args.input))
plt.show()
plt.imshow(output)
plt.show()
if args.resize:
size = tuple(int(i) for i in args.resize.split('x'))
output = cv2.resize(output, size)
cv2.imwrite(args.output, output)
else:
print('Error: Missing arguments, check -h, --help for details')
# tiles = image_slicer.slice('tmp/{}/original/{}'.format(folder_name, i), slice, save=False)
# print(tiles)
# for tile in tiles:
# up = frame_esrgan.upscale_slice(args.model_path, np.array(tile.image))
# tile.image = Image.fromarray(up, 'RGB')
# out = join(tiles)
# out.save('tmp/{}/upscaled/{}'.format(folder_name, i.replace('jpg', 'png')))