import os.path import random from glob import glob from pathlib import Path import cv2 import numpy as np import pandas as pd from torch.utils.data import Dataset from stf_alternative.s2f_dir.src.mask_history import calc_poly def frame_id(fname): return int(os.path.basename(fname).split("_")[0]) def masking(im, pts): im = cv2.fillPoly(im, [pts], (128, 128, 128)) return im accepted_format = set([".webp", ".png", ".jpg"]) class LipGanImage(Dataset): def __init__(self, args, path, num_skip_frames=0): self.args = args paths = sorted( [it for it in glob(f"{path}/*") if Path(it).suffix in accepted_format] ) self.paths = paths[num_skip_frames:] + paths[:num_skip_frames] self.mask_ver = ( list(args.mask_ver) if isinstance(args.mask_ver, (list, tuple)) else [args.mask_ver] ) self.keying_mask_ver = ( args.keying_mask_ver if "keying_mask_ver" in args else None ) self.smoothing_mask = True if args.smoothing_mask else False self.num_ips = args.num_ips df = pd.read_pickle(path / "df_fan.pickle") self.df = df.set_index("frame_idx")["cropped_pts2d"] def __getitem__(self, idx): img_name = Path(self.paths[idx]) gt_fname = img_name.name dir_name = img_name.parent sidx = frame_id(gt_fname) img_gt = cv2.imread(str(img_name), cv2.IMREAD_UNCHANGED) masked = img_gt[:, :, :3].copy() img_ip = masked * 2.0 / 255.0 - 1.0 if self.df[sidx] is None: # snow : 인사하는 템플릿이 들어오면서 preds 가 없는 경우가 생겼다. # 이런 경우, 마스크 없이 원래 이미지를 그대로 준다. mask = np.zeros_like(masked, dtype=np.uint8) else: mask_ver = random.choice(self.mask_ver) pts = calc_poly[mask_ver](self.df[sidx], masked.shape[0], randomness=False) if self.keying_mask_ver is not None: keying_pts = calc_poly[self.keying_mask_ver]( self.df[sidx], masked.shape[0], randomness=False ) else: keying_pts = pts if self.smoothing_mask: pts = smoothing_mask(pts) masked = masking(masked, pts) mask = np.zeros_like(masked, dtype=np.uint8) mask = masking(mask, keying_pts) img_ips = [img_ip for _ in range(self.num_ips)] ips = np.concatenate([masked * 2.0 / 255.0 - 1.0] + img_ips, axis=2) if img_gt.shape[2] == 3: alpha = np.zeros_like(img_gt[:, :, :1]) alpha.fill(255) img_gt = np.concatenate([img_gt, alpha], axis=2) return { "ips": ips.astype(np.float32), "mask": mask, "img_gt_with_alpha": img_gt, "filename": str(img_name), } def __len__(self): return len(self.paths) class LipGanRemoteImage(Dataset): def __init__(self, args, path, num_skip_frames=0): self.args = args paths = sorted( [it for it in glob(f"{path}/*") if Path(it).suffix in accepted_format] ) self.paths = paths[num_skip_frames:] + paths[:num_skip_frames] self.num_skip_frames = num_skip_frames self.mask_ver = ( list(args.mask_ver) if isinstance(args.mask_ver, (list, tuple)) else [args.mask_ver] ) self.keying_mask_ver = ( args.keying_mask_ver if "keying_mask_ver" in args else None ) self.smoothing_mask = True if args.smoothing_mask else False self.num_ips = args.num_ips df = pd.read_pickle(path / "df_fan.pickle") self.df = df.set_index("frame_idx")["cropped_pts2d"] def __getitem__(self, idx): img_name = Path(self.paths[idx]) gt_fname = img_name.name sidx = frame_id(gt_fname) img_gt = cv2.imread(str(img_name), cv2.IMREAD_UNCHANGED) masked = img_gt[:, :, :3].copy() img_ip = img_gt[:, :, :3].copy() if self.df[sidx] is None: mask = np.zeros_like(masked, dtype=np.uint8) else: mask_ver = random.choice(self.mask_ver) pts = calc_poly[mask_ver](self.df[sidx], masked.shape[0], randomness=False) if self.keying_mask_ver is not None: keying_pts = calc_poly[self.keying_mask_ver]( self.df[sidx], masked.shape[0], randomness=False ) else: keying_pts = pts if self.smoothing_mask: pts = smoothing_mask(pts) masked = masking(masked, pts) mask = np.zeros_like(masked, dtype=np.uint8) mask = masking(mask, keying_pts) img_ips = [img_ip for _ in range(self.num_ips)] ips = np.concatenate([masked] + img_ips, axis=2) if img_gt.shape[2] == 3: alpha = np.zeros_like(img_gt[:, :, :1]) alpha.fill(255) img_gt = np.concatenate([img_gt, alpha], axis=2) return { "ips": ips.transpose(2, 0, 1), "mask": mask, "img_gt_with_alpha": img_gt, "filename": str(img_name), } def __len__(self): return len(self.paths) def get_processed_audio_segment(center_frame_id, processed_wav, fps, sample_rate): time_center = center_frame_id / fps center_idx = int(time_center * sample_rate) center_idx = center_idx // 320 start_idx = center_idx - 39 new_logits = processed_wav.copy() if start_idx < 0: new_logits = np.pad( new_logits, ((-start_idx, 0), (0, 0)), mode="constant", constant_values=0 ) start_idx = 0 end_idx = start_idx + 39 * 2 if len(new_logits) < end_idx: new_logits = np.pad( new_logits, ((0, end_idx - len(new_logits)), (0, 0)), mode="constant", constant_values=0, ) return new_logits[start_idx:end_idx, :] def zero_wav_mels_when_silent_center( mels, mel_ps, zero_mels, zero=-4, t_secs=0.25, verbose=False ): if t_secs is None: return mels t_size = t_secs * mel_ps _, t_axis = mels.shape if t_size >= t_axis: # 원하는 구간이 원래 보고 있는 구간보다 크다면 그대로 준다. return mels t_size_half = int(t_size * 0.5) if verbose: print(f"t_axis:{t_axis}, t_size_half: {t_size_half}") t_axis_s, t_axis_e = int(t_axis / 2) - t_size_half, int(t_axis / 2) + t_size_half t_axis_s, t_axis_e = max(t_axis_s, 0), min(t_axis_e, t_axis) if (mels[:, t_axis_s:t_axis_e] == -4).all(): return zero_mels return mels class LipGanAudio(Dataset): def __init__(self, args, id_list, mel, fps): if args.model_type in ("stf_v1", "stf_v2"): raise "Did not support version < stf_v3" self.id_list = id_list self.mel = mel self.fps = fps self.silent_secs = ( None if "silent_secs" not in args.keys() else args["silent_secs"] ) self.zero_mels = np.full((96, args.mel_step_size), -4, dtype=np.float32) self.mel_ps = args.mel_ps def __getitem__(self, idx): mel = get_processed_audio_segment(self.id_list[idx], self.mel, self.fps, 16000) mel = zero_wav_mels_when_silent_center( mels=mel, mel_ps=self.mel_ps, zero_mels=self.zero_mels, t_secs=self.silent_secs, ) return { "mel": mel, } def __len__(self): return len(self.id_list)