File size: 1,516 Bytes
1bc457e
 
35575bb
 
1bc457e
 
 
22df957
1bc457e
 
 
22df957
 
 
 
1bc457e
 
 
 
 
 
 
 
 
 
 
 
 
 
35575bb
 
 
 
 
 
 
22df957
35575bb
 
 
 
 
 
 
 
 
 
1bc457e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from typing import List

import cv2
import numpy as np
import torch

import models.ultrasharp.arch as arch
from models.ultrasharp.util import infer_params, upscale


class Ultrasharp:
    def __init__(self, model_path, tile_pad=0, tile=0):
        self.filename = model_path
        self.tile_pad = tile_pad
        self.tile = tile

    def enhance(self, img, outscale=4):
        state_dict = torch.load(self.filename, map_location="cpu")

        in_nc, out_nc, nf, nb, plus, mscale = infer_params(state_dict)

        model = arch.RRDBNet(
            in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb, upscale=mscale, plus=plus
        )
        model.load_state_dict(state_dict)
        model.eval()

        model.to("cuda")

        if img.shape[2] == 4:  # RGBA image with alpha channel
            img_mode = "RGBA"
            alpha = img[:, :, 3]
            img = img[:, :, 0:3]
            alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB)
        else:
            img_mode = "RGB"
        img = upscale(model, img, self.tile_pad, self.tile)

        # process alpha channel if necessary
        if img_mode == "RGBA":
            output_alpha = upscale(model, alpha, self.tile_pad, self.tile)
            output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
            # output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))

            # merge the alpha channel
            img = cv2.cvtColor(img, cv2.COLOR_BGR2BGRA)
            img[:, :, 3] = output_alpha
        return img, None