Spaces:
Running
on
A10G
Running
on
A10G
Xintao
commited on
Commit
·
d0fef57
1
Parent(s):
d561a66
add GFPGAN
Browse files- app.py +78 -4
- gfpgan_utils.py +119 -0
- gfpganv1_clean_arch.py +325 -0
- packages.txt +3 -0
- realesrgan_utils.py +280 -0
- requirements.txt +11 -0
- srvgg_arch.py +67 -0
- stylegan2_clean_arch.py +369 -0
- weights/PutWeightsHere +0 -0
app.py
CHANGED
@@ -1,8 +1,82 @@
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
|
|
3 |
|
4 |
-
def greet(name):
|
5 |
-
return "Hello " + name + "!!"
|
6 |
|
7 |
-
|
8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import cv2
|
4 |
import gradio as gr
|
5 |
+
import torch
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
+
from realesrgan_utils import RealESRGANer
|
9 |
+
from srvgg_arch import SRVGGNetCompact
|
10 |
+
|
11 |
+
os.system("pip freeze")
|
12 |
+
os.system(
|
13 |
+
"wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -P ./weights")
|
14 |
+
os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.2.pth -P ./weights")
|
15 |
+
os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth -P ./weights")
|
16 |
+
|
17 |
+
torch.hub.download_url_to_file(
|
18 |
+
'https://upload.wikimedia.org/wikipedia/commons/thumb/a/ab/Abraham_Lincoln_O-77_matte_collodion_print.jpg/1024px-Abraham_Lincoln_O-77_matte_collodion_print.jpg',
|
19 |
+
'lincoln.jpg')
|
20 |
+
torch.hub.download_url_to_file('https://upload.wikimedia.org/wikipedia/commons/5/50/Albert_Einstein_%28Nobel%29.png',
|
21 |
+
'einstein.png')
|
22 |
+
torch.hub.download_url_to_file(
|
23 |
+
'https://upload.wikimedia.org/wikipedia/commons/thumb/9/9d/Thomas_Edison2.jpg/1024px-Thomas_Edison2.jpg',
|
24 |
+
'edison.jpg')
|
25 |
+
torch.hub.download_url_to_file(
|
26 |
+
'https://upload.wikimedia.org/wikipedia/commons/thumb/a/a9/Henry_Ford_1888.jpg/1024px-Henry_Ford_1888.jpg',
|
27 |
+
'Henry.jpg')
|
28 |
+
torch.hub.download_url_to_file(
|
29 |
+
'https://upload.wikimedia.org/wikipedia/commons/thumb/0/06/Frida_Kahlo%2C_by_Guillermo_Kahlo.jpg/800px-Frida_Kahlo%2C_by_Guillermo_Kahlo.jpg',
|
30 |
+
'Frida.jpg')
|
31 |
+
|
32 |
+
# determine models according to model names
|
33 |
+
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
|
34 |
+
netscale = 4
|
35 |
+
model_path = os.path.join('weights', 'realesr-general-x4v3.pth')
|
36 |
+
|
37 |
+
# restorer
|
38 |
+
upsampler = RealESRGANer(scale=netscale, model_path=model_path, model=model, tile=0, tile_pad=10, pre_pad=0, half=True)
|
39 |
+
|
40 |
+
# Use GFPGAN for face enhancement
|
41 |
+
from gfpgan_utils import GFPGANer
|
42 |
+
|
43 |
+
face_enhancer = GFPGANer(
|
44 |
+
model_path='weights/GFPGANv1.3.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
|
45 |
+
os.makedirs('output', exist_ok=True)
|
46 |
+
|
47 |
+
|
48 |
+
def inference(img):
|
49 |
+
img = cv2.imread(img, cv2.IMREAD_UNCHANGED)
|
50 |
+
|
51 |
+
h, w = img.shape[0:2]
|
52 |
+
if h < 400:
|
53 |
+
img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4)
|
54 |
+
|
55 |
+
if len(img.shape) == 3 and img.shape[2] == 4:
|
56 |
+
img_mode = 'RGBA'
|
57 |
+
else:
|
58 |
+
img_mode = None
|
59 |
+
|
60 |
+
try:
|
61 |
+
_, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
|
62 |
+
except RuntimeError as error:
|
63 |
+
print('Error', error)
|
64 |
+
print('If you encounter CUDA out of memory, try to set --tile with a smaller number.')
|
65 |
+
else:
|
66 |
+
extension = extension[1:]
|
67 |
+
if img_mode == 'RGBA': # RGBA images should be saved in png format
|
68 |
+
extension = 'png'
|
69 |
|
70 |
+
return Image.fromarray(output)
|
71 |
|
|
|
|
|
72 |
|
73 |
+
title = "GFP-GAN"
|
74 |
+
description = "Gradio demo for GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below. Please click submit only once"
|
75 |
+
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2101.04061' target='_blank'>Towards Real-World Blind Face Restoration with Generative Facial Prior</a> | <a href='https://github.com/TencentARC/GFPGAN' target='_blank'>Github Repo</a></p><center><img src='https://visitor-badge.glitch.me/badge?page_id=akhaliq_GFPGAN' alt='visitor badge'></center>"
|
76 |
+
gr.Interface(
|
77 |
+
inference, [gr.inputs.Image(type="filepath", label="Input")],
|
78 |
+
gr.outputs.Image(type="pil", label="Output"),
|
79 |
+
title=title,
|
80 |
+
description=description,
|
81 |
+
article=article,
|
82 |
+
examples=[['lincoln.jpg'], ['einstein.png'], ['edison.jpg'], ['Henry.jpg'], ['Frida.jpg']]).launch()
|
gfpgan_utils.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import torch
|
5 |
+
from basicsr.utils import img2tensor, tensor2img
|
6 |
+
from basicsr.utils.download_util import load_file_from_url
|
7 |
+
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
8 |
+
from torchvision.transforms.functional import normalize
|
9 |
+
|
10 |
+
from gfpganv1_clean_arch import GFPGANv1Clean
|
11 |
+
|
12 |
+
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
13 |
+
|
14 |
+
|
15 |
+
class GFPGANer():
|
16 |
+
"""Helper for restoration with GFPGAN.
|
17 |
+
|
18 |
+
It will detect and crop faces, and then resize the faces to 512x512.
|
19 |
+
GFPGAN is used to restored the resized faces.
|
20 |
+
The background is upsampled with the bg_upsampler.
|
21 |
+
Finally, the faces will be pasted back to the upsample background image.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
model_path (str): The path to the GFPGAN model. It can be urls (will first download it automatically).
|
25 |
+
upscale (float): The upscale of the final output. Default: 2.
|
26 |
+
arch (str): The GFPGAN architecture. Option: clean | original. Default: clean.
|
27 |
+
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
28 |
+
bg_upsampler (nn.Module): The upsampler for the background. Default: None.
|
29 |
+
"""
|
30 |
+
|
31 |
+
def __init__(self, model_path, upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=None, device=None):
|
32 |
+
self.upscale = upscale
|
33 |
+
self.bg_upsampler = bg_upsampler
|
34 |
+
|
35 |
+
# initialize model
|
36 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
|
37 |
+
# initialize the GFP-GAN
|
38 |
+
self.gfpgan = GFPGANv1Clean(
|
39 |
+
out_size=512,
|
40 |
+
num_style_feat=512,
|
41 |
+
channel_multiplier=channel_multiplier,
|
42 |
+
decoder_load_path=None,
|
43 |
+
fix_decoder=False,
|
44 |
+
num_mlp=8,
|
45 |
+
input_is_latent=True,
|
46 |
+
different_w=True,
|
47 |
+
narrow=1,
|
48 |
+
sft_half=True)
|
49 |
+
|
50 |
+
# initialize face helper
|
51 |
+
self.face_helper = FaceRestoreHelper(
|
52 |
+
upscale,
|
53 |
+
face_size=512,
|
54 |
+
crop_ratio=(1, 1),
|
55 |
+
det_model='retinaface_resnet50',
|
56 |
+
save_ext='png',
|
57 |
+
use_parse=True,
|
58 |
+
device=self.device)
|
59 |
+
|
60 |
+
if model_path.startswith('https://'):
|
61 |
+
model_path = load_file_from_url(
|
62 |
+
url=model_path, model_dir=os.path.join(ROOT_DIR, 'gfpgan/weights'), progress=True, file_name=None)
|
63 |
+
loadnet = torch.load(model_path)
|
64 |
+
if 'params_ema' in loadnet:
|
65 |
+
keyname = 'params_ema'
|
66 |
+
else:
|
67 |
+
keyname = 'params'
|
68 |
+
self.gfpgan.load_state_dict(loadnet[keyname], strict=True)
|
69 |
+
self.gfpgan.eval()
|
70 |
+
self.gfpgan = self.gfpgan.to(self.device)
|
71 |
+
|
72 |
+
@torch.no_grad()
|
73 |
+
def enhance(self, img, has_aligned=False, only_center_face=False, paste_back=True):
|
74 |
+
self.face_helper.clean_all()
|
75 |
+
|
76 |
+
if has_aligned: # the inputs are already aligned
|
77 |
+
img = cv2.resize(img, (512, 512))
|
78 |
+
self.face_helper.cropped_faces = [img]
|
79 |
+
else:
|
80 |
+
self.face_helper.read_image(img)
|
81 |
+
# get face landmarks for each face
|
82 |
+
self.face_helper.get_face_landmarks_5(only_center_face=only_center_face, eye_dist_threshold=5)
|
83 |
+
# eye_dist_threshold=5: skip faces whose eye distance is smaller than 5 pixels
|
84 |
+
# TODO: even with eye_dist_threshold, it will still introduce wrong detections and restorations.
|
85 |
+
# align and warp each face
|
86 |
+
self.face_helper.align_warp_face()
|
87 |
+
|
88 |
+
# face restoration
|
89 |
+
for cropped_face in self.face_helper.cropped_faces:
|
90 |
+
# prepare data
|
91 |
+
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
|
92 |
+
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
93 |
+
cropped_face_t = cropped_face_t.unsqueeze(0).to(self.device)
|
94 |
+
|
95 |
+
try:
|
96 |
+
output = self.gfpgan(cropped_face_t, return_rgb=False)[0]
|
97 |
+
# convert to image
|
98 |
+
restored_face = tensor2img(output.squeeze(0), rgb2bgr=True, min_max=(-1, 1))
|
99 |
+
except RuntimeError as error:
|
100 |
+
print(f'\tFailed inference for GFPGAN: {error}.')
|
101 |
+
restored_face = cropped_face
|
102 |
+
|
103 |
+
restored_face = restored_face.astype('uint8')
|
104 |
+
self.face_helper.add_restored_face(restored_face)
|
105 |
+
|
106 |
+
if not has_aligned and paste_back:
|
107 |
+
# upsample the background
|
108 |
+
if self.bg_upsampler is not None:
|
109 |
+
# Now only support RealESRGAN for upsampling background
|
110 |
+
bg_img = self.bg_upsampler.enhance(img, outscale=self.upscale)[0]
|
111 |
+
else:
|
112 |
+
bg_img = None
|
113 |
+
|
114 |
+
self.face_helper.get_inverse_affine(None)
|
115 |
+
# paste each restored face to the input image
|
116 |
+
restored_img = self.face_helper.paste_faces_to_input_image(upsample_img=bg_img)
|
117 |
+
return self.face_helper.cropped_faces, self.face_helper.restored_faces, restored_img
|
118 |
+
else:
|
119 |
+
return self.face_helper.cropped_faces, self.face_helper.restored_faces, None
|
gfpganv1_clean_arch.py
ADDED
@@ -0,0 +1,325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import random
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
6 |
+
from torch import nn
|
7 |
+
from torch.nn import functional as F
|
8 |
+
|
9 |
+
from stylegan2_clean_arch import StyleGAN2GeneratorClean
|
10 |
+
|
11 |
+
|
12 |
+
class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean):
|
13 |
+
"""StyleGAN2 Generator with SFT modulation (Spatial Feature Transform).
|
14 |
+
|
15 |
+
It is the clean version without custom compiled CUDA extensions used in StyleGAN2.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
out_size (int): The spatial size of outputs.
|
19 |
+
num_style_feat (int): Channel number of style features. Default: 512.
|
20 |
+
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
21 |
+
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
22 |
+
narrow (float): The narrow ratio for channels. Default: 1.
|
23 |
+
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1, sft_half=False):
|
27 |
+
super(StyleGAN2GeneratorCSFT, self).__init__(
|
28 |
+
out_size,
|
29 |
+
num_style_feat=num_style_feat,
|
30 |
+
num_mlp=num_mlp,
|
31 |
+
channel_multiplier=channel_multiplier,
|
32 |
+
narrow=narrow)
|
33 |
+
self.sft_half = sft_half
|
34 |
+
|
35 |
+
def forward(self,
|
36 |
+
styles,
|
37 |
+
conditions,
|
38 |
+
input_is_latent=False,
|
39 |
+
noise=None,
|
40 |
+
randomize_noise=True,
|
41 |
+
truncation=1,
|
42 |
+
truncation_latent=None,
|
43 |
+
inject_index=None,
|
44 |
+
return_latents=False):
|
45 |
+
"""Forward function for StyleGAN2GeneratorCSFT.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
styles (list[Tensor]): Sample codes of styles.
|
49 |
+
conditions (list[Tensor]): SFT conditions to generators.
|
50 |
+
input_is_latent (bool): Whether input is latent style. Default: False.
|
51 |
+
noise (Tensor | None): Input noise or None. Default: None.
|
52 |
+
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
|
53 |
+
truncation (float): The truncation ratio. Default: 1.
|
54 |
+
truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
|
55 |
+
inject_index (int | None): The injection index for mixing noise. Default: None.
|
56 |
+
return_latents (bool): Whether to return style latents. Default: False.
|
57 |
+
"""
|
58 |
+
# style codes -> latents with Style MLP layer
|
59 |
+
if not input_is_latent:
|
60 |
+
styles = [self.style_mlp(s) for s in styles]
|
61 |
+
# noises
|
62 |
+
if noise is None:
|
63 |
+
if randomize_noise:
|
64 |
+
noise = [None] * self.num_layers # for each style conv layer
|
65 |
+
else: # use the stored noise
|
66 |
+
noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
|
67 |
+
# style truncation
|
68 |
+
if truncation < 1:
|
69 |
+
style_truncation = []
|
70 |
+
for style in styles:
|
71 |
+
style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
|
72 |
+
styles = style_truncation
|
73 |
+
# get style latents with injection
|
74 |
+
if len(styles) == 1:
|
75 |
+
inject_index = self.num_latent
|
76 |
+
|
77 |
+
if styles[0].ndim < 3:
|
78 |
+
# repeat latent code for all the layers
|
79 |
+
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
80 |
+
else: # used for encoder with different latent code for each layer
|
81 |
+
latent = styles[0]
|
82 |
+
elif len(styles) == 2: # mixing noises
|
83 |
+
if inject_index is None:
|
84 |
+
inject_index = random.randint(1, self.num_latent - 1)
|
85 |
+
latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
86 |
+
latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
|
87 |
+
latent = torch.cat([latent1, latent2], 1)
|
88 |
+
|
89 |
+
# main generation
|
90 |
+
out = self.constant_input(latent.shape[0])
|
91 |
+
out = self.style_conv1(out, latent[:, 0], noise=noise[0])
|
92 |
+
skip = self.to_rgb1(out, latent[:, 1])
|
93 |
+
|
94 |
+
i = 1
|
95 |
+
for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
|
96 |
+
noise[2::2], self.to_rgbs):
|
97 |
+
out = conv1(out, latent[:, i], noise=noise1)
|
98 |
+
|
99 |
+
# the conditions may have fewer levels
|
100 |
+
if i < len(conditions):
|
101 |
+
# SFT part to combine the conditions
|
102 |
+
if self.sft_half: # only apply SFT to half of the channels
|
103 |
+
out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
|
104 |
+
out_sft = out_sft * conditions[i - 1] + conditions[i]
|
105 |
+
out = torch.cat([out_same, out_sft], dim=1)
|
106 |
+
else: # apply SFT to all the channels
|
107 |
+
out = out * conditions[i - 1] + conditions[i]
|
108 |
+
|
109 |
+
out = conv2(out, latent[:, i + 1], noise=noise2)
|
110 |
+
skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
|
111 |
+
i += 2
|
112 |
+
|
113 |
+
image = skip
|
114 |
+
|
115 |
+
if return_latents:
|
116 |
+
return image, latent
|
117 |
+
else:
|
118 |
+
return image, None
|
119 |
+
|
120 |
+
|
121 |
+
class ResBlock(nn.Module):
|
122 |
+
"""Residual block with bilinear upsampling/downsampling.
|
123 |
+
|
124 |
+
Args:
|
125 |
+
in_channels (int): Channel number of the input.
|
126 |
+
out_channels (int): Channel number of the output.
|
127 |
+
mode (str): Upsampling/downsampling mode. Options: down | up. Default: down.
|
128 |
+
"""
|
129 |
+
|
130 |
+
def __init__(self, in_channels, out_channels, mode='down'):
|
131 |
+
super(ResBlock, self).__init__()
|
132 |
+
|
133 |
+
self.conv1 = nn.Conv2d(in_channels, in_channels, 3, 1, 1)
|
134 |
+
self.conv2 = nn.Conv2d(in_channels, out_channels, 3, 1, 1)
|
135 |
+
self.skip = nn.Conv2d(in_channels, out_channels, 1, bias=False)
|
136 |
+
if mode == 'down':
|
137 |
+
self.scale_factor = 0.5
|
138 |
+
elif mode == 'up':
|
139 |
+
self.scale_factor = 2
|
140 |
+
|
141 |
+
def forward(self, x):
|
142 |
+
out = F.leaky_relu_(self.conv1(x), negative_slope=0.2)
|
143 |
+
# upsample/downsample
|
144 |
+
out = F.interpolate(out, scale_factor=self.scale_factor, mode='bilinear', align_corners=False)
|
145 |
+
out = F.leaky_relu_(self.conv2(out), negative_slope=0.2)
|
146 |
+
# skip
|
147 |
+
x = F.interpolate(x, scale_factor=self.scale_factor, mode='bilinear', align_corners=False)
|
148 |
+
skip = self.skip(x)
|
149 |
+
out = out + skip
|
150 |
+
return out
|
151 |
+
|
152 |
+
|
153 |
+
@ARCH_REGISTRY.register()
|
154 |
+
class GFPGANv1Clean(nn.Module):
|
155 |
+
"""The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT.
|
156 |
+
|
157 |
+
It is the clean version without custom compiled CUDA extensions used in StyleGAN2.
|
158 |
+
|
159 |
+
Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior.
|
160 |
+
|
161 |
+
Args:
|
162 |
+
out_size (int): The spatial size of outputs.
|
163 |
+
num_style_feat (int): Channel number of style features. Default: 512.
|
164 |
+
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
165 |
+
decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None.
|
166 |
+
fix_decoder (bool): Whether to fix the decoder. Default: True.
|
167 |
+
|
168 |
+
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
169 |
+
input_is_latent (bool): Whether input is latent style. Default: False.
|
170 |
+
different_w (bool): Whether to use different latent w for different layers. Default: False.
|
171 |
+
narrow (float): The narrow ratio for channels. Default: 1.
|
172 |
+
sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
|
173 |
+
"""
|
174 |
+
|
175 |
+
def __init__(
|
176 |
+
self,
|
177 |
+
out_size,
|
178 |
+
num_style_feat=512,
|
179 |
+
channel_multiplier=1,
|
180 |
+
decoder_load_path=None,
|
181 |
+
fix_decoder=True,
|
182 |
+
# for stylegan decoder
|
183 |
+
num_mlp=8,
|
184 |
+
input_is_latent=False,
|
185 |
+
different_w=False,
|
186 |
+
narrow=1,
|
187 |
+
sft_half=False):
|
188 |
+
|
189 |
+
super(GFPGANv1Clean, self).__init__()
|
190 |
+
self.input_is_latent = input_is_latent
|
191 |
+
self.different_w = different_w
|
192 |
+
self.num_style_feat = num_style_feat
|
193 |
+
|
194 |
+
unet_narrow = narrow * 0.5 # by default, use a half of input channels
|
195 |
+
channels = {
|
196 |
+
'4': int(512 * unet_narrow),
|
197 |
+
'8': int(512 * unet_narrow),
|
198 |
+
'16': int(512 * unet_narrow),
|
199 |
+
'32': int(512 * unet_narrow),
|
200 |
+
'64': int(256 * channel_multiplier * unet_narrow),
|
201 |
+
'128': int(128 * channel_multiplier * unet_narrow),
|
202 |
+
'256': int(64 * channel_multiplier * unet_narrow),
|
203 |
+
'512': int(32 * channel_multiplier * unet_narrow),
|
204 |
+
'1024': int(16 * channel_multiplier * unet_narrow)
|
205 |
+
}
|
206 |
+
|
207 |
+
self.log_size = int(math.log(out_size, 2))
|
208 |
+
first_out_size = 2**(int(math.log(out_size, 2)))
|
209 |
+
|
210 |
+
self.conv_body_first = nn.Conv2d(3, channels[f'{first_out_size}'], 1)
|
211 |
+
|
212 |
+
# downsample
|
213 |
+
in_channels = channels[f'{first_out_size}']
|
214 |
+
self.conv_body_down = nn.ModuleList()
|
215 |
+
for i in range(self.log_size, 2, -1):
|
216 |
+
out_channels = channels[f'{2**(i - 1)}']
|
217 |
+
self.conv_body_down.append(ResBlock(in_channels, out_channels, mode='down'))
|
218 |
+
in_channels = out_channels
|
219 |
+
|
220 |
+
self.final_conv = nn.Conv2d(in_channels, channels['4'], 3, 1, 1)
|
221 |
+
|
222 |
+
# upsample
|
223 |
+
in_channels = channels['4']
|
224 |
+
self.conv_body_up = nn.ModuleList()
|
225 |
+
for i in range(3, self.log_size + 1):
|
226 |
+
out_channels = channels[f'{2**i}']
|
227 |
+
self.conv_body_up.append(ResBlock(in_channels, out_channels, mode='up'))
|
228 |
+
in_channels = out_channels
|
229 |
+
|
230 |
+
# to RGB
|
231 |
+
self.toRGB = nn.ModuleList()
|
232 |
+
for i in range(3, self.log_size + 1):
|
233 |
+
self.toRGB.append(nn.Conv2d(channels[f'{2**i}'], 3, 1))
|
234 |
+
|
235 |
+
if different_w:
|
236 |
+
linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat
|
237 |
+
else:
|
238 |
+
linear_out_channel = num_style_feat
|
239 |
+
|
240 |
+
self.final_linear = nn.Linear(channels['4'] * 4 * 4, linear_out_channel)
|
241 |
+
|
242 |
+
# the decoder: stylegan2 generator with SFT modulations
|
243 |
+
self.stylegan_decoder = StyleGAN2GeneratorCSFT(
|
244 |
+
out_size=out_size,
|
245 |
+
num_style_feat=num_style_feat,
|
246 |
+
num_mlp=num_mlp,
|
247 |
+
channel_multiplier=channel_multiplier,
|
248 |
+
narrow=narrow,
|
249 |
+
sft_half=sft_half)
|
250 |
+
|
251 |
+
# load pre-trained stylegan2 model if necessary
|
252 |
+
if decoder_load_path:
|
253 |
+
self.stylegan_decoder.load_state_dict(
|
254 |
+
torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema'])
|
255 |
+
# fix decoder without updating params
|
256 |
+
if fix_decoder:
|
257 |
+
for _, param in self.stylegan_decoder.named_parameters():
|
258 |
+
param.requires_grad = False
|
259 |
+
|
260 |
+
# for SFT modulations (scale and shift)
|
261 |
+
self.condition_scale = nn.ModuleList()
|
262 |
+
self.condition_shift = nn.ModuleList()
|
263 |
+
for i in range(3, self.log_size + 1):
|
264 |
+
out_channels = channels[f'{2**i}']
|
265 |
+
if sft_half:
|
266 |
+
sft_out_channels = out_channels
|
267 |
+
else:
|
268 |
+
sft_out_channels = out_channels * 2
|
269 |
+
self.condition_scale.append(
|
270 |
+
nn.Sequential(
|
271 |
+
nn.Conv2d(out_channels, out_channels, 3, 1, 1), nn.LeakyReLU(0.2, True),
|
272 |
+
nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1)))
|
273 |
+
self.condition_shift.append(
|
274 |
+
nn.Sequential(
|
275 |
+
nn.Conv2d(out_channels, out_channels, 3, 1, 1), nn.LeakyReLU(0.2, True),
|
276 |
+
nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1)))
|
277 |
+
|
278 |
+
def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True):
|
279 |
+
"""Forward function for GFPGANv1Clean.
|
280 |
+
|
281 |
+
Args:
|
282 |
+
x (Tensor): Input images.
|
283 |
+
return_latents (bool): Whether to return style latents. Default: False.
|
284 |
+
return_rgb (bool): Whether return intermediate rgb images. Default: True.
|
285 |
+
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
|
286 |
+
"""
|
287 |
+
conditions = []
|
288 |
+
unet_skips = []
|
289 |
+
out_rgbs = []
|
290 |
+
|
291 |
+
# encoder
|
292 |
+
feat = F.leaky_relu_(self.conv_body_first(x), negative_slope=0.2)
|
293 |
+
for i in range(self.log_size - 2):
|
294 |
+
feat = self.conv_body_down[i](feat)
|
295 |
+
unet_skips.insert(0, feat)
|
296 |
+
feat = F.leaky_relu_(self.final_conv(feat), negative_slope=0.2)
|
297 |
+
|
298 |
+
# style code
|
299 |
+
style_code = self.final_linear(feat.view(feat.size(0), -1))
|
300 |
+
if self.different_w:
|
301 |
+
style_code = style_code.view(style_code.size(0), -1, self.num_style_feat)
|
302 |
+
|
303 |
+
# decode
|
304 |
+
for i in range(self.log_size - 2):
|
305 |
+
# add unet skip
|
306 |
+
feat = feat + unet_skips[i]
|
307 |
+
# ResUpLayer
|
308 |
+
feat = self.conv_body_up[i](feat)
|
309 |
+
# generate scale and shift for SFT layers
|
310 |
+
scale = self.condition_scale[i](feat)
|
311 |
+
conditions.append(scale.clone())
|
312 |
+
shift = self.condition_shift[i](feat)
|
313 |
+
conditions.append(shift.clone())
|
314 |
+
# generate rgb images
|
315 |
+
if return_rgb:
|
316 |
+
out_rgbs.append(self.toRGB[i](feat))
|
317 |
+
|
318 |
+
# decoder
|
319 |
+
image, _ = self.stylegan_decoder([style_code],
|
320 |
+
conditions,
|
321 |
+
return_latents=return_latents,
|
322 |
+
input_is_latent=self.input_is_latent,
|
323 |
+
randomize_noise=randomize_noise)
|
324 |
+
|
325 |
+
return image, out_rgbs
|
packages.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
ffmpeg
|
2 |
+
libsm6
|
3 |
+
libxext6
|
realesrgan_utils.py
ADDED
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import math
|
3 |
+
import numpy as np
|
4 |
+
import os
|
5 |
+
import queue
|
6 |
+
import threading
|
7 |
+
import torch
|
8 |
+
from basicsr.utils.download_util import load_file_from_url
|
9 |
+
from torch.nn import functional as F
|
10 |
+
|
11 |
+
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
12 |
+
|
13 |
+
|
14 |
+
class RealESRGANer():
|
15 |
+
"""A helper class for upsampling images with RealESRGAN.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4.
|
19 |
+
model_path (str): The path to the pretrained model. It can be urls (will first download it automatically).
|
20 |
+
model (nn.Module): The defined network. Default: None.
|
21 |
+
tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop
|
22 |
+
input images into tiles, and then process each of them. Finally, they will be merged into one image.
|
23 |
+
0 denotes for do not use tile. Default: 0.
|
24 |
+
tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10.
|
25 |
+
pre_pad (int): Pad the input images to avoid border artifacts. Default: 10.
|
26 |
+
half (float): Whether to use half precision during inference. Default: False.
|
27 |
+
"""
|
28 |
+
|
29 |
+
def __init__(self, scale, model_path, model=None, tile=0, tile_pad=10, pre_pad=10, half=False):
|
30 |
+
self.scale = scale
|
31 |
+
self.tile_size = tile
|
32 |
+
self.tile_pad = tile_pad
|
33 |
+
self.pre_pad = pre_pad
|
34 |
+
self.mod_scale = None
|
35 |
+
self.half = half
|
36 |
+
|
37 |
+
# initialize model
|
38 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
39 |
+
# if the model_path starts with https, it will first download models to the folder: realesrgan/weights
|
40 |
+
if model_path.startswith('https://'):
|
41 |
+
model_path = load_file_from_url(
|
42 |
+
url=model_path, model_dir=os.path.join(ROOT_DIR, 'realesrgan/weights'), progress=True, file_name=None)
|
43 |
+
loadnet = torch.load(model_path, map_location=torch.device('cpu'))
|
44 |
+
# prefer to use params_ema
|
45 |
+
if 'params_ema' in loadnet:
|
46 |
+
keyname = 'params_ema'
|
47 |
+
else:
|
48 |
+
keyname = 'params'
|
49 |
+
model.load_state_dict(loadnet[keyname], strict=True)
|
50 |
+
model.eval()
|
51 |
+
self.model = model.to(self.device)
|
52 |
+
if self.half:
|
53 |
+
self.model = self.model.half()
|
54 |
+
|
55 |
+
def pre_process(self, img):
|
56 |
+
"""Pre-process, such as pre-pad and mod pad, so that the images can be divisible
|
57 |
+
"""
|
58 |
+
img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
|
59 |
+
self.img = img.unsqueeze(0).to(self.device)
|
60 |
+
if self.half:
|
61 |
+
self.img = self.img.half()
|
62 |
+
|
63 |
+
# pre_pad
|
64 |
+
if self.pre_pad != 0:
|
65 |
+
self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect')
|
66 |
+
# mod pad for divisible borders
|
67 |
+
if self.scale == 2:
|
68 |
+
self.mod_scale = 2
|
69 |
+
elif self.scale == 1:
|
70 |
+
self.mod_scale = 4
|
71 |
+
if self.mod_scale is not None:
|
72 |
+
self.mod_pad_h, self.mod_pad_w = 0, 0
|
73 |
+
_, _, h, w = self.img.size()
|
74 |
+
if (h % self.mod_scale != 0):
|
75 |
+
self.mod_pad_h = (self.mod_scale - h % self.mod_scale)
|
76 |
+
if (w % self.mod_scale != 0):
|
77 |
+
self.mod_pad_w = (self.mod_scale - w % self.mod_scale)
|
78 |
+
self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect')
|
79 |
+
|
80 |
+
def process(self):
|
81 |
+
# model inference
|
82 |
+
self.output = self.model(self.img)
|
83 |
+
|
84 |
+
def tile_process(self):
|
85 |
+
"""It will first crop input images to tiles, and then process each tile.
|
86 |
+
Finally, all the processed tiles are merged into one images.
|
87 |
+
|
88 |
+
Modified from: https://github.com/ata4/esrgan-launcher
|
89 |
+
"""
|
90 |
+
batch, channel, height, width = self.img.shape
|
91 |
+
output_height = height * self.scale
|
92 |
+
output_width = width * self.scale
|
93 |
+
output_shape = (batch, channel, output_height, output_width)
|
94 |
+
|
95 |
+
# start with black image
|
96 |
+
self.output = self.img.new_zeros(output_shape)
|
97 |
+
tiles_x = math.ceil(width / self.tile_size)
|
98 |
+
tiles_y = math.ceil(height / self.tile_size)
|
99 |
+
|
100 |
+
# loop over all tiles
|
101 |
+
for y in range(tiles_y):
|
102 |
+
for x in range(tiles_x):
|
103 |
+
# extract tile from input image
|
104 |
+
ofs_x = x * self.tile_size
|
105 |
+
ofs_y = y * self.tile_size
|
106 |
+
# input tile area on total image
|
107 |
+
input_start_x = ofs_x
|
108 |
+
input_end_x = min(ofs_x + self.tile_size, width)
|
109 |
+
input_start_y = ofs_y
|
110 |
+
input_end_y = min(ofs_y + self.tile_size, height)
|
111 |
+
|
112 |
+
# input tile area on total image with padding
|
113 |
+
input_start_x_pad = max(input_start_x - self.tile_pad, 0)
|
114 |
+
input_end_x_pad = min(input_end_x + self.tile_pad, width)
|
115 |
+
input_start_y_pad = max(input_start_y - self.tile_pad, 0)
|
116 |
+
input_end_y_pad = min(input_end_y + self.tile_pad, height)
|
117 |
+
|
118 |
+
# input tile dimensions
|
119 |
+
input_tile_width = input_end_x - input_start_x
|
120 |
+
input_tile_height = input_end_y - input_start_y
|
121 |
+
tile_idx = y * tiles_x + x + 1
|
122 |
+
input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad]
|
123 |
+
|
124 |
+
# upscale tile
|
125 |
+
try:
|
126 |
+
with torch.no_grad():
|
127 |
+
output_tile = self.model(input_tile)
|
128 |
+
except RuntimeError as error:
|
129 |
+
print('Error', error)
|
130 |
+
print(f'\tTile {tile_idx}/{tiles_x * tiles_y}')
|
131 |
+
|
132 |
+
# output tile area on total image
|
133 |
+
output_start_x = input_start_x * self.scale
|
134 |
+
output_end_x = input_end_x * self.scale
|
135 |
+
output_start_y = input_start_y * self.scale
|
136 |
+
output_end_y = input_end_y * self.scale
|
137 |
+
|
138 |
+
# output tile area without padding
|
139 |
+
output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale
|
140 |
+
output_end_x_tile = output_start_x_tile + input_tile_width * self.scale
|
141 |
+
output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale
|
142 |
+
output_end_y_tile = output_start_y_tile + input_tile_height * self.scale
|
143 |
+
|
144 |
+
# put tile into output image
|
145 |
+
self.output[:, :, output_start_y:output_end_y,
|
146 |
+
output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile,
|
147 |
+
output_start_x_tile:output_end_x_tile]
|
148 |
+
|
149 |
+
def post_process(self):
|
150 |
+
# remove extra pad
|
151 |
+
if self.mod_scale is not None:
|
152 |
+
_, _, h, w = self.output.size()
|
153 |
+
self.output = self.output[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale]
|
154 |
+
# remove prepad
|
155 |
+
if self.pre_pad != 0:
|
156 |
+
_, _, h, w = self.output.size()
|
157 |
+
self.output = self.output[:, :, 0:h - self.pre_pad * self.scale, 0:w - self.pre_pad * self.scale]
|
158 |
+
return self.output
|
159 |
+
|
160 |
+
@torch.no_grad()
|
161 |
+
def enhance(self, img, outscale=None, alpha_upsampler='realesrgan'):
|
162 |
+
h_input, w_input = img.shape[0:2]
|
163 |
+
# img: numpy
|
164 |
+
img = img.astype(np.float32)
|
165 |
+
if np.max(img) > 256: # 16-bit image
|
166 |
+
max_range = 65535
|
167 |
+
print('\tInput is a 16-bit image')
|
168 |
+
else:
|
169 |
+
max_range = 255
|
170 |
+
img = img / max_range
|
171 |
+
if len(img.shape) == 2: # gray image
|
172 |
+
img_mode = 'L'
|
173 |
+
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
|
174 |
+
elif img.shape[2] == 4: # RGBA image with alpha channel
|
175 |
+
img_mode = 'RGBA'
|
176 |
+
alpha = img[:, :, 3]
|
177 |
+
img = img[:, :, 0:3]
|
178 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
179 |
+
if alpha_upsampler == 'realesrgan':
|
180 |
+
alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB)
|
181 |
+
else:
|
182 |
+
img_mode = 'RGB'
|
183 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
184 |
+
|
185 |
+
# ------------------- process image (without the alpha channel) ------------------- #
|
186 |
+
self.pre_process(img)
|
187 |
+
if self.tile_size > 0:
|
188 |
+
self.tile_process()
|
189 |
+
else:
|
190 |
+
self.process()
|
191 |
+
output_img = self.post_process()
|
192 |
+
output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
193 |
+
output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0))
|
194 |
+
if img_mode == 'L':
|
195 |
+
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
|
196 |
+
|
197 |
+
# ------------------- process the alpha channel if necessary ------------------- #
|
198 |
+
if img_mode == 'RGBA':
|
199 |
+
if alpha_upsampler == 'realesrgan':
|
200 |
+
self.pre_process(alpha)
|
201 |
+
if self.tile_size > 0:
|
202 |
+
self.tile_process()
|
203 |
+
else:
|
204 |
+
self.process()
|
205 |
+
output_alpha = self.post_process()
|
206 |
+
output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
207 |
+
output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
|
208 |
+
output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
|
209 |
+
else: # use the cv2 resize for alpha channel
|
210 |
+
h, w = alpha.shape[0:2]
|
211 |
+
output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR)
|
212 |
+
|
213 |
+
# merge the alpha channel
|
214 |
+
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA)
|
215 |
+
output_img[:, :, 3] = output_alpha
|
216 |
+
|
217 |
+
# ------------------------------ return ------------------------------ #
|
218 |
+
if max_range == 65535: # 16-bit image
|
219 |
+
output = (output_img * 65535.0).round().astype(np.uint16)
|
220 |
+
else:
|
221 |
+
output = (output_img * 255.0).round().astype(np.uint8)
|
222 |
+
|
223 |
+
if outscale is not None and outscale != float(self.scale):
|
224 |
+
output = cv2.resize(
|
225 |
+
output, (
|
226 |
+
int(w_input * outscale),
|
227 |
+
int(h_input * outscale),
|
228 |
+
), interpolation=cv2.INTER_LANCZOS4)
|
229 |
+
|
230 |
+
return output, img_mode
|
231 |
+
|
232 |
+
|
233 |
+
class PrefetchReader(threading.Thread):
|
234 |
+
"""Prefetch images.
|
235 |
+
|
236 |
+
Args:
|
237 |
+
img_list (list[str]): A image list of image paths to be read.
|
238 |
+
num_prefetch_queue (int): Number of prefetch queue.
|
239 |
+
"""
|
240 |
+
|
241 |
+
def __init__(self, img_list, num_prefetch_queue):
|
242 |
+
super().__init__()
|
243 |
+
self.que = queue.Queue(num_prefetch_queue)
|
244 |
+
self.img_list = img_list
|
245 |
+
|
246 |
+
def run(self):
|
247 |
+
for img_path in self.img_list:
|
248 |
+
img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
|
249 |
+
self.que.put(img)
|
250 |
+
|
251 |
+
self.que.put(None)
|
252 |
+
|
253 |
+
def __next__(self):
|
254 |
+
next_item = self.que.get()
|
255 |
+
if next_item is None:
|
256 |
+
raise StopIteration
|
257 |
+
return next_item
|
258 |
+
|
259 |
+
def __iter__(self):
|
260 |
+
return self
|
261 |
+
|
262 |
+
|
263 |
+
class IOConsumer(threading.Thread):
|
264 |
+
|
265 |
+
def __init__(self, opt, que, qid):
|
266 |
+
super().__init__()
|
267 |
+
self._queue = que
|
268 |
+
self.qid = qid
|
269 |
+
self.opt = opt
|
270 |
+
|
271 |
+
def run(self):
|
272 |
+
while True:
|
273 |
+
msg = self._queue.get()
|
274 |
+
if isinstance(msg, str) and msg == 'quit':
|
275 |
+
break
|
276 |
+
|
277 |
+
output = msg['output']
|
278 |
+
save_path = msg['save_path']
|
279 |
+
cv2.imwrite(save_path, output)
|
280 |
+
print(f'IO worker {self.qid} is done.')
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch>=1.7
|
2 |
+
numpy
|
3 |
+
opencv-python
|
4 |
+
torchvision
|
5 |
+
scipy
|
6 |
+
tqdm
|
7 |
+
basicsr>=1.4.1
|
8 |
+
facexlib>=0.2.4
|
9 |
+
lmdb
|
10 |
+
pyyaml
|
11 |
+
yapf
|
srvgg_arch.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn as nn
|
2 |
+
from torch.nn import functional as F
|
3 |
+
|
4 |
+
|
5 |
+
class SRVGGNetCompact(nn.Module):
|
6 |
+
"""A compact VGG-style network structure for super-resolution.
|
7 |
+
|
8 |
+
It is a compact network structure, which performs upsampling in the last layer and no convolution is
|
9 |
+
conducted on the HR feature space.
|
10 |
+
|
11 |
+
Args:
|
12 |
+
num_in_ch (int): Channel number of inputs. Default: 3.
|
13 |
+
num_out_ch (int): Channel number of outputs. Default: 3.
|
14 |
+
num_feat (int): Channel number of intermediate features. Default: 64.
|
15 |
+
num_conv (int): Number of convolution layers in the body network. Default: 16.
|
16 |
+
upscale (int): Upsampling factor. Default: 4.
|
17 |
+
act_type (str): Activation type, options: 'relu', 'prelu', 'leakyrelu'. Default: prelu.
|
18 |
+
"""
|
19 |
+
|
20 |
+
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'):
|
21 |
+
super(SRVGGNetCompact, self).__init__()
|
22 |
+
self.num_in_ch = num_in_ch
|
23 |
+
self.num_out_ch = num_out_ch
|
24 |
+
self.num_feat = num_feat
|
25 |
+
self.num_conv = num_conv
|
26 |
+
self.upscale = upscale
|
27 |
+
self.act_type = act_type
|
28 |
+
|
29 |
+
self.body = nn.ModuleList()
|
30 |
+
# the first conv
|
31 |
+
self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
|
32 |
+
# the first activation
|
33 |
+
if act_type == 'relu':
|
34 |
+
activation = nn.ReLU(inplace=True)
|
35 |
+
elif act_type == 'prelu':
|
36 |
+
activation = nn.PReLU(num_parameters=num_feat)
|
37 |
+
elif act_type == 'leakyrelu':
|
38 |
+
activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
39 |
+
self.body.append(activation)
|
40 |
+
|
41 |
+
# the body structure
|
42 |
+
for _ in range(num_conv):
|
43 |
+
self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
|
44 |
+
# activation
|
45 |
+
if act_type == 'relu':
|
46 |
+
activation = nn.ReLU(inplace=True)
|
47 |
+
elif act_type == 'prelu':
|
48 |
+
activation = nn.PReLU(num_parameters=num_feat)
|
49 |
+
elif act_type == 'leakyrelu':
|
50 |
+
activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
51 |
+
self.body.append(activation)
|
52 |
+
|
53 |
+
# the last conv
|
54 |
+
self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
|
55 |
+
# upsample
|
56 |
+
self.upsampler = nn.PixelShuffle(upscale)
|
57 |
+
|
58 |
+
def forward(self, x):
|
59 |
+
out = x
|
60 |
+
for i in range(0, len(self.body)):
|
61 |
+
out = self.body[i](out)
|
62 |
+
|
63 |
+
out = self.upsampler(out)
|
64 |
+
# add the nearest upsampled image, so that the network learns the residual
|
65 |
+
base = F.interpolate(x, scale_factor=self.upscale, mode='nearest')
|
66 |
+
out += base
|
67 |
+
return out
|
stylegan2_clean_arch.py
ADDED
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import random
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from basicsr.archs.arch_util import default_init_weights
|
6 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
7 |
+
from torch import nn
|
8 |
+
from torch.nn import functional as F
|
9 |
+
|
10 |
+
|
11 |
+
class NormStyleCode(nn.Module):
|
12 |
+
|
13 |
+
def forward(self, x):
|
14 |
+
"""Normalize the style codes.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
x (Tensor): Style codes with shape (b, c).
|
18 |
+
|
19 |
+
Returns:
|
20 |
+
Tensor: Normalized tensor.
|
21 |
+
"""
|
22 |
+
return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8)
|
23 |
+
|
24 |
+
|
25 |
+
class ModulatedConv2d(nn.Module):
|
26 |
+
"""Modulated Conv2d used in StyleGAN2.
|
27 |
+
|
28 |
+
There is no bias in ModulatedConv2d.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
in_channels (int): Channel number of the input.
|
32 |
+
out_channels (int): Channel number of the output.
|
33 |
+
kernel_size (int): Size of the convolving kernel.
|
34 |
+
num_style_feat (int): Channel number of style features.
|
35 |
+
demodulate (bool): Whether to demodulate in the conv layer. Default: True.
|
36 |
+
sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None.
|
37 |
+
eps (float): A value added to the denominator for numerical stability. Default: 1e-8.
|
38 |
+
"""
|
39 |
+
|
40 |
+
def __init__(self,
|
41 |
+
in_channels,
|
42 |
+
out_channels,
|
43 |
+
kernel_size,
|
44 |
+
num_style_feat,
|
45 |
+
demodulate=True,
|
46 |
+
sample_mode=None,
|
47 |
+
eps=1e-8):
|
48 |
+
super(ModulatedConv2d, self).__init__()
|
49 |
+
self.in_channels = in_channels
|
50 |
+
self.out_channels = out_channels
|
51 |
+
self.kernel_size = kernel_size
|
52 |
+
self.demodulate = demodulate
|
53 |
+
self.sample_mode = sample_mode
|
54 |
+
self.eps = eps
|
55 |
+
|
56 |
+
# modulation inside each modulated conv
|
57 |
+
self.modulation = nn.Linear(num_style_feat, in_channels, bias=True)
|
58 |
+
# initialization
|
59 |
+
default_init_weights(self.modulation, scale=1, bias_fill=1, a=0, mode='fan_in', nonlinearity='linear')
|
60 |
+
|
61 |
+
self.weight = nn.Parameter(
|
62 |
+
torch.randn(1, out_channels, in_channels, kernel_size, kernel_size) /
|
63 |
+
math.sqrt(in_channels * kernel_size**2))
|
64 |
+
self.padding = kernel_size // 2
|
65 |
+
|
66 |
+
def forward(self, x, style):
|
67 |
+
"""Forward function.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
x (Tensor): Tensor with shape (b, c, h, w).
|
71 |
+
style (Tensor): Tensor with shape (b, num_style_feat).
|
72 |
+
|
73 |
+
Returns:
|
74 |
+
Tensor: Modulated tensor after convolution.
|
75 |
+
"""
|
76 |
+
b, c, h, w = x.shape # c = c_in
|
77 |
+
# weight modulation
|
78 |
+
style = self.modulation(style).view(b, 1, c, 1, 1)
|
79 |
+
# self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1)
|
80 |
+
weight = self.weight * style # (b, c_out, c_in, k, k)
|
81 |
+
|
82 |
+
if self.demodulate:
|
83 |
+
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
|
84 |
+
weight = weight * demod.view(b, self.out_channels, 1, 1, 1)
|
85 |
+
|
86 |
+
weight = weight.view(b * self.out_channels, c, self.kernel_size, self.kernel_size)
|
87 |
+
|
88 |
+
# upsample or downsample if necessary
|
89 |
+
if self.sample_mode == 'upsample':
|
90 |
+
x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
|
91 |
+
elif self.sample_mode == 'downsample':
|
92 |
+
x = F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=False)
|
93 |
+
|
94 |
+
b, c, h, w = x.shape
|
95 |
+
x = x.view(1, b * c, h, w)
|
96 |
+
# weight: (b*c_out, c_in, k, k), groups=b
|
97 |
+
out = F.conv2d(x, weight, padding=self.padding, groups=b)
|
98 |
+
out = out.view(b, self.out_channels, *out.shape[2:4])
|
99 |
+
|
100 |
+
return out
|
101 |
+
|
102 |
+
def __repr__(self):
|
103 |
+
return (f'{self.__class__.__name__}(in_channels={self.in_channels}, out_channels={self.out_channels}, '
|
104 |
+
f'kernel_size={self.kernel_size}, demodulate={self.demodulate}, sample_mode={self.sample_mode})')
|
105 |
+
|
106 |
+
|
107 |
+
class StyleConv(nn.Module):
|
108 |
+
"""Style conv used in StyleGAN2.
|
109 |
+
|
110 |
+
Args:
|
111 |
+
in_channels (int): Channel number of the input.
|
112 |
+
out_channels (int): Channel number of the output.
|
113 |
+
kernel_size (int): Size of the convolving kernel.
|
114 |
+
num_style_feat (int): Channel number of style features.
|
115 |
+
demodulate (bool): Whether demodulate in the conv layer. Default: True.
|
116 |
+
sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None.
|
117 |
+
"""
|
118 |
+
|
119 |
+
def __init__(self, in_channels, out_channels, kernel_size, num_style_feat, demodulate=True, sample_mode=None):
|
120 |
+
super(StyleConv, self).__init__()
|
121 |
+
self.modulated_conv = ModulatedConv2d(
|
122 |
+
in_channels, out_channels, kernel_size, num_style_feat, demodulate=demodulate, sample_mode=sample_mode)
|
123 |
+
self.weight = nn.Parameter(torch.zeros(1)) # for noise injection
|
124 |
+
self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1))
|
125 |
+
self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
126 |
+
|
127 |
+
def forward(self, x, style, noise=None):
|
128 |
+
# modulate
|
129 |
+
out = self.modulated_conv(x, style) * 2**0.5 # for conversion
|
130 |
+
# noise injection
|
131 |
+
if noise is None:
|
132 |
+
b, _, h, w = out.shape
|
133 |
+
noise = out.new_empty(b, 1, h, w).normal_()
|
134 |
+
out = out + self.weight * noise
|
135 |
+
# add bias
|
136 |
+
out = out + self.bias
|
137 |
+
# activation
|
138 |
+
out = self.activate(out)
|
139 |
+
return out
|
140 |
+
|
141 |
+
|
142 |
+
class ToRGB(nn.Module):
|
143 |
+
"""To RGB (image space) from features.
|
144 |
+
|
145 |
+
Args:
|
146 |
+
in_channels (int): Channel number of input.
|
147 |
+
num_style_feat (int): Channel number of style features.
|
148 |
+
upsample (bool): Whether to upsample. Default: True.
|
149 |
+
"""
|
150 |
+
|
151 |
+
def __init__(self, in_channels, num_style_feat, upsample=True):
|
152 |
+
super(ToRGB, self).__init__()
|
153 |
+
self.upsample = upsample
|
154 |
+
self.modulated_conv = ModulatedConv2d(
|
155 |
+
in_channels, 3, kernel_size=1, num_style_feat=num_style_feat, demodulate=False, sample_mode=None)
|
156 |
+
self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
|
157 |
+
|
158 |
+
def forward(self, x, style, skip=None):
|
159 |
+
"""Forward function.
|
160 |
+
|
161 |
+
Args:
|
162 |
+
x (Tensor): Feature tensor with shape (b, c, h, w).
|
163 |
+
style (Tensor): Tensor with shape (b, num_style_feat).
|
164 |
+
skip (Tensor): Base/skip tensor. Default: None.
|
165 |
+
|
166 |
+
Returns:
|
167 |
+
Tensor: RGB images.
|
168 |
+
"""
|
169 |
+
out = self.modulated_conv(x, style)
|
170 |
+
out = out + self.bias
|
171 |
+
if skip is not None:
|
172 |
+
if self.upsample:
|
173 |
+
skip = F.interpolate(skip, scale_factor=2, mode='bilinear', align_corners=False)
|
174 |
+
out = out + skip
|
175 |
+
return out
|
176 |
+
|
177 |
+
|
178 |
+
class ConstantInput(nn.Module):
|
179 |
+
"""Constant input.
|
180 |
+
|
181 |
+
Args:
|
182 |
+
num_channel (int): Channel number of constant input.
|
183 |
+
size (int): Spatial size of constant input.
|
184 |
+
"""
|
185 |
+
|
186 |
+
def __init__(self, num_channel, size):
|
187 |
+
super(ConstantInput, self).__init__()
|
188 |
+
self.weight = nn.Parameter(torch.randn(1, num_channel, size, size))
|
189 |
+
|
190 |
+
def forward(self, batch):
|
191 |
+
out = self.weight.repeat(batch, 1, 1, 1)
|
192 |
+
return out
|
193 |
+
|
194 |
+
|
195 |
+
@ARCH_REGISTRY.register()
|
196 |
+
class StyleGAN2GeneratorClean(nn.Module):
|
197 |
+
"""Clean version of StyleGAN2 Generator.
|
198 |
+
|
199 |
+
Args:
|
200 |
+
out_size (int): The spatial size of outputs.
|
201 |
+
num_style_feat (int): Channel number of style features. Default: 512.
|
202 |
+
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
203 |
+
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
204 |
+
narrow (float): Narrow ratio for channels. Default: 1.0.
|
205 |
+
"""
|
206 |
+
|
207 |
+
def __init__(self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1):
|
208 |
+
super(StyleGAN2GeneratorClean, self).__init__()
|
209 |
+
# Style MLP layers
|
210 |
+
self.num_style_feat = num_style_feat
|
211 |
+
style_mlp_layers = [NormStyleCode()]
|
212 |
+
for i in range(num_mlp):
|
213 |
+
style_mlp_layers.extend(
|
214 |
+
[nn.Linear(num_style_feat, num_style_feat, bias=True),
|
215 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True)])
|
216 |
+
self.style_mlp = nn.Sequential(*style_mlp_layers)
|
217 |
+
# initialization
|
218 |
+
default_init_weights(self.style_mlp, scale=1, bias_fill=0, a=0.2, mode='fan_in', nonlinearity='leaky_relu')
|
219 |
+
|
220 |
+
# channel list
|
221 |
+
channels = {
|
222 |
+
'4': int(512 * narrow),
|
223 |
+
'8': int(512 * narrow),
|
224 |
+
'16': int(512 * narrow),
|
225 |
+
'32': int(512 * narrow),
|
226 |
+
'64': int(256 * channel_multiplier * narrow),
|
227 |
+
'128': int(128 * channel_multiplier * narrow),
|
228 |
+
'256': int(64 * channel_multiplier * narrow),
|
229 |
+
'512': int(32 * channel_multiplier * narrow),
|
230 |
+
'1024': int(16 * channel_multiplier * narrow)
|
231 |
+
}
|
232 |
+
self.channels = channels
|
233 |
+
|
234 |
+
self.constant_input = ConstantInput(channels['4'], size=4)
|
235 |
+
self.style_conv1 = StyleConv(
|
236 |
+
channels['4'],
|
237 |
+
channels['4'],
|
238 |
+
kernel_size=3,
|
239 |
+
num_style_feat=num_style_feat,
|
240 |
+
demodulate=True,
|
241 |
+
sample_mode=None)
|
242 |
+
self.to_rgb1 = ToRGB(channels['4'], num_style_feat, upsample=False)
|
243 |
+
|
244 |
+
self.log_size = int(math.log(out_size, 2))
|
245 |
+
self.num_layers = (self.log_size - 2) * 2 + 1
|
246 |
+
self.num_latent = self.log_size * 2 - 2
|
247 |
+
|
248 |
+
self.style_convs = nn.ModuleList()
|
249 |
+
self.to_rgbs = nn.ModuleList()
|
250 |
+
self.noises = nn.Module()
|
251 |
+
|
252 |
+
in_channels = channels['4']
|
253 |
+
# noise
|
254 |
+
for layer_idx in range(self.num_layers):
|
255 |
+
resolution = 2**((layer_idx + 5) // 2)
|
256 |
+
shape = [1, 1, resolution, resolution]
|
257 |
+
self.noises.register_buffer(f'noise{layer_idx}', torch.randn(*shape))
|
258 |
+
# style convs and to_rgbs
|
259 |
+
for i in range(3, self.log_size + 1):
|
260 |
+
out_channels = channels[f'{2**i}']
|
261 |
+
self.style_convs.append(
|
262 |
+
StyleConv(
|
263 |
+
in_channels,
|
264 |
+
out_channels,
|
265 |
+
kernel_size=3,
|
266 |
+
num_style_feat=num_style_feat,
|
267 |
+
demodulate=True,
|
268 |
+
sample_mode='upsample'))
|
269 |
+
self.style_convs.append(
|
270 |
+
StyleConv(
|
271 |
+
out_channels,
|
272 |
+
out_channels,
|
273 |
+
kernel_size=3,
|
274 |
+
num_style_feat=num_style_feat,
|
275 |
+
demodulate=True,
|
276 |
+
sample_mode=None))
|
277 |
+
self.to_rgbs.append(ToRGB(out_channels, num_style_feat, upsample=True))
|
278 |
+
in_channels = out_channels
|
279 |
+
|
280 |
+
def make_noise(self):
|
281 |
+
"""Make noise for noise injection."""
|
282 |
+
device = self.constant_input.weight.device
|
283 |
+
noises = [torch.randn(1, 1, 4, 4, device=device)]
|
284 |
+
|
285 |
+
for i in range(3, self.log_size + 1):
|
286 |
+
for _ in range(2):
|
287 |
+
noises.append(torch.randn(1, 1, 2**i, 2**i, device=device))
|
288 |
+
|
289 |
+
return noises
|
290 |
+
|
291 |
+
def get_latent(self, x):
|
292 |
+
return self.style_mlp(x)
|
293 |
+
|
294 |
+
def mean_latent(self, num_latent):
|
295 |
+
latent_in = torch.randn(num_latent, self.num_style_feat, device=self.constant_input.weight.device)
|
296 |
+
latent = self.style_mlp(latent_in).mean(0, keepdim=True)
|
297 |
+
return latent
|
298 |
+
|
299 |
+
def forward(self,
|
300 |
+
styles,
|
301 |
+
input_is_latent=False,
|
302 |
+
noise=None,
|
303 |
+
randomize_noise=True,
|
304 |
+
truncation=1,
|
305 |
+
truncation_latent=None,
|
306 |
+
inject_index=None,
|
307 |
+
return_latents=False):
|
308 |
+
"""Forward function for StyleGAN2GeneratorClean.
|
309 |
+
|
310 |
+
Args:
|
311 |
+
styles (list[Tensor]): Sample codes of styles.
|
312 |
+
input_is_latent (bool): Whether input is latent style. Default: False.
|
313 |
+
noise (Tensor | None): Input noise or None. Default: None.
|
314 |
+
randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
|
315 |
+
truncation (float): The truncation ratio. Default: 1.
|
316 |
+
truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
|
317 |
+
inject_index (int | None): The injection index for mixing noise. Default: None.
|
318 |
+
return_latents (bool): Whether to return style latents. Default: False.
|
319 |
+
"""
|
320 |
+
# style codes -> latents with Style MLP layer
|
321 |
+
if not input_is_latent:
|
322 |
+
styles = [self.style_mlp(s) for s in styles]
|
323 |
+
# noises
|
324 |
+
if noise is None:
|
325 |
+
if randomize_noise:
|
326 |
+
noise = [None] * self.num_layers # for each style conv layer
|
327 |
+
else: # use the stored noise
|
328 |
+
noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
|
329 |
+
# style truncation
|
330 |
+
if truncation < 1:
|
331 |
+
style_truncation = []
|
332 |
+
for style in styles:
|
333 |
+
style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
|
334 |
+
styles = style_truncation
|
335 |
+
# get style latents with injection
|
336 |
+
if len(styles) == 1:
|
337 |
+
inject_index = self.num_latent
|
338 |
+
|
339 |
+
if styles[0].ndim < 3:
|
340 |
+
# repeat latent code for all the layers
|
341 |
+
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
342 |
+
else: # used for encoder with different latent code for each layer
|
343 |
+
latent = styles[0]
|
344 |
+
elif len(styles) == 2: # mixing noises
|
345 |
+
if inject_index is None:
|
346 |
+
inject_index = random.randint(1, self.num_latent - 1)
|
347 |
+
latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
348 |
+
latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
|
349 |
+
latent = torch.cat([latent1, latent2], 1)
|
350 |
+
|
351 |
+
# main generation
|
352 |
+
out = self.constant_input(latent.shape[0])
|
353 |
+
out = self.style_conv1(out, latent[:, 0], noise=noise[0])
|
354 |
+
skip = self.to_rgb1(out, latent[:, 1])
|
355 |
+
|
356 |
+
i = 1
|
357 |
+
for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
|
358 |
+
noise[2::2], self.to_rgbs):
|
359 |
+
out = conv1(out, latent[:, i], noise=noise1)
|
360 |
+
out = conv2(out, latent[:, i + 1], noise=noise2)
|
361 |
+
skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
|
362 |
+
i += 2
|
363 |
+
|
364 |
+
image = skip
|
365 |
+
|
366 |
+
if return_latents:
|
367 |
+
return image, latent
|
368 |
+
else:
|
369 |
+
return image, None
|
weights/PutWeightsHere
ADDED
File without changes
|