Spaces:
Sleeping
Sleeping
import os.path | |
import random | |
from glob import glob | |
from pathlib import Path | |
import cv2 | |
import numpy as np | |
import pandas as pd | |
from torch.utils.data import Dataset | |
from stf_alternative.s2f_dir.src.mask_history import calc_poly | |
def frame_id(fname): | |
return int(os.path.basename(fname).split("_")[0]) | |
def masking(im, pts): | |
im = cv2.fillPoly(im, [pts], (128, 128, 128)) | |
return im | |
accepted_format = set([".webp", ".png", ".jpg"]) | |
class LipGanImage(Dataset): | |
def __init__(self, args, path, num_skip_frames=0): | |
self.args = args | |
paths = sorted( | |
[it for it in glob(f"{path}/*") if Path(it).suffix in accepted_format] | |
) | |
self.paths = paths[num_skip_frames:] + paths[:num_skip_frames] | |
self.mask_ver = ( | |
list(args.mask_ver) | |
if isinstance(args.mask_ver, (list, tuple)) | |
else [args.mask_ver] | |
) | |
self.keying_mask_ver = ( | |
args.keying_mask_ver if "keying_mask_ver" in args else None | |
) | |
self.smoothing_mask = True if args.smoothing_mask else False | |
self.num_ips = args.num_ips | |
df = pd.read_pickle(path / "df_fan.pickle") | |
self.df = df.set_index("frame_idx")["cropped_pts2d"] | |
def __getitem__(self, idx): | |
img_name = Path(self.paths[idx]) | |
gt_fname = img_name.name | |
dir_name = img_name.parent | |
sidx = frame_id(gt_fname) | |
img_gt = cv2.imread(str(img_name), cv2.IMREAD_UNCHANGED) | |
masked = img_gt[:, :, :3].copy() | |
img_ip = masked * 2.0 / 255.0 - 1.0 | |
if self.df[sidx] is None: | |
# snow : 인사하는 템플릿이 들어오면서 preds 가 없는 경우가 생겼다. | |
# 이런 경우, 마스크 없이 원래 이미지를 그대로 준다. | |
mask = np.zeros_like(masked, dtype=np.uint8) | |
else: | |
mask_ver = random.choice(self.mask_ver) | |
pts = calc_poly[mask_ver](self.df[sidx], masked.shape[0], randomness=False) | |
if self.keying_mask_ver is not None: | |
keying_pts = calc_poly[self.keying_mask_ver]( | |
self.df[sidx], masked.shape[0], randomness=False | |
) | |
else: | |
keying_pts = pts | |
if self.smoothing_mask: | |
pts = smoothing_mask(pts) | |
masked = masking(masked, pts) | |
mask = np.zeros_like(masked, dtype=np.uint8) | |
mask = masking(mask, keying_pts) | |
img_ips = [img_ip for _ in range(self.num_ips)] | |
ips = np.concatenate([masked * 2.0 / 255.0 - 1.0] + img_ips, axis=2) | |
if img_gt.shape[2] == 3: | |
alpha = np.zeros_like(img_gt[:, :, :1]) | |
alpha.fill(255) | |
img_gt = np.concatenate([img_gt, alpha], axis=2) | |
return { | |
"ips": ips.astype(np.float32), | |
"mask": mask, | |
"img_gt_with_alpha": img_gt, | |
"filename": str(img_name), | |
} | |
def __len__(self): | |
return len(self.paths) | |
class LipGanRemoteImage(Dataset): | |
def __init__(self, args, path, num_skip_frames=0): | |
self.args = args | |
paths = sorted( | |
[it for it in glob(f"{path}/*") if Path(it).suffix in accepted_format] | |
) | |
self.paths = paths[num_skip_frames:] + paths[:num_skip_frames] | |
self.num_skip_frames = num_skip_frames | |
self.mask_ver = ( | |
list(args.mask_ver) | |
if isinstance(args.mask_ver, (list, tuple)) | |
else [args.mask_ver] | |
) | |
self.keying_mask_ver = ( | |
args.keying_mask_ver if "keying_mask_ver" in args else None | |
) | |
self.smoothing_mask = True if args.smoothing_mask else False | |
self.num_ips = args.num_ips | |
df = pd.read_pickle(path / "df_fan.pickle") | |
self.df = df.set_index("frame_idx")["cropped_pts2d"] | |
def __getitem__(self, idx): | |
img_name = Path(self.paths[idx]) | |
gt_fname = img_name.name | |
sidx = frame_id(gt_fname) | |
img_gt = cv2.imread(str(img_name), cv2.IMREAD_UNCHANGED) | |
masked = img_gt[:, :, :3].copy() | |
img_ip = img_gt[:, :, :3].copy() | |
if self.df[sidx] is None: | |
mask = np.zeros_like(masked, dtype=np.uint8) | |
else: | |
mask_ver = random.choice(self.mask_ver) | |
pts = calc_poly[mask_ver](self.df[sidx], masked.shape[0], randomness=False) | |
if self.keying_mask_ver is not None: | |
keying_pts = calc_poly[self.keying_mask_ver]( | |
self.df[sidx], masked.shape[0], randomness=False | |
) | |
else: | |
keying_pts = pts | |
if self.smoothing_mask: | |
pts = smoothing_mask(pts) | |
masked = masking(masked, pts) | |
mask = np.zeros_like(masked, dtype=np.uint8) | |
mask = masking(mask, keying_pts) | |
img_ips = [img_ip for _ in range(self.num_ips)] | |
ips = np.concatenate([masked] + img_ips, axis=2) | |
if img_gt.shape[2] == 3: | |
alpha = np.zeros_like(img_gt[:, :, :1]) | |
alpha.fill(255) | |
img_gt = np.concatenate([img_gt, alpha], axis=2) | |
return { | |
"ips": ips.transpose(2, 0, 1), | |
"mask": mask, | |
"img_gt_with_alpha": img_gt, | |
"filename": str(img_name), | |
} | |
def __len__(self): | |
return len(self.paths) | |
def get_processed_audio_segment(center_frame_id, processed_wav, fps, sample_rate): | |
time_center = center_frame_id / fps | |
center_idx = int(time_center * sample_rate) | |
center_idx = center_idx // 320 | |
start_idx = center_idx - 39 | |
new_logits = processed_wav.copy() | |
if start_idx < 0: | |
new_logits = np.pad( | |
new_logits, ((-start_idx, 0), (0, 0)), mode="constant", constant_values=0 | |
) | |
start_idx = 0 | |
end_idx = start_idx + 39 * 2 | |
if len(new_logits) < end_idx: | |
new_logits = np.pad( | |
new_logits, | |
((0, end_idx - len(new_logits)), (0, 0)), | |
mode="constant", | |
constant_values=0, | |
) | |
return new_logits[start_idx:end_idx, :] | |
def zero_wav_mels_when_silent_center( | |
mels, mel_ps, zero_mels, zero=-4, t_secs=0.25, verbose=False | |
): | |
if t_secs is None: | |
return mels | |
t_size = t_secs * mel_ps | |
_, t_axis = mels.shape | |
if t_size >= t_axis: | |
# 원하는 구간이 원래 보고 있는 구간보다 크다면 그대로 준다. | |
return mels | |
t_size_half = int(t_size * 0.5) | |
if verbose: | |
print(f"t_axis:{t_axis}, t_size_half: {t_size_half}") | |
t_axis_s, t_axis_e = int(t_axis / 2) - t_size_half, int(t_axis / 2) + t_size_half | |
t_axis_s, t_axis_e = max(t_axis_s, 0), min(t_axis_e, t_axis) | |
if (mels[:, t_axis_s:t_axis_e] == -4).all(): | |
return zero_mels | |
return mels | |
class LipGanAudio(Dataset): | |
def __init__(self, args, id_list, mel, fps): | |
if args.model_type in ("stf_v1", "stf_v2"): | |
raise "Did not support version < stf_v3" | |
self.id_list = id_list | |
self.mel = mel | |
self.fps = fps | |
self.silent_secs = ( | |
None if "silent_secs" not in args.keys() else args["silent_secs"] | |
) | |
self.zero_mels = np.full((96, args.mel_step_size), -4, dtype=np.float32) | |
self.mel_ps = args.mel_ps | |
def __getitem__(self, idx): | |
mel = get_processed_audio_segment(self.id_list[idx], self.mel, self.fps, 16000) | |
mel = zero_wav_mels_when_silent_center( | |
mels=mel, | |
mel_ps=self.mel_ps, | |
zero_mels=self.zero_mels, | |
t_secs=self.silent_secs, | |
) | |
return { | |
"mel": mel, | |
} | |
def __len__(self): | |
return len(self.id_list) | |