File size: 3,806 Bytes
6c343a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import cv2
import glob
import os
import sys
from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils.download_util import load_file_from_url
import numpy as np
import torch
from gfpgan import GFPGANer
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
from basicsr.utils import imwrite, img2tensor, tensor2img
from torchvision.transforms.functional import normalize
from basicsr.utils.registry import ARCH_REGISTRY

def load_sr(model_path, device, face):
    if not face=='codeformer':
        model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) #alter to match dims as needed
        netscale = 4
        model_path = os.path.normpath(model_path)
        if not os.path.isfile(model_path):
            model_path = load_file_from_url(
                url='https://github.com/GucciFlipFlops1917/wav2lip-hq-updated-ESRGAN/releases/download/v0.0.1/4x_BigFace_v3_Clear.pth',
                model_dir='weights', progress=True, file_name=None)
        upsampler = RealESRGANer(
            scale=netscale,
            model_path=model_path,
            dni_weight=None,
            model=model,
            tile=0,
            tile_pad=10,
            pre_pad=0,
            half=True,
            gpu_id=0)
        if face==None:
            run_params=upsampler
        else:
            gfp = GFPGANer(
                model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/GFPGANv1.4.pth',
                upscale=2,
                arch='clean',
                channel_multiplier=2,
                bg_upsampler=upsampler)
            run_params=gfp
    else:
        run_params = ARCH_REGISTRY.get('CodeFormer')(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9,
                                              connect_list=['32', '64', '128', '256']).to(device)
        ckpt_path = load_file_from_url(url='https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth',
                                       model_dir='weights/CodeFormer', progress=True, file_name=None)
        checkpoint = torch.load(ckpt_path)['params_ema']
        run_params.load_state_dict(checkpoint)
        run_params.eval()
    return run_params


def upscale(image, face, properties):
    try:
        if face==1:             ## GFP-GAN
            _, _, output = properties.enhance(image, has_aligned=False, only_center_face=False, paste_back=True)
        elif face==2:           ## CODEFORMER
            net = properties[0]
            device = properties[1]
            w = properties[2]
            image = cv2.resize(image, (512, 512), interpolation=cv2.INTER_LINEAR)
            cropped_face_t = img2tensor(image / 255., bgr2rgb=True, float32=True)
            normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
            cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
            try:
                with torch.no_grad():
                    cropped_face_t = net(cropped_face_t, w=w, adain=True)[0]
                    restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
                del cropped_face_t
                torch.cuda.empty_cache()
            except Exception as error:
                print(f'\tFailed inference for CodeFormer: {error}')
                restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
            output = restored_face.astype('uint8')
        elif face==0:           ## ESRGAN
            img = image.astype(np.float32) / 255.
            output, _ = properties.enhance(image, outscale=4)
    except RuntimeError as error:
        print('Error', error)
        print('If you encounter CUDA out of memory, try to set --tile with a smaller number.')
    return output