yerang's picture
Update stf/stf-api-alternative/src/stf_alternative/dataset.py
1843e04 verified
import os.path
import random
from glob import glob
from pathlib import Path
import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False)
import numpy as np
import pandas as pd
from torch.utils.data import Dataset
from stf_alternative.s2f_dir.src.mask_history import calc_poly
def frame_id(fname):
return int(os.path.basename(fname).split("_")[0])
def masking(im, pts):
im = cv2.fillPoly(im, [pts], (128, 128, 128))
return im
accepted_format = set([".webp", ".png", ".jpg"])
class LipGanImage(Dataset):
def __init__(self, args, path, num_skip_frames=0):
self.args = args
paths = sorted(
[it for it in glob(f"{path}/*") if Path(it).suffix in accepted_format]
)
self.paths = paths[num_skip_frames:] + paths[:num_skip_frames]
self.mask_ver = (
list(args.mask_ver)
if isinstance(args.mask_ver, (list, tuple))
else [args.mask_ver]
)
self.keying_mask_ver = (
args.keying_mask_ver if "keying_mask_ver" in args else None
)
self.smoothing_mask = True if args.smoothing_mask else False
self.num_ips = args.num_ips
df = pd.read_pickle(path / "df_fan.pickle")
self.df = df.set_index("frame_idx")["cropped_pts2d"]
def __getitem__(self, idx):
img_name = Path(self.paths[idx])
gt_fname = img_name.name
dir_name = img_name.parent
sidx = frame_id(gt_fname)
img_gt = cv2.imread(str(img_name), cv2.IMREAD_UNCHANGED)
masked = img_gt[:, :, :3].copy()
img_ip = masked * 2.0 / 255.0 - 1.0
if self.df[sidx] is None:
# snow : 인사하는 템플릿이 들어오면서 preds 가 없는 경우가 생겼다.
# 이런 경우, 마스크 없이 원래 이미지를 그대로 준다.
mask = np.zeros_like(masked, dtype=np.uint8)
else:
mask_ver = random.choice(self.mask_ver)
pts = calc_poly[mask_ver](self.df[sidx], masked.shape[0], randomness=False)
if self.keying_mask_ver is not None:
keying_pts = calc_poly[self.keying_mask_ver](
self.df[sidx], masked.shape[0], randomness=False
)
else:
keying_pts = pts
if self.smoothing_mask:
pts = smoothing_mask(pts)
masked = masking(masked, pts)
mask = np.zeros_like(masked, dtype=np.uint8)
mask = masking(mask, keying_pts)
img_ips = [img_ip for _ in range(self.num_ips)]
ips = np.concatenate([masked * 2.0 / 255.0 - 1.0] + img_ips, axis=2)
if img_gt.shape[2] == 3:
alpha = np.zeros_like(img_gt[:, :, :1])
alpha.fill(255)
img_gt = np.concatenate([img_gt, alpha], axis=2)
return {
"ips": ips.astype(np.float32),
"mask": mask,
"img_gt_with_alpha": img_gt,
"filename": str(img_name),
}
def __len__(self):
return len(self.paths)
class LipGanRemoteImage(Dataset):
def __init__(self, args, path, num_skip_frames=0):
self.args = args
paths = sorted(
[it for it in glob(f"{path}/*") if Path(it).suffix in accepted_format]
)
self.paths = paths[num_skip_frames:] + paths[:num_skip_frames]
self.num_skip_frames = num_skip_frames
self.mask_ver = (
list(args.mask_ver)
if isinstance(args.mask_ver, (list, tuple))
else [args.mask_ver]
)
self.keying_mask_ver = (
args.keying_mask_ver if "keying_mask_ver" in args else None
)
self.smoothing_mask = True if args.smoothing_mask else False
self.num_ips = args.num_ips
df = pd.read_pickle(path / "df_fan.pickle")
self.df = df.set_index("frame_idx")["cropped_pts2d"]
def __getitem__(self, idx):
img_name = Path(self.paths[idx])
gt_fname = img_name.name
sidx = frame_id(gt_fname)
img_gt = cv2.imread(str(img_name), cv2.IMREAD_UNCHANGED)
masked = img_gt[:, :, :3].copy()
img_ip = img_gt[:, :, :3].copy()
if self.df[sidx] is None:
mask = np.zeros_like(masked, dtype=np.uint8)
else:
mask_ver = random.choice(self.mask_ver)
pts = calc_poly[mask_ver](self.df[sidx], masked.shape[0], randomness=False)
if self.keying_mask_ver is not None:
keying_pts = calc_poly[self.keying_mask_ver](
self.df[sidx], masked.shape[0], randomness=False
)
else:
keying_pts = pts
if self.smoothing_mask:
pts = smoothing_mask(pts)
masked = masking(masked, pts)
mask = np.zeros_like(masked, dtype=np.uint8)
mask = masking(mask, keying_pts)
img_ips = [img_ip for _ in range(self.num_ips)]
ips = np.concatenate([masked] + img_ips, axis=2)
if img_gt.shape[2] == 3:
alpha = np.zeros_like(img_gt[:, :, :1])
alpha.fill(255)
img_gt = np.concatenate([img_gt, alpha], axis=2)
return {
"ips": ips.transpose(2, 0, 1),
"mask": mask,
"img_gt_with_alpha": img_gt,
"filename": str(img_name),
}
def __len__(self):
return len(self.paths)
def get_processed_audio_segment(center_frame_id, processed_wav, fps, sample_rate):
time_center = center_frame_id / fps
center_idx = int(time_center * sample_rate)
center_idx = center_idx // 320
start_idx = center_idx - 39
new_logits = processed_wav.copy()
if start_idx < 0:
new_logits = np.pad(
new_logits, ((-start_idx, 0), (0, 0)), mode="constant", constant_values=0
)
start_idx = 0
end_idx = start_idx + 39 * 2
if len(new_logits) < end_idx:
new_logits = np.pad(
new_logits,
((0, end_idx - len(new_logits)), (0, 0)),
mode="constant",
constant_values=0,
)
return new_logits[start_idx:end_idx, :]
def zero_wav_mels_when_silent_center(
mels, mel_ps, zero_mels, zero=-4, t_secs=0.25, verbose=False
):
if t_secs is None:
return mels
t_size = t_secs * mel_ps
_, t_axis = mels.shape
if t_size >= t_axis:
# 원하는 구간이 원래 보고 있는 구간보다 크다면 그대로 준다.
return mels
t_size_half = int(t_size * 0.5)
if verbose:
print(f"t_axis:{t_axis}, t_size_half: {t_size_half}")
t_axis_s, t_axis_e = int(t_axis / 2) - t_size_half, int(t_axis / 2) + t_size_half
t_axis_s, t_axis_e = max(t_axis_s, 0), min(t_axis_e, t_axis)
if (mels[:, t_axis_s:t_axis_e] == -4).all():
return zero_mels
return mels
class LipGanAudio(Dataset):
def __init__(self, args, id_list, mel, fps):
if args.model_type in ("stf_v1", "stf_v2"):
raise "Did not support version < stf_v3"
self.id_list = id_list
self.mel = mel
self.fps = fps
self.silent_secs = (
None if "silent_secs" not in args.keys() else args["silent_secs"]
)
self.zero_mels = np.full((96, args.mel_step_size), -4, dtype=np.float32)
self.mel_ps = args.mel_ps
def __getitem__(self, idx):
mel = get_processed_audio_segment(self.id_list[idx], self.mel, self.fps, 16000)
mel = zero_wav_mels_when_silent_center(
mels=mel,
mel_ps=self.mel_ps,
zero_mels=self.zero_mels,
t_secs=self.silent_secs,
)
return {
"mel": mel,
}
def __len__(self):
return len(self.id_list)