Spaces:
Sleeping
Sleeping
import asyncio | |
import json | |
from pathlib import Path | |
import asyncstdlib | |
import numpy as np | |
import pandas as pd | |
from pydub import AudioSegment | |
from stf_alternative.compose import get_compose_func_without_keying, get_keying_func | |
from stf_alternative.dataset import LipGanAudio, LipGanImage, LipGanRemoteImage | |
from stf_alternative.inference import ( | |
adictzip, | |
ainference_model_remote, | |
audio_encode, | |
dictzip, | |
get_head_box, | |
inference_model, | |
inference_model_remote, | |
) | |
from stf_alternative.preprocess_dir.utils import face_finder as ff | |
from stf_alternative.readers import ( | |
AsyncProcessPoolBatchIterator, | |
ProcessPoolBatchIterator, | |
get_image_folder_async_process_reader, | |
get_image_folder_process_reader, | |
) | |
from stf_alternative.util import ( | |
acycle, | |
get_crop_mp4_dir, | |
get_frame_dir, | |
get_preprocess_dir, | |
icycle, | |
read_config, | |
) | |
def calc_audio_std(audio_segment): | |
sample = np.array(audio_segment.get_array_of_samples(), dtype=np.int16) | |
max_value = np.iinfo( | |
np.int8 | |
if audio_segment.sample_width == 1 | |
else np.int16 | |
if audio_segment.sample_width == 2 | |
else np.int32 | |
).max | |
return sample.std() / max_value, len(sample) | |
class RunningAudioNormalizer: | |
def __init__(self, ref_audio_segment, decay_rate=0.01): | |
self.ref_std, _ = calc_audio_std(ref_audio_segment) | |
self.running_var = np.float64(0) | |
self.running_cnt = 0 | |
self.decay_rate = decay_rate | |
def __call__(self, audio_segment): | |
std, cnt = calc_audio_std(audio_segment) | |
self.running_var = (self.running_var + (std**2) * cnt) * (1 - self.decay_rate) | |
self.running_cnt = (self.running_cnt + cnt) * (1 - self.decay_rate) | |
return audio_segment._spawn( | |
(audio_segment.get_array_of_samples() / self.std * self.ref_std) | |
.astype(np.int16) | |
.tobytes() | |
) | |
def std(self): | |
return np.sqrt(self.running_var / self.running_cnt) | |
def get_video_metadata(preprocess_dir): | |
json_path = preprocess_dir / "metadata.json" | |
with open(json_path, "r") as f: | |
return json.load(f) | |
class Template: | |
def __init__( | |
self, | |
config_path, | |
model, | |
template_video_path, | |
wav_std=False, | |
ref_wav=None, | |
verbose=False, | |
): | |
self.config = read_config(config_path) | |
self.model = model | |
self.template_video_path = Path(template_video_path) | |
self.preprocess_dir = Path( | |
get_preprocess_dir(model.work_root_path, model.args.name) | |
) | |
self.crop_mp4_dir = Path( | |
get_crop_mp4_dir(self.preprocess_dir, template_video_path) | |
) | |
self.dataset_dir = self.crop_mp4_dir / f"{Path(template_video_path).stem}_000" | |
self.template_frames_path = Path( | |
get_frame_dir(self.preprocess_dir, template_video_path, ratio=1.0) | |
) | |
self.verbose = verbose | |
self.remote = self.model.args.model_type == "remote" | |
self.audio_normalizer = ( | |
RunningAudioNormalizer(ref_wav) if wav_std else lambda x: x | |
) | |
self.df = pd.read_pickle(self.dataset_dir / "df_fan.pickle") | |
metadata = get_video_metadata(self.preprocess_dir) | |
self.fps = metadata["fps"] | |
self.width, self.height = metadata["width"], metadata["height"] | |
self.keying_func = get_keying_func(self) | |
self.compose_func = get_compose_func_without_keying(self, ratio=1.0) | |
self.move = "move" in self.config.keys() and self.config.move | |
self.inference_func = inference_model_remote if self.remote else inference_model | |
self.batch_size = self.model.args.batch_size | |
self.unit = 1000 / self.fps | |
def _get_reader(self, num_skip_frames): | |
assert self.template_frames_path.exists() | |
return get_image_folder_process_reader( | |
data_path=self.template_frames_path, | |
num_skip_frames=num_skip_frames, | |
preload=self.batch_size, | |
) | |
def _get_local_face_dataset(self, num_skip_frames): | |
return LipGanImage( | |
args=self.model.args, | |
path=self.dataset_dir, | |
num_skip_frames=num_skip_frames, | |
) | |
def _get_remote_face_dataset(self, num_skip_frames): | |
return LipGanRemoteImage( | |
args=self.model.args, | |
path=self.dataset_dir, | |
num_skip_frames=num_skip_frames, | |
) | |
def _get_mel_dataset(self, audio_segment): | |
image_count = round( | |
audio_segment.duration_seconds * self.fps | |
) # 패딩 했기 때문에 batch_size로 나뉜다 | |
ids = list(range(image_count)) | |
mel = audio_encode( | |
model=self.model, | |
audio_segment=audio_segment, | |
device=self.model.device, | |
) | |
return LipGanAudio( | |
args=self.model.args, | |
id_list=ids, | |
mel=mel, | |
fps=self.fps, | |
) | |
def _get_face_dataset(self, num_skip_frames): | |
if self.remote: | |
return self._get_remote_face_dataset(num_skip_frames=num_skip_frames) | |
else: | |
return self._get_local_face_dataset(num_skip_frames=num_skip_frames) | |
def _wrap_reader(self, reader): | |
reader = icycle(reader) | |
return reader | |
def _wrap_dataset(self, dataset): | |
dataloader = ProcessPoolBatchIterator( | |
dataset=dataset, | |
batch_size=self.batch_size, | |
) | |
return dataloader | |
def get_reader(self, num_skip_frames=0): | |
reader = self._get_reader(num_skip_frames=num_skip_frames) | |
reader = self._wrap_reader(reader) | |
return reader | |
def get_mel_loader(self, audio_segment): | |
mel_dataset = self._get_mel_dataset(audio_segment) | |
return self._wrap_dataset(mel_dataset) | |
def get_face_loader(self, num_skip_frames=0): | |
face_dataset = self._get_face_dataset(num_skip_frames=num_skip_frames) | |
return self._wrap_dataset(face_dataset) # need cycle | |
# padding according to batch size. | |
def pad(self, audio_segment): | |
num_frames = audio_segment.duration_seconds * self.fps | |
pad = AudioSegment.silent( | |
(self.batch_size - (num_frames % self.batch_size)) * (1000 / self.fps) | |
) | |
return audio_segment + pad | |
def _prepare_data( | |
self, | |
audio_segment, | |
video_start_offset_frame, | |
): | |
video_start_offset_frame = video_start_offset_frame % len(self.df) | |
padded = self.pad(audio_segment) | |
face_dataset = self._get_face_dataset(num_skip_frames=video_start_offset_frame) | |
mel_dataset = self._get_mel_dataset(audio_segment=padded) | |
n_frames = len(mel_dataset) | |
assert n_frames % self.batch_size == 0 | |
face_loader = self._wrap_dataset(face_dataset) | |
mel_loader = self._wrap_dataset(mel_dataset) | |
return padded, face_loader, mel_loader | |
def gen_infer( | |
self, | |
audio_segment, | |
video_start_offset_frame, | |
): | |
padded, face_loader, mel_loader = self._prepare_data( | |
audio_segment=audio_segment, | |
video_start_offset_frame=video_start_offset_frame, | |
) | |
for i, v in enumerate(dictzip(iter(mel_loader), iter(face_loader))): | |
inferred = self.inference_func(self.model, v, self.model.device) | |
for j, it in enumerate(inferred): | |
chunk_pivot = i * self.unit * self.batch_size + j * self.unit | |
chunk = padded[chunk_pivot : chunk_pivot + self.unit] | |
yield it, chunk | |
def gen_infer_batch( | |
self, | |
audio_segment, | |
video_start_offset_frame, | |
): | |
padded, face_loader, mel_loader = self._prepare_data( | |
audio_segment=audio_segment, | |
video_start_offset_frame=video_start_offset_frame, | |
) | |
for i, v in enumerate(dictzip(iter(mel_loader), iter(face_loader))): | |
inferred = self.inference_func(self.model, v, self.model.device) | |
yield inferred, padded[ | |
i * self.unit * self.batch_size : (i + 1) * self.unit * self.batch_size | |
] | |
def gen_infer_batch_future( | |
self, | |
pool, | |
audio_segment, | |
video_start_offset_frame, | |
): | |
padded, face_loader, mel_loader = self._prepare_data( | |
audio_segment=audio_segment, | |
video_start_offset_frame=video_start_offset_frame, | |
) | |
futures = [] | |
for i, v in enumerate(dictzip(iter(mel_loader), iter(face_loader))): | |
futures.append( | |
pool.submit(self.inference_func, self.model, v, self.model.device) | |
) | |
for i, future in enumerate(futures): | |
yield future, padded[ | |
i * self.unit * self.batch_size : (i + 1) * self.unit * self.batch_size | |
] | |
def gen_infer_concurrent( | |
self, | |
pool, | |
audio_segment, | |
video_start_offset_frame, | |
): | |
for future, chunk in self.gen_infer_batch_future( | |
pool, audio_segment, video_start_offset_frame | |
): | |
for i, inferred in enumerate(future.result()): | |
yield inferred, chunk[i * self.unit : (i + 1) * self.unit] | |
def compose( | |
self, | |
idx, | |
frame, | |
output, | |
): | |
head_box_idx = idx % len(self.df) | |
head_box = get_head_box( | |
self.df, | |
move=self.move, | |
head_box_idx=head_box_idx, | |
) | |
alpha2 = self.keying_func(output, head_box_idx, head_box) | |
frame = self.compose_func(alpha2, frame[:, :, :4], head_box_idx) | |
return frame | |
def gen_frames( | |
self, | |
audio_segment, | |
video_start_offset_frame, | |
reader=None, | |
): | |
reader = reader or self.get_reader(num_skip_frames=video_start_offset_frame) | |
gen_infer = self.gen_infer(audio_segment, video_start_offset_frame) | |
for idx, ((o, a), f) in enumerate( | |
zip(gen_infer, reader), video_start_offset_frame | |
): | |
composed = self.compose(idx, f, o) | |
yield composed, a | |
def gen_frames_concurrent( | |
self, | |
pool, | |
audio_segment, | |
video_start_offset_frame, | |
reader=None, | |
): | |
reader = reader or self.get_reader(num_skip_frames=video_start_offset_frame) | |
gen_infer = self.gen_infer_concurrent( | |
pool, | |
audio_segment, | |
video_start_offset_frame, | |
) | |
for idx, ((o, a), f) in enumerate( | |
zip(gen_infer, reader), video_start_offset_frame | |
): | |
yield self.compose(idx, f, o), a | |
class AsyncTemplate(Template): | |
async def agen_infer_batch_future( | |
self, | |
pool, | |
audio_segment, | |
video_start_offset_frame, | |
): | |
assert self.remote | |
padded, face_loader, mel_loader = await self._aprepare_data( | |
pool, | |
audio_segment=audio_segment, | |
video_start_offset_frame=video_start_offset_frame, | |
) | |
futures = [] | |
async for i, v in asyncstdlib.enumerate( | |
adictzip(aiter(mel_loader), aiter(face_loader)) | |
): | |
futures.append( | |
asyncio.create_task( | |
ainference_model_remote(pool, self.model, v, self.model.device) | |
) | |
) | |
for i, future in enumerate(futures): | |
yield future, padded[ | |
i * self.unit * self.batch_size : (i + 1) * self.unit * self.batch_size | |
] | |
async def _awrap_dataset(self, dataset): | |
dataloader = AsyncProcessPoolBatchIterator( | |
dataset=dataset, | |
batch_size=self.batch_size, | |
) | |
return dataloader | |
async def _aprepare_data( | |
self, | |
pool, | |
audio_segment, | |
video_start_offset_frame, | |
): | |
video_start_offset_frame = video_start_offset_frame % len(self.df) | |
padded = self.pad(audio_segment) | |
loop = asyncio.get_running_loop() | |
face_dataset, mel_dataset = await asyncio.gather( | |
loop.run_in_executor( | |
pool, self._get_face_dataset, video_start_offset_frame | |
), | |
loop.run_in_executor(pool, self._get_mel_dataset, padded), | |
) | |
n_frames = len(mel_dataset) | |
assert n_frames % self.batch_size == 0 | |
face_loader = await self._awrap_dataset(face_dataset) | |
mel_loader = await self._awrap_dataset(mel_dataset) | |
return padded, face_loader, mel_loader | |
def _aget_reader(self, num_skip_frames): | |
assert self.template_frames_path.exists() | |
return get_image_folder_async_process_reader( | |
data_path=self.template_frames_path, | |
num_skip_frames=num_skip_frames, | |
preload=self.batch_size, | |
) | |
def _awrap_reader(self, reader): | |
reader = acycle(reader) | |
return reader | |
def aget_reader(self, num_skip_frames=0): | |
reader = self._aget_reader(num_skip_frames=num_skip_frames) | |
reader = self._awrap_reader(reader) | |
return reader | |