manavisrani07's picture
Upload 319 files
6c343a2
raw
history blame
3.81 kB
import cv2
import glob
import os
import sys
from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils.download_util import load_file_from_url
import numpy as np
import torch
from gfpgan import GFPGANer
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
from basicsr.utils import imwrite, img2tensor, tensor2img
from torchvision.transforms.functional import normalize
from basicsr.utils.registry import ARCH_REGISTRY
def load_sr(model_path, device, face):
if not face=='codeformer':
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) #alter to match dims as needed
netscale = 4
model_path = os.path.normpath(model_path)
if not os.path.isfile(model_path):
model_path = load_file_from_url(
url='https://github.com/GucciFlipFlops1917/wav2lip-hq-updated-ESRGAN/releases/download/v0.0.1/4x_BigFace_v3_Clear.pth',
model_dir='weights', progress=True, file_name=None)
upsampler = RealESRGANer(
scale=netscale,
model_path=model_path,
dni_weight=None,
model=model,
tile=0,
tile_pad=10,
pre_pad=0,
half=True,
gpu_id=0)
if face==None:
run_params=upsampler
else:
gfp = GFPGANer(
model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/GFPGANv1.4.pth',
upscale=2,
arch='clean',
channel_multiplier=2,
bg_upsampler=upsampler)
run_params=gfp
else:
run_params = ARCH_REGISTRY.get('CodeFormer')(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9,
connect_list=['32', '64', '128', '256']).to(device)
ckpt_path = load_file_from_url(url='https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth',
model_dir='weights/CodeFormer', progress=True, file_name=None)
checkpoint = torch.load(ckpt_path)['params_ema']
run_params.load_state_dict(checkpoint)
run_params.eval()
return run_params
def upscale(image, face, properties):
try:
if face==1: ## GFP-GAN
_, _, output = properties.enhance(image, has_aligned=False, only_center_face=False, paste_back=True)
elif face==2: ## CODEFORMER
net = properties[0]
device = properties[1]
w = properties[2]
image = cv2.resize(image, (512, 512), interpolation=cv2.INTER_LINEAR)
cropped_face_t = img2tensor(image / 255., bgr2rgb=True, float32=True)
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
try:
with torch.no_grad():
cropped_face_t = net(cropped_face_t, w=w, adain=True)[0]
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
del cropped_face_t
torch.cuda.empty_cache()
except Exception as error:
print(f'\tFailed inference for CodeFormer: {error}')
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
output = restored_face.astype('uint8')
elif face==0: ## ESRGAN
img = image.astype(np.float32) / 255.
output, _ = properties.enhance(image, outscale=4)
except RuntimeError as error:
print('Error', error)
print('If you encounter CUDA out of memory, try to set --tile with a smaller number.')
return output