from typing import List | |
import torch | |
import models.ultrasharp.arch as arch | |
from models.ultrasharp.util import infer_params, upscale_without_tiling | |
class Ultrasharp: | |
def __init__(self, filename): | |
self.filename = filename | |
def enhance(self, img, outscale=4): | |
state_dict = torch.load(self.filename, map_location="cpu") | |
in_nc, out_nc, nf, nb, plus, mscale = infer_params(state_dict) | |
model = arch.RRDBNet( | |
in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb, upscale=mscale, plus=plus | |
) | |
model.load_state_dict(state_dict) | |
model.eval() | |
model.to("cuda") | |
img = upscale_without_tiling(model, img) | |
return img, None | |