|
import io |
|
import math |
|
import os |
|
from pathlib import Path |
|
from typing import Optional, Union |
|
|
|
import cv2 |
|
import numpy as np |
|
from basicsr.archs.rrdbnet_arch import RRDBNet |
|
from basicsr.archs.srvgg_arch import SRVGGNetCompact |
|
from basicsr.utils.download_util import load_file_from_url |
|
from gfpgan import GFPGANer |
|
from PIL import Image |
|
from realesrgan import RealESRGANer |
|
|
|
import internals.util.image as ImageUtil |
|
from internals.util.commons import download_image |
|
from internals.util.config import get_root_dir |
|
from models.ultrasharp.model import Ultrasharp |
|
|
|
|
|
class Upscaler: |
|
__model_esrgan_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth" |
|
__model_esrgan_anime_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth" |
|
__model_gfpgan_url = ( |
|
"https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth" |
|
) |
|
__model_4x_ultrasharp_url = ( |
|
"https://comic-assets.s3.ap-south-1.amazonaws.com/models/4x-UltraSharp.pth" |
|
) |
|
|
|
__loaded = False |
|
|
|
def load(self): |
|
if self.__loaded: |
|
return |
|
|
|
download_dir = Path(Path.home() / ".cache" / "realesrgan") |
|
download_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
self.__model_path = self.__preload_model(self.__model_esrgan_url, download_dir) |
|
self.__model_path_anime = self.__preload_model( |
|
self.__model_esrgan_anime_url, download_dir |
|
) |
|
self.__model_path_gfpgan = self.__preload_model( |
|
self.__model_gfpgan_url, download_dir |
|
) |
|
self.__model_path_4x_ultrasharp = self.__preload_model( |
|
self.__model_4x_ultrasharp_url, download_dir |
|
) |
|
self.__loaded = True |
|
|
|
def upscale( |
|
self, |
|
image: Union[str, Image.Image], |
|
width: int, |
|
height: int, |
|
face_enhance: bool, |
|
resize_dimension: Optional[int] = None, |
|
) -> bytes: |
|
"if resize dimension is not provided, use the smaller of width and height" |
|
|
|
self.load() |
|
|
|
model = SRVGGNetCompact( |
|
num_in_ch=3, |
|
num_out_ch=3, |
|
num_feat=64, |
|
num_conv=32, |
|
upscale=4, |
|
act_type="prelu", |
|
) |
|
return self.__internal_upscale( |
|
image, |
|
resize_dimension, |
|
face_enhance, |
|
width, |
|
height, |
|
self.__model_path, |
|
model, |
|
) |
|
|
|
def upscale_anime( |
|
self, |
|
image: Union[str, Image.Image], |
|
width: int, |
|
height: int, |
|
face_enhance: bool, |
|
resize_dimension: int, |
|
) -> bytes: |
|
"if resize dimension is not provided, use the smaller of width and height" |
|
|
|
self.load() |
|
|
|
model = RRDBNet( |
|
num_in_ch=3, |
|
num_out_ch=3, |
|
num_feat=64, |
|
num_block=6, |
|
num_grow_ch=32, |
|
scale=4, |
|
) |
|
return self.__internal_upscale( |
|
image, |
|
resize_dimension, |
|
face_enhance, |
|
width, |
|
height, |
|
self.__model_path_anime, |
|
model, |
|
) |
|
|
|
def __preload_model(self, url: str, download_dir: Path): |
|
name = url.split("/")[-1] |
|
if not os.path.exists(str(download_dir / name)): |
|
return load_file_from_url( |
|
url=url, |
|
model_dir=str(download_dir), |
|
progress=True, |
|
file_name=None, |
|
) |
|
else: |
|
return str(download_dir / name) |
|
|
|
def __internal_upscale( |
|
self, |
|
image, |
|
resize_dimension: int, |
|
face_enhance: bool, |
|
width: int, |
|
height: int, |
|
model_path: str, |
|
model, |
|
) -> bytes: |
|
if type(image) is str: |
|
image = download_image(image, mode="RGBA") |
|
|
|
w, h = image.size |
|
|
|
|
|
|
|
in_path = str(Path.home() / ".cache" / "input_upscale.png") |
|
image.save(in_path) |
|
input_image = cv2.imread(in_path, cv2.IMREAD_UNCHANGED) |
|
dimension = max(input_image.shape[0], input_image.shape[1]) |
|
if not resize_dimension: |
|
resize_dimension = max(width, height) |
|
scale = max(math.floor(resize_dimension / dimension), 2) |
|
|
|
print("Upscaling by: ", scale) |
|
|
|
os.chdir(str(Path.home() / ".cache")) |
|
if scale == 4: |
|
print("Using 4x-Ultrasharp") |
|
upsampler = Ultrasharp( |
|
model_path=self.__model_path_4x_ultrasharp, |
|
tile=320, |
|
tile_pad=10, |
|
) |
|
else: |
|
print("Using RealESRGANer") |
|
upsampler = RealESRGANer( |
|
scale=4, |
|
model_path=model_path, |
|
model=model, |
|
half=False, |
|
gpu_id="0", |
|
tile=320, |
|
tile_pad=10, |
|
pre_pad=0, |
|
) |
|
face_enhancer = GFPGANer( |
|
model_path=self.__model_path_gfpgan, |
|
upscale=scale, |
|
arch="clean", |
|
channel_multiplier=2, |
|
bg_upsampler=upsampler, |
|
) |
|
|
|
if face_enhance: |
|
_, _, output = face_enhancer.enhance( |
|
input_image, has_aligned=False, only_center_face=False, paste_back=True |
|
) |
|
else: |
|
output, _ = upsampler.enhance(input_image, outscale=scale) |
|
os.chdir(get_root_dir()) |
|
cv2.imwrite("out.png", output) |
|
out_bytes = cv2.imencode(".png", output)[1].tobytes() |
|
return out_bytes |
|
|
|
@staticmethod |
|
def to_pil(buffer: bytes, mode="RGB") -> Image.Image: |
|
return Image.open(io.BytesIO(buffer)).convert(mode) |
|
|