DifFace / sampler.py
Zongsheng
first upload
06f26d7
raw
history blame
13.6 kB
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Power by Zongsheng Yue 2022-07-13 16:59:27
import os
import random
import numpy as np
from math import ceil
from pathlib import Path
from einops import rearrange
from omegaconf import OmegaConf
from skimage import img_as_ubyte
from ResizeRight.resize_right import resize
from utils import util_net
from utils import util_image
from utils import util_common
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from basicsr.utils import img2tensor
from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils.realesrgan_utils import RealESRGANer
from facelib.utils.face_restoration_helper import FaceRestoreHelper
class BaseSampler:
def __init__(self, configs):
'''
Input:
configs: config, see the yaml file in folder ./configs/sample/
'''
self.configs = configs
self.display = configs.display
self.diffusion_cfg = configs.diffusion
self.setup_dist() # setup distributed training: self.num_gpus, self.rank
self.setup_seed() # setup seed
self.build_model()
def setup_seed(self, seed=None):
seed = self.configs.seed if seed is None else seed
seed += (self.rank+1) * 10000
if self.rank == 0 and self.display:
print(f'Setting random seed {seed}')
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def setup_dist(self):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.rank = 0
def build_model(self):
obj = util_common.get_obj_from_str(self.configs.diffusion.target)
self.diffusion = obj(**self.configs.diffusion.params)
obj = util_common.get_obj_from_str(self.configs.model.target)
model = obj(**self.configs.model.params).to(self.device)
if not self.configs.model.ckpt_path is None:
self.load_model(model, self.configs.model.ckpt_path)
self.model = model
self.model.eval()
def load_model(self, model, ckpt_path=None):
if not ckpt_path is None:
if self.rank == 0 and self.display:
print(f'Loading from {ckpt_path}...')
ckpt = torch.load(ckpt_path, map_location=f"cuda:{self.rank}")
util_net.reload_model(model, ckpt)
if self.rank == 0 and self.display:
print('Loaded Done')
def reset_diffusion(self, diffusion_cfg):
self.diffusion = create_gaussian_diffusion(**diffusion_cfg)
class DifIRSampler(BaseSampler):
def build_model(self):
super().build_model()
if not self.configs.model_ir is None:
obj = util_common.get_obj_from_str(self.configs.model_ir.target)
model_ir = obj(**self.configs.model_ir.params).cuda()
if not self.configs.model_ir.ckpt_path is None:
self.load_model(model_ir, self.configs.model_ir.ckpt_path)
self.model_ir = model_ir
self.model_ir.eval()
if not self.configs.aligned:
# face dection model
self.face_helper = FaceRestoreHelper(
self.configs.detection.upscale,
face_size=self.configs.im_size,
crop_ratio=(1, 1),
det_model = self.configs.detection.det_model,
save_ext='png',
use_parse=True,
device=self.device,
)
# background super-resolution
if self.configs.background_enhance or self.configs.face_upsample:
bg_model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=2,
)
self.bg_model = RealESRGANer(
scale=2,
model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
model=bg_model,
tile=400,
tile_pad=10,
pre_pad=0,
half=True,
device=torch.device(f'cuda:{self.rank}'),
) # need to set False in CPU mode
def sample_func_ir_aligned(
self,
y0,
start_timesteps=None,
post_fun=None,
model_kwargs_ir=None,
need_restoration=True,
):
'''
Input:
y0: n x c x h x w torch tensor, low-quality image, [0, 1], RGB
or, h x w x c, numpy array, [0, 255], uint8, BGR
start_timesteps: integer, range [0, num_timesteps-1],
for accelerated sampling (e.g., 'ddim250'), range [0, 249]
post_fun: post-processing for the enhanced image
model_kwargs_ir: additional parameters for restoration model
Output:
sample: n x c x h x w, torch tensor, [0,1], RGB
'''
if not isinstance(y0, torch.Tensor):
y0 = img2tensor(y0, bgr2rgb=True, float32=True).unsqueeze(0) / 255. # 1 x c x h x w, [0,1]
if start_timesteps is None:
start_timesteps = self.diffusion.num_timesteps
if post_fun is None:
post_fun = lambda x: util_image.normalize_th(
im=x,
mean=0.5,
std=0.5,
reverse=False,
)
# basical image restoration
device = next(self.model.parameters()).device
y0 = y0.to(device=device, dtype=torch.float32)
if need_restoration:
with torch.no_grad():
if model_kwargs_ir is None:
im_hq = self.model_ir(y0)
else:
im_hq = self.model_ir(y0, **model_kwargs_ir)
else:
im_hq = y0
im_hq.clamp_(0.0, 1.0)
h_old, w_old = im_hq.shape[2:4]
if not (h_old == self.configs.im_size and w_old == self.configs.im_size):
im_hq = resize(im_hq, out_shape=(self.configs.im_size,) * 2).to(torch.float32)
# diffuse for im_hq
yt = self.diffusion.q_sample(
x_start=post_fun(im_hq),
t=torch.tensor([start_timesteps,]*im_hq.shape[0], device=device),
)
assert yt.shape[-1] == self.configs.im_size and yt.shape[-2] == self.configs.im_size
if 'ddim' in self.configs.diffusion.params.timestep_respacing:
sample = self.diffusion.ddim_sample_loop(
self.model,
shape=yt.shape,
noise=yt,
start_timesteps=start_timesteps,
clip_denoised=True,
denoised_fn=None,
model_kwargs=None,
device=None,
progress=False,
eta=0.0,
)
else:
sample = self.diffusion.p_sample_loop(
self.model,
shape=yt.shape,
noise=yt,
start_timesteps=start_timesteps,
clip_denoised=True,
denoised_fn=None,
model_kwargs=None,
device=None,
progress=False,
)
sample = util_image.normalize_th(sample, reverse=True).clamp(0.0, 1.0)
if not (h_old == self.configs.im_size and w_old == self.configs.im_size):
sample = resize(sample, out_shape=(h_old, w_old)).clamp(0.0, 1.0)
return sample, im_hq
def sample_func_bfr_unaligned(
self,
y0,
bs=16,
start_timesteps=None,
post_fun=None,
model_kwargs_ir=None,
need_restoration=True,
only_center_face=False,
draw_box=False,
):
'''
Input:
y0: h x w x c numpy array, uint8, BGR
bs: batch size for face restoration
upscale: upsampling factor for the restorated image
start_timesteps: integer, range [0, num_timesteps-1],
for accelerated sampling (e.g., 'ddim250'), range [0, 249]
post_fun: post-processing for the enhanced image
model_kwargs_ir: additional parameters for restoration model
only_center_face:
draw_box: draw a box for each face
Output:
restored_img: h x w x c, numpy array, uint8, BGR
restored_faces: list, h x w x c, numpy array, uint8, BGR
cropped_faces: list, h x w x c, numpy array, uint8, BGR
'''
def _process_batch(cropped_faces_list):
length = len(cropped_faces_list)
cropped_face_t = np.stack(
img2tensor(cropped_faces_list, bgr2rgb=True, float32=True),
axis=0) / 255.
cropped_face_t = torch.from_numpy(cropped_face_t).to(torch.device(f"cuda:{self.rank}"))
restored_faces = self.sample_func_ir_aligned(
cropped_face_t,
start_timesteps=start_timesteps,
post_fun=post_fun,
model_kwargs_ir=model_kwargs_ir,
need_restoration=need_restoration,
)[0] # [0, 1], b x c x h x w
return restored_faces
assert not self.configs.aligned
self.face_helper.clean_all()
self.face_helper.read_image(y0)
num_det_faces = self.face_helper.get_face_landmarks_5(
only_center_face=only_center_face,
resize=640,
eye_dist_threshold=5,
)
# align and warp each face
self.face_helper.align_warp_face()
num_cropped_face = len(self.face_helper.cropped_faces)
if num_cropped_face > bs:
restored_faces = []
for idx_start in range(0, num_cropped_face, bs):
idx_end = idx_start + bs if idx_start + bs < num_cropped_face else num_cropped_face
current_cropped_faces = self.face_helper.cropped_faces[idx_start:idx_end]
current_restored_faces = _process_batch(current_cropped_faces)
current_restored_faces = util_image.tensor2img(
list(current_restored_faces.split(1, dim=0)),
rgb2bgr=True,
min_max=(0, 1),
out_type=np.uint8,
)
restored_faces.extend(current_restored_faces)
else:
restored_faces = _process_batch(self.face_helper.cropped_faces)
restored_faces = util_image.tensor2img(
list(restored_faces.split(1, dim=0)),
rgb2bgr=True,
min_max=(0, 1),
out_type=np.uint8,
)
for xx in restored_faces:
self.face_helper.add_restored_face(xx)
# paste_back
if self.configs.background_enhance:
bg_img = self.bg_model.enhance(y0, outscale=self.configs.detection.upscale)[0]
else:
bg_img = None
self.face_helper.get_inverse_affine(None)
# paste each restored face to the input image
if self.configs.face_upsample:
restored_img = self.face_helper.paste_faces_to_input_image(
upsample_img=bg_img,
draw_box=draw_box,
face_upsampler=self.bg_model,
)
else:
restored_img = self.face_helper.paste_faces_to_input_image(
upsample_img=bg_img,
draw_box=draw_box,
)
cropped_faces = self.face_helper.cropped_faces
return restored_img, restored_faces, cropped_faces
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--save_dir",
type=str,
default="./save_dir",
help="Folder to save the checkpoints and training log",
)
parser.add_argument(
"--gpu_id",
type=str,
default='',
help="GPU Index, e.g., 025",
)
parser.add_argument(
"--cfg_path",
type=str,
default='./configs/sample/iddpm_ffhq256.yaml',
help="Path of config files",
)
parser.add_argument(
"--bs",
type=int,
default=32,
help="Batch size",
)
parser.add_argument(
"--num_images",
type=int,
default=3000,
help="Number of sampled images",
)
parser.add_argument(
"--timestep_respacing",
type=str,
default='1000',
help="Sampling steps for accelerate",
)
args = parser.parse_args()
configs = OmegaConf.load(args.cfg_path)
configs.gpu_id = args.gpu_id
configs.diffusion.params.timestep_respacing = args.timestep_respacing
sampler_dist = DiffusionSampler(configs)
sampler_dist.sample_func(
bs=args.bs,
num_images=args.num_images,
save_dir=args.save_dir,
)