yerang's picture
Upload 1110 files
e3af00f verified
import asyncio
import json
from pathlib import Path
import asyncstdlib
import numpy as np
import pandas as pd
from pydub import AudioSegment
from stf_alternative.compose import get_compose_func_without_keying, get_keying_func
from stf_alternative.dataset import LipGanAudio, LipGanImage, LipGanRemoteImage
from stf_alternative.inference import (
adictzip,
ainference_model_remote,
audio_encode,
dictzip,
get_head_box,
inference_model,
inference_model_remote,
)
from stf_alternative.preprocess_dir.utils import face_finder as ff
from stf_alternative.readers import (
AsyncProcessPoolBatchIterator,
ProcessPoolBatchIterator,
get_image_folder_async_process_reader,
get_image_folder_process_reader,
)
from stf_alternative.util import (
acycle,
get_crop_mp4_dir,
get_frame_dir,
get_preprocess_dir,
icycle,
read_config,
)
def calc_audio_std(audio_segment):
sample = np.array(audio_segment.get_array_of_samples(), dtype=np.int16)
max_value = np.iinfo(
np.int8
if audio_segment.sample_width == 1
else np.int16
if audio_segment.sample_width == 2
else np.int32
).max
return sample.std() / max_value, len(sample)
class RunningAudioNormalizer:
def __init__(self, ref_audio_segment, decay_rate=0.01):
self.ref_std, _ = calc_audio_std(ref_audio_segment)
self.running_var = np.float64(0)
self.running_cnt = 0
self.decay_rate = decay_rate
def __call__(self, audio_segment):
std, cnt = calc_audio_std(audio_segment)
self.running_var = (self.running_var + (std**2) * cnt) * (1 - self.decay_rate)
self.running_cnt = (self.running_cnt + cnt) * (1 - self.decay_rate)
return audio_segment._spawn(
(audio_segment.get_array_of_samples() / self.std * self.ref_std)
.astype(np.int16)
.tobytes()
)
@property
def std(self):
return np.sqrt(self.running_var / self.running_cnt)
def get_video_metadata(preprocess_dir):
json_path = preprocess_dir / "metadata.json"
with open(json_path, "r") as f:
return json.load(f)
class Template:
def __init__(
self,
config_path,
model,
template_video_path,
wav_std=False,
ref_wav=None,
verbose=False,
):
self.config = read_config(config_path)
self.model = model
self.template_video_path = Path(template_video_path)
self.preprocess_dir = Path(
get_preprocess_dir(model.work_root_path, model.args.name)
)
self.crop_mp4_dir = Path(
get_crop_mp4_dir(self.preprocess_dir, template_video_path)
)
self.dataset_dir = self.crop_mp4_dir / f"{Path(template_video_path).stem}_000"
self.template_frames_path = Path(
get_frame_dir(self.preprocess_dir, template_video_path, ratio=1.0)
)
self.verbose = verbose
self.remote = self.model.args.model_type == "remote"
self.audio_normalizer = (
RunningAudioNormalizer(ref_wav) if wav_std else lambda x: x
)
self.df = pd.read_pickle(self.dataset_dir / "df_fan.pickle")
metadata = get_video_metadata(self.preprocess_dir)
self.fps = metadata["fps"]
self.width, self.height = metadata["width"], metadata["height"]
self.keying_func = get_keying_func(self)
self.compose_func = get_compose_func_without_keying(self, ratio=1.0)
self.move = "move" in self.config.keys() and self.config.move
self.inference_func = inference_model_remote if self.remote else inference_model
self.batch_size = self.model.args.batch_size
self.unit = 1000 / self.fps
def _get_reader(self, num_skip_frames):
assert self.template_frames_path.exists()
return get_image_folder_process_reader(
data_path=self.template_frames_path,
num_skip_frames=num_skip_frames,
preload=self.batch_size,
)
def _get_local_face_dataset(self, num_skip_frames):
return LipGanImage(
args=self.model.args,
path=self.dataset_dir,
num_skip_frames=num_skip_frames,
)
def _get_remote_face_dataset(self, num_skip_frames):
return LipGanRemoteImage(
args=self.model.args,
path=self.dataset_dir,
num_skip_frames=num_skip_frames,
)
def _get_mel_dataset(self, audio_segment):
image_count = round(
audio_segment.duration_seconds * self.fps
) # 패딩 했기 때문에 batch_size로 나뉜다
ids = list(range(image_count))
mel = audio_encode(
model=self.model,
audio_segment=audio_segment,
device=self.model.device,
)
return LipGanAudio(
args=self.model.args,
id_list=ids,
mel=mel,
fps=self.fps,
)
def _get_face_dataset(self, num_skip_frames):
if self.remote:
return self._get_remote_face_dataset(num_skip_frames=num_skip_frames)
else:
return self._get_local_face_dataset(num_skip_frames=num_skip_frames)
def _wrap_reader(self, reader):
reader = icycle(reader)
return reader
def _wrap_dataset(self, dataset):
dataloader = ProcessPoolBatchIterator(
dataset=dataset,
batch_size=self.batch_size,
)
return dataloader
def get_reader(self, num_skip_frames=0):
reader = self._get_reader(num_skip_frames=num_skip_frames)
reader = self._wrap_reader(reader)
return reader
def get_mel_loader(self, audio_segment):
mel_dataset = self._get_mel_dataset(audio_segment)
return self._wrap_dataset(mel_dataset)
def get_face_loader(self, num_skip_frames=0):
face_dataset = self._get_face_dataset(num_skip_frames=num_skip_frames)
return self._wrap_dataset(face_dataset) # need cycle
# padding according to batch size.
def pad(self, audio_segment):
num_frames = audio_segment.duration_seconds * self.fps
pad = AudioSegment.silent(
(self.batch_size - (num_frames % self.batch_size)) * (1000 / self.fps)
)
return audio_segment + pad
def _prepare_data(
self,
audio_segment,
video_start_offset_frame,
):
video_start_offset_frame = video_start_offset_frame % len(self.df)
padded = self.pad(audio_segment)
face_dataset = self._get_face_dataset(num_skip_frames=video_start_offset_frame)
mel_dataset = self._get_mel_dataset(audio_segment=padded)
n_frames = len(mel_dataset)
assert n_frames % self.batch_size == 0
face_loader = self._wrap_dataset(face_dataset)
mel_loader = self._wrap_dataset(mel_dataset)
return padded, face_loader, mel_loader
def gen_infer(
self,
audio_segment,
video_start_offset_frame,
):
padded, face_loader, mel_loader = self._prepare_data(
audio_segment=audio_segment,
video_start_offset_frame=video_start_offset_frame,
)
for i, v in enumerate(dictzip(iter(mel_loader), iter(face_loader))):
inferred = self.inference_func(self.model, v, self.model.device)
for j, it in enumerate(inferred):
chunk_pivot = i * self.unit * self.batch_size + j * self.unit
chunk = padded[chunk_pivot : chunk_pivot + self.unit]
yield it, chunk
def gen_infer_batch(
self,
audio_segment,
video_start_offset_frame,
):
padded, face_loader, mel_loader = self._prepare_data(
audio_segment=audio_segment,
video_start_offset_frame=video_start_offset_frame,
)
for i, v in enumerate(dictzip(iter(mel_loader), iter(face_loader))):
inferred = self.inference_func(self.model, v, self.model.device)
yield inferred, padded[
i * self.unit * self.batch_size : (i + 1) * self.unit * self.batch_size
]
def gen_infer_batch_future(
self,
pool,
audio_segment,
video_start_offset_frame,
):
padded, face_loader, mel_loader = self._prepare_data(
audio_segment=audio_segment,
video_start_offset_frame=video_start_offset_frame,
)
futures = []
for i, v in enumerate(dictzip(iter(mel_loader), iter(face_loader))):
futures.append(
pool.submit(self.inference_func, self.model, v, self.model.device)
)
for i, future in enumerate(futures):
yield future, padded[
i * self.unit * self.batch_size : (i + 1) * self.unit * self.batch_size
]
def gen_infer_concurrent(
self,
pool,
audio_segment,
video_start_offset_frame,
):
for future, chunk in self.gen_infer_batch_future(
pool, audio_segment, video_start_offset_frame
):
for i, inferred in enumerate(future.result()):
yield inferred, chunk[i * self.unit : (i + 1) * self.unit]
def compose(
self,
idx,
frame,
output,
):
head_box_idx = idx % len(self.df)
head_box = get_head_box(
self.df,
move=self.move,
head_box_idx=head_box_idx,
)
alpha2 = self.keying_func(output, head_box_idx, head_box)
frame = self.compose_func(alpha2, frame[:, :, :4], head_box_idx)
return frame
def gen_frames(
self,
audio_segment,
video_start_offset_frame,
reader=None,
):
reader = reader or self.get_reader(num_skip_frames=video_start_offset_frame)
gen_infer = self.gen_infer(audio_segment, video_start_offset_frame)
for idx, ((o, a), f) in enumerate(
zip(gen_infer, reader), video_start_offset_frame
):
composed = self.compose(idx, f, o)
yield composed, a
def gen_frames_concurrent(
self,
pool,
audio_segment,
video_start_offset_frame,
reader=None,
):
reader = reader or self.get_reader(num_skip_frames=video_start_offset_frame)
gen_infer = self.gen_infer_concurrent(
pool,
audio_segment,
video_start_offset_frame,
)
for idx, ((o, a), f) in enumerate(
zip(gen_infer, reader), video_start_offset_frame
):
yield self.compose(idx, f, o), a
class AsyncTemplate(Template):
async def agen_infer_batch_future(
self,
pool,
audio_segment,
video_start_offset_frame,
):
assert self.remote
padded, face_loader, mel_loader = await self._aprepare_data(
pool,
audio_segment=audio_segment,
video_start_offset_frame=video_start_offset_frame,
)
futures = []
async for i, v in asyncstdlib.enumerate(
adictzip(aiter(mel_loader), aiter(face_loader))
):
futures.append(
asyncio.create_task(
ainference_model_remote(pool, self.model, v, self.model.device)
)
)
for i, future in enumerate(futures):
yield future, padded[
i * self.unit * self.batch_size : (i + 1) * self.unit * self.batch_size
]
async def _awrap_dataset(self, dataset):
dataloader = AsyncProcessPoolBatchIterator(
dataset=dataset,
batch_size=self.batch_size,
)
return dataloader
async def _aprepare_data(
self,
pool,
audio_segment,
video_start_offset_frame,
):
video_start_offset_frame = video_start_offset_frame % len(self.df)
padded = self.pad(audio_segment)
loop = asyncio.get_running_loop()
face_dataset, mel_dataset = await asyncio.gather(
loop.run_in_executor(
pool, self._get_face_dataset, video_start_offset_frame
),
loop.run_in_executor(pool, self._get_mel_dataset, padded),
)
n_frames = len(mel_dataset)
assert n_frames % self.batch_size == 0
face_loader = await self._awrap_dataset(face_dataset)
mel_loader = await self._awrap_dataset(mel_dataset)
return padded, face_loader, mel_loader
def _aget_reader(self, num_skip_frames):
assert self.template_frames_path.exists()
return get_image_folder_async_process_reader(
data_path=self.template_frames_path,
num_skip_frames=num_skip_frames,
preload=self.batch_size,
)
def _awrap_reader(self, reader):
reader = acycle(reader)
return reader
def aget_reader(self, num_skip_frames=0):
reader = self._aget_reader(num_skip_frames=num_skip_frames)
reader = self._awrap_reader(reader)
return reader