import io import math import os from pathlib import Path from typing import Optional, Union import cv2 import numpy as np from basicsr.archs.rrdbnet_arch import RRDBNet from basicsr.archs.srvgg_arch import SRVGGNetCompact from basicsr.utils.download_util import load_file_from_url from gfpgan import GFPGANer from PIL import Image from realesrgan import RealESRGANer # pyright: ignore import internals.util.image as ImageUtil from internals.util.commons import download_image from internals.util.config import get_root_dir from models.ultrasharp.model import Ultrasharp class Upscaler: __model_esrgan_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth" __model_esrgan_anime_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth" __model_gfpgan_url = ( "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth" ) __model_4x_ultrasharp_url = ( "https://comic-assets.s3.ap-south-1.amazonaws.com/models/4x-UltraSharp.pth" ) __loaded = False def load(self): if self.__loaded: return download_dir = Path(Path.home() / ".cache" / "realesrgan") download_dir.mkdir(parents=True, exist_ok=True) self.__model_path = self.__preload_model(self.__model_esrgan_url, download_dir) self.__model_path_anime = self.__preload_model( self.__model_esrgan_anime_url, download_dir ) self.__model_path_gfpgan = self.__preload_model( self.__model_gfpgan_url, download_dir ) self.__model_path_4x_ultrasharp = self.__preload_model( self.__model_4x_ultrasharp_url, download_dir ) self.__loaded = True def upscale( self, image: Union[str, Image.Image], width: int, height: int, face_enhance: bool, resize_dimension: Optional[int] = None, ) -> bytes: "if resize dimension is not provided, use the smaller of width and height" self.load() model = SRVGGNetCompact( num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type="prelu", ) return self.__internal_upscale( image, resize_dimension, # type: ignore face_enhance, width, height, self.__model_path, model, ) def upscale_anime( self, image: Union[str, Image.Image], width: int, height: int, face_enhance: bool, resize_dimension: int, ) -> bytes: "if resize dimension is not provided, use the smaller of width and height" self.load() model = RRDBNet( num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4, ) return self.__internal_upscale( image, resize_dimension, face_enhance, width, height, self.__model_path_anime, model, ) def __preload_model(self, url: str, download_dir: Path): name = url.split("/")[-1] if not os.path.exists(str(download_dir / name)): return load_file_from_url( url=url, model_dir=str(download_dir), progress=True, file_name=None, ) else: return str(download_dir / name) def __internal_upscale( self, image, resize_dimension: int, face_enhance: bool, width: int, height: int, model_path: str, model, ) -> bytes: if type(image) is str: image = download_image(image, mode="RGBA") w, h = image.size # if max(w, h) > 1024: # image = ImageUtil.resize_image(image, dimension=1024) in_path = str(Path.home() / ".cache" / "input_upscale.png") image.save(in_path) input_image = cv2.imread(in_path, cv2.IMREAD_UNCHANGED) dimension = max(input_image.shape[0], input_image.shape[1]) if not resize_dimension: resize_dimension = max(width, height) scale = max(math.floor(resize_dimension / dimension), 2) print("Upscaling by: ", scale) os.chdir(str(Path.home() / ".cache")) if scale == 4: print("Using 4x-Ultrasharp") upsampler = Ultrasharp( model_path=self.__model_path_4x_ultrasharp, tile=320, tile_pad=10, ) else: print("Using RealESRGANer") upsampler = RealESRGANer( scale=4, model_path=model_path, model=model, half=False, gpu_id="0", tile=320, tile_pad=10, pre_pad=0, ) face_enhancer = GFPGANer( model_path=self.__model_path_gfpgan, upscale=scale, arch="clean", channel_multiplier=2, bg_upsampler=upsampler, ) if face_enhance: _, _, output = face_enhancer.enhance( input_image, has_aligned=False, only_center_face=False, paste_back=True ) else: output, _ = upsampler.enhance(input_image, outscale=scale) os.chdir(get_root_dir()) cv2.imwrite("out.png", output) out_bytes = cv2.imencode(".png", output)[1].tobytes() return out_bytes @staticmethod def to_pil(buffer: bytes, mode="RGB") -> Image.Image: return Image.open(io.BytesIO(buffer)).convert(mode)