jayparmr's picture
update : inference
35575bb verified
import io
import math
import os
from pathlib import Path
from typing import Optional, Union
import cv2
import numpy as np
from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.archs.srvgg_arch import SRVGGNetCompact
from basicsr.utils.download_util import load_file_from_url
from gfpgan import GFPGANer
from PIL import Image
from realesrgan import RealESRGANer # pyright: ignore
import internals.util.image as ImageUtil
from internals.util.commons import download_image
from internals.util.config import get_root_dir
from models.ultrasharp.model import Ultrasharp
class Upscaler:
__model_esrgan_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth"
__model_esrgan_anime_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth"
__model_gfpgan_url = (
"https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth"
)
__model_4x_ultrasharp_url = (
"https://comic-assets.s3.ap-south-1.amazonaws.com/models/4x-UltraSharp.pth"
)
__loaded = False
def load(self):
if self.__loaded:
return
download_dir = Path(Path.home() / ".cache" / "realesrgan")
download_dir.mkdir(parents=True, exist_ok=True)
self.__model_path = self.__preload_model(self.__model_esrgan_url, download_dir)
self.__model_path_anime = self.__preload_model(
self.__model_esrgan_anime_url, download_dir
)
self.__model_path_gfpgan = self.__preload_model(
self.__model_gfpgan_url, download_dir
)
self.__model_path_4x_ultrasharp = self.__preload_model(
self.__model_4x_ultrasharp_url, download_dir
)
self.__loaded = True
def upscale(
self,
image: Union[str, Image.Image],
width: int,
height: int,
face_enhance: bool,
resize_dimension: Optional[int] = None,
) -> bytes:
"if resize dimension is not provided, use the smaller of width and height"
self.load()
model = SRVGGNetCompact(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_conv=32,
upscale=4,
act_type="prelu",
)
return self.__internal_upscale(
image,
resize_dimension, # type: ignore
face_enhance,
width,
height,
self.__model_path,
model,
)
def upscale_anime(
self,
image: Union[str, Image.Image],
width: int,
height: int,
face_enhance: bool,
resize_dimension: int,
) -> bytes:
"if resize dimension is not provided, use the smaller of width and height"
self.load()
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=6,
num_grow_ch=32,
scale=4,
)
return self.__internal_upscale(
image,
resize_dimension,
face_enhance,
width,
height,
self.__model_path_anime,
model,
)
def __preload_model(self, url: str, download_dir: Path):
name = url.split("/")[-1]
if not os.path.exists(str(download_dir / name)):
return load_file_from_url(
url=url,
model_dir=str(download_dir),
progress=True,
file_name=None,
)
else:
return str(download_dir / name)
def __internal_upscale(
self,
image,
resize_dimension: int,
face_enhance: bool,
width: int,
height: int,
model_path: str,
model,
) -> bytes:
if type(image) is str:
image = download_image(image, mode="RGBA")
w, h = image.size
# if max(w, h) > 1024:
# image = ImageUtil.resize_image(image, dimension=1024)
in_path = str(Path.home() / ".cache" / "input_upscale.png")
image.save(in_path)
input_image = cv2.imread(in_path, cv2.IMREAD_UNCHANGED)
dimension = max(input_image.shape[0], input_image.shape[1])
if not resize_dimension:
resize_dimension = max(width, height)
scale = max(math.floor(resize_dimension / dimension), 2)
print("Upscaling by: ", scale)
os.chdir(str(Path.home() / ".cache"))
if scale == 4:
print("Using 4x-Ultrasharp")
upsampler = Ultrasharp(
model_path=self.__model_path_4x_ultrasharp,
tile=320,
tile_pad=10,
)
else:
print("Using RealESRGANer")
upsampler = RealESRGANer(
scale=4,
model_path=model_path,
model=model,
half=False,
gpu_id="0",
tile=320,
tile_pad=10,
pre_pad=0,
)
face_enhancer = GFPGANer(
model_path=self.__model_path_gfpgan,
upscale=scale,
arch="clean",
channel_multiplier=2,
bg_upsampler=upsampler,
)
if face_enhance:
_, _, output = face_enhancer.enhance(
input_image, has_aligned=False, only_center_face=False, paste_back=True
)
else:
output, _ = upsampler.enhance(input_image, outscale=scale)
os.chdir(get_root_dir())
cv2.imwrite("out.png", output)
out_bytes = cv2.imencode(".png", output)[1].tobytes()
return out_bytes
@staticmethod
def to_pil(buffer: bytes, mode="RGB") -> Image.Image:
return Image.open(io.BytesIO(buffer)).convert(mode)