import asyncio import json from pathlib import Path import asyncstdlib import numpy as np import pandas as pd from pydub import AudioSegment from stf_alternative.compose import get_compose_func_without_keying, get_keying_func from stf_alternative.dataset import LipGanAudio, LipGanImage, LipGanRemoteImage from stf_alternative.inference import ( adictzip, ainference_model_remote, audio_encode, dictzip, get_head_box, inference_model, inference_model_remote, ) from stf_alternative.preprocess_dir.utils import face_finder as ff from stf_alternative.readers import ( AsyncProcessPoolBatchIterator, ProcessPoolBatchIterator, get_image_folder_async_process_reader, get_image_folder_process_reader, ) from stf_alternative.util import ( acycle, get_crop_mp4_dir, get_frame_dir, get_preprocess_dir, icycle, read_config, ) def calc_audio_std(audio_segment): sample = np.array(audio_segment.get_array_of_samples(), dtype=np.int16) max_value = np.iinfo( np.int8 if audio_segment.sample_width == 1 else np.int16 if audio_segment.sample_width == 2 else np.int32 ).max return sample.std() / max_value, len(sample) class RunningAudioNormalizer: def __init__(self, ref_audio_segment, decay_rate=0.01): self.ref_std, _ = calc_audio_std(ref_audio_segment) self.running_var = np.float64(0) self.running_cnt = 0 self.decay_rate = decay_rate def __call__(self, audio_segment): std, cnt = calc_audio_std(audio_segment) self.running_var = (self.running_var + (std**2) * cnt) * (1 - self.decay_rate) self.running_cnt = (self.running_cnt + cnt) * (1 - self.decay_rate) return audio_segment._spawn( (audio_segment.get_array_of_samples() / self.std * self.ref_std) .astype(np.int16) .tobytes() ) @property def std(self): return np.sqrt(self.running_var / self.running_cnt) def get_video_metadata(preprocess_dir): json_path = preprocess_dir / "metadata.json" with open(json_path, "r") as f: return json.load(f) class Template: def __init__( self, config_path, model, template_video_path, wav_std=False, ref_wav=None, verbose=False, ): self.config = read_config(config_path) self.model = model self.template_video_path = Path(template_video_path) self.preprocess_dir = Path( get_preprocess_dir(model.work_root_path, model.args.name) ) self.crop_mp4_dir = Path( get_crop_mp4_dir(self.preprocess_dir, template_video_path) ) self.dataset_dir = self.crop_mp4_dir / f"{Path(template_video_path).stem}_000" self.template_frames_path = Path( get_frame_dir(self.preprocess_dir, template_video_path, ratio=1.0) ) self.verbose = verbose self.remote = self.model.args.model_type == "remote" self.audio_normalizer = ( RunningAudioNormalizer(ref_wav) if wav_std else lambda x: x ) self.df = pd.read_pickle(self.dataset_dir / "df_fan.pickle") metadata = get_video_metadata(self.preprocess_dir) self.fps = metadata["fps"] self.width, self.height = metadata["width"], metadata["height"] self.keying_func = get_keying_func(self) self.compose_func = get_compose_func_without_keying(self, ratio=1.0) self.move = "move" in self.config.keys() and self.config.move self.inference_func = inference_model_remote if self.remote else inference_model self.batch_size = self.model.args.batch_size self.unit = 1000 / self.fps def _get_reader(self, num_skip_frames): assert self.template_frames_path.exists() return get_image_folder_process_reader( data_path=self.template_frames_path, num_skip_frames=num_skip_frames, preload=self.batch_size, ) def _get_local_face_dataset(self, num_skip_frames): return LipGanImage( args=self.model.args, path=self.dataset_dir, num_skip_frames=num_skip_frames, ) def _get_remote_face_dataset(self, num_skip_frames): return LipGanRemoteImage( args=self.model.args, path=self.dataset_dir, num_skip_frames=num_skip_frames, ) def _get_mel_dataset(self, audio_segment): image_count = round( audio_segment.duration_seconds * self.fps ) # 패딩 했기 때문에 batch_size로 나뉜다 ids = list(range(image_count)) mel = audio_encode( model=self.model, audio_segment=audio_segment, device=self.model.device, ) return LipGanAudio( args=self.model.args, id_list=ids, mel=mel, fps=self.fps, ) def _get_face_dataset(self, num_skip_frames): if self.remote: return self._get_remote_face_dataset(num_skip_frames=num_skip_frames) else: return self._get_local_face_dataset(num_skip_frames=num_skip_frames) def _wrap_reader(self, reader): reader = icycle(reader) return reader def _wrap_dataset(self, dataset): dataloader = ProcessPoolBatchIterator( dataset=dataset, batch_size=self.batch_size, ) return dataloader def get_reader(self, num_skip_frames=0): reader = self._get_reader(num_skip_frames=num_skip_frames) reader = self._wrap_reader(reader) return reader def get_mel_loader(self, audio_segment): mel_dataset = self._get_mel_dataset(audio_segment) return self._wrap_dataset(mel_dataset) def get_face_loader(self, num_skip_frames=0): face_dataset = self._get_face_dataset(num_skip_frames=num_skip_frames) return self._wrap_dataset(face_dataset) # need cycle # padding according to batch size. def pad(self, audio_segment): num_frames = audio_segment.duration_seconds * self.fps pad = AudioSegment.silent( (self.batch_size - (num_frames % self.batch_size)) * (1000 / self.fps) ) return audio_segment + pad def _prepare_data( self, audio_segment, video_start_offset_frame, ): video_start_offset_frame = video_start_offset_frame % len(self.df) padded = self.pad(audio_segment) face_dataset = self._get_face_dataset(num_skip_frames=video_start_offset_frame) mel_dataset = self._get_mel_dataset(audio_segment=padded) n_frames = len(mel_dataset) assert n_frames % self.batch_size == 0 face_loader = self._wrap_dataset(face_dataset) mel_loader = self._wrap_dataset(mel_dataset) return padded, face_loader, mel_loader def gen_infer( self, audio_segment, video_start_offset_frame, ): padded, face_loader, mel_loader = self._prepare_data( audio_segment=audio_segment, video_start_offset_frame=video_start_offset_frame, ) for i, v in enumerate(dictzip(iter(mel_loader), iter(face_loader))): inferred = self.inference_func(self.model, v, self.model.device) for j, it in enumerate(inferred): chunk_pivot = i * self.unit * self.batch_size + j * self.unit chunk = padded[chunk_pivot : chunk_pivot + self.unit] yield it, chunk def gen_infer_batch( self, audio_segment, video_start_offset_frame, ): padded, face_loader, mel_loader = self._prepare_data( audio_segment=audio_segment, video_start_offset_frame=video_start_offset_frame, ) for i, v in enumerate(dictzip(iter(mel_loader), iter(face_loader))): inferred = self.inference_func(self.model, v, self.model.device) yield inferred, padded[ i * self.unit * self.batch_size : (i + 1) * self.unit * self.batch_size ] def gen_infer_batch_future( self, pool, audio_segment, video_start_offset_frame, ): padded, face_loader, mel_loader = self._prepare_data( audio_segment=audio_segment, video_start_offset_frame=video_start_offset_frame, ) futures = [] for i, v in enumerate(dictzip(iter(mel_loader), iter(face_loader))): futures.append( pool.submit(self.inference_func, self.model, v, self.model.device) ) for i, future in enumerate(futures): yield future, padded[ i * self.unit * self.batch_size : (i + 1) * self.unit * self.batch_size ] def gen_infer_concurrent( self, pool, audio_segment, video_start_offset_frame, ): for future, chunk in self.gen_infer_batch_future( pool, audio_segment, video_start_offset_frame ): for i, inferred in enumerate(future.result()): yield inferred, chunk[i * self.unit : (i + 1) * self.unit] def compose( self, idx, frame, output, ): head_box_idx = idx % len(self.df) head_box = get_head_box( self.df, move=self.move, head_box_idx=head_box_idx, ) alpha2 = self.keying_func(output, head_box_idx, head_box) frame = self.compose_func(alpha2, frame[:, :, :4], head_box_idx) return frame def gen_frames( self, audio_segment, video_start_offset_frame, reader=None, ): reader = reader or self.get_reader(num_skip_frames=video_start_offset_frame) gen_infer = self.gen_infer(audio_segment, video_start_offset_frame) for idx, ((o, a), f) in enumerate( zip(gen_infer, reader), video_start_offset_frame ): composed = self.compose(idx, f, o) yield composed, a def gen_frames_concurrent( self, pool, audio_segment, video_start_offset_frame, reader=None, ): reader = reader or self.get_reader(num_skip_frames=video_start_offset_frame) gen_infer = self.gen_infer_concurrent( pool, audio_segment, video_start_offset_frame, ) for idx, ((o, a), f) in enumerate( zip(gen_infer, reader), video_start_offset_frame ): yield self.compose(idx, f, o), a class AsyncTemplate(Template): async def agen_infer_batch_future( self, pool, audio_segment, video_start_offset_frame, ): assert self.remote padded, face_loader, mel_loader = await self._aprepare_data( pool, audio_segment=audio_segment, video_start_offset_frame=video_start_offset_frame, ) futures = [] async for i, v in asyncstdlib.enumerate( adictzip(aiter(mel_loader), aiter(face_loader)) ): futures.append( asyncio.create_task( ainference_model_remote(pool, self.model, v, self.model.device) ) ) for i, future in enumerate(futures): yield future, padded[ i * self.unit * self.batch_size : (i + 1) * self.unit * self.batch_size ] async def _awrap_dataset(self, dataset): dataloader = AsyncProcessPoolBatchIterator( dataset=dataset, batch_size=self.batch_size, ) return dataloader async def _aprepare_data( self, pool, audio_segment, video_start_offset_frame, ): video_start_offset_frame = video_start_offset_frame % len(self.df) padded = self.pad(audio_segment) loop = asyncio.get_running_loop() face_dataset, mel_dataset = await asyncio.gather( loop.run_in_executor( pool, self._get_face_dataset, video_start_offset_frame ), loop.run_in_executor(pool, self._get_mel_dataset, padded), ) n_frames = len(mel_dataset) assert n_frames % self.batch_size == 0 face_loader = await self._awrap_dataset(face_dataset) mel_loader = await self._awrap_dataset(mel_dataset) return padded, face_loader, mel_loader def _aget_reader(self, num_skip_frames): assert self.template_frames_path.exists() return get_image_folder_async_process_reader( data_path=self.template_frames_path, num_skip_frames=num_skip_frames, preload=self.batch_size, ) def _awrap_reader(self, reader): reader = acycle(reader) return reader def aget_reader(self, num_skip_frames=0): reader = self._aget_reader(num_skip_frames=num_skip_frames) reader = self._awrap_reader(reader) return reader