|
from typing import List |
|
|
|
import cv2 |
|
import numpy as np |
|
import torch |
|
|
|
import models.ultrasharp.arch as arch |
|
from models.ultrasharp.util import infer_params, upscale |
|
|
|
|
|
class Ultrasharp: |
|
def __init__(self, model_path, tile_pad=0, tile=0): |
|
self.filename = model_path |
|
self.tile_pad = tile_pad |
|
self.tile = tile |
|
|
|
def enhance(self, img, outscale=4): |
|
state_dict = torch.load(self.filename, map_location="cpu") |
|
|
|
in_nc, out_nc, nf, nb, plus, mscale = infer_params(state_dict) |
|
|
|
model = arch.RRDBNet( |
|
in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb, upscale=mscale, plus=plus |
|
) |
|
model.load_state_dict(state_dict) |
|
model.eval() |
|
|
|
model.to("cuda") |
|
|
|
if img.shape[2] == 4: |
|
img_mode = "RGBA" |
|
alpha = img[:, :, 3] |
|
img = img[:, :, 0:3] |
|
alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB) |
|
else: |
|
img_mode = "RGB" |
|
img = upscale(model, img, self.tile_pad, self.tile) |
|
|
|
|
|
if img_mode == "RGBA": |
|
output_alpha = upscale(model, alpha, self.tile_pad, self.tile) |
|
output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY) |
|
|
|
|
|
|
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2BGRA) |
|
img[:, :, 3] = output_alpha |
|
return img, None |
|
|