import asyncio from itertools import chain import numpy as np import torch from pydub import AudioSegment, silence def check_split_lengths(silent_ranges, len_audio): prev_end = 0 for idx, (start, end) in enumerate(silent_ranges): if idx < len(silent_ranges) - 1: if silent_ranges[idx + 1][0] - start > 70000: return False else: if len_audio - start > 70000: return False return True def load_and_split_audio_by_silence( audio_segment, silence_thresh: int = -75, min_silence_len: int = 500, min_chunk_length_ms: int = 40, seek_step: int = 100, verbose: bool = False, ): audio_segment = audio_segment.set_channels(1) audio_segment = audio_segment.set_frame_rate(16000) for st in range(silence_thresh, -50, 5): for msl in range(min_silence_len, 0, -100): silent_ranges = silence.detect_silence( audio_segment, msl, st, seek_step=seek_step ) length_ok = check_split_lengths(silent_ranges, len(audio_segment)) if length_ok: break if len(silent_ranges) > 0 and length_ok: break if ( len(silent_ranges) == 0 and len(audio_segment) < 70000 and len(audio_segment) >= 40 ): return [audio_segment] assert ( length_ok and len(silent_ranges) > 0 ), "Each sentence must be within 70 seconds, including silence" audio_chunks = [] prev_end = 0 for idx, (start, end) in enumerate(silent_ranges): if idx < len(silent_ranges) - 1: chunk_length = silent_ranges[idx + 1][0] - prev_end silence_length = end - prev_end chunk_length_samples = ( chunk_length * 16 ) # Convert ms to samples (16000 samples/sec) if idx == 0: target_length_samples = (chunk_length_samples // 320 + 1) * 320 + 80 else: target_length_samples = (chunk_length_samples // 320 + 1) * 320 target_length = target_length_samples // 16 # Convert samples back to ms adjusted_end = prev_end + target_length else: silence_length = ( silent_ranges[-1][1] - prev_end if silent_ranges[-1][1] != len(audio_segment) else 0 ) adjusted_end = len(audio_segment) silence_length_split = max(0, (silence_length - 300)) # ms if silence_length_split <= 0: silence_chunk = None chunk = audio_segment[prev_end if idx == 0 else prev_end - 5 : adjusted_end] else: silence_length_samples = ( silence_length_split * 16 ) # Convert ms to samples (16000 samples/sec) if idx == 0: target_length_samples = (silence_length_samples // 320 + 1) * 320 + 80 else: target_length_samples = (silence_length_samples // 320 + 1) * 320 silence_length_split = ( target_length_samples // 16 ) # Convert samples back to ms silence_chunk = audio_segment[ prev_end if idx == 0 else prev_end - 5 : prev_end + silence_length_split ] chunk = audio_segment[prev_end + silence_length_split - 5 : adjusted_end] if len(chunk) >= min_chunk_length_ms: if silence_chunk is not None: audio_chunks.append(silence_chunk) audio_chunks.append(chunk) else: if audio_chunks: if silence_chunk is not None: audio_chunks[-1] += silence_chunk audio_chunks[-1] += chunk prev_end = adjusted_end return audio_chunks def process_audio_chunks( audio_processor, audio_encoder, audio_chunks: list[AudioSegment], device ): features_list = [] for audio_chunk in audio_chunks: features = process_audio_chunk( audio_processor, audio_encoder, audio_chunk, device ) features_list.append(features) return features_list def process_audio_chunk(audio_processor, audio_encoder, audio_chunk, device): audio_data = np.array(audio_chunk.get_array_of_samples(), dtype=np.float32) audio_data /= np.iinfo( np.int8 if audio_chunk.sample_width == 1 else np.int16 if audio_chunk.sample_width == 2 else np.int32 ).max input_values = audio_processor( audio_data, sampling_rate=16000, return_tensors="pt" ).to(device)["input_values"] with torch.no_grad(): logits = audio_encoder(input_values=input_values) return logits.last_hidden_state[0] def audio_encode(model, audio_segment, device): audio_chunks = load_and_split_audio_by_silence(audio_segment) features_list = process_audio_chunks( model.audio_processor, model.audio_encoder, audio_chunks, device ) concatenated_features = torch.cat(features_list, dim=0) return concatenated_features.detach().cpu().numpy() def dictzip(*iterators): try: while True: yield dict(chain(*[next(iterator).items() for iterator in iterators])) except StopIteration as e: pass async def adictzip(*aiterators): try: while True: yield dict( chain(*[(await anext(aiterator)).items() for aiterator in aiterators]) ) except StopAsyncIteration as e: pass def to_img(t): t = t.permute(0, 2, 3, 1) img = ((t / 2.0) + 0.5) * 255.0 img = torch.clip(img, 0.0, 255.0).type(torch.uint8) img = img.cpu().numpy() img = img[:, :, :, [2, 1, 0]] return img def inference_model(model, v, device, verbose=False): with torch.no_grad(): mel, ips, mask, alpha = ( v["mel"], v["ips"], v["mask"], v["img_gt_with_alpha"], ) cpu_ips = ips cpu_alpha = alpha audio = mel.to(device) ips = ips.to(device).permute(0, 3, 1, 2) pred = model.model(ips, audio) gen_face = to_img(pred) return [ { "pred": o, "mask": mask[j].numpy(), "ips": cpu_ips[j].numpy(), "img_gt_with_alpha": cpu_alpha[j].numpy(), "filename": v["filename"][j], } for j, o in enumerate(gen_face) ] def inference_model_remote(model, v, device, verbose=False): ips, mel = v["ips"], v["mel"] try: pred = model.model( ips=ips, mel=mel, ) return postprocess_result(pred, v) except Exception as e: return [None] * len(v["filename"]) def postprocess_result(pred, v): pred = pred.cpu().numpy() pred = pred.transpose(0, 2, 3, 1) pred = pred[:, :, :, [2, 1, 0]] return [ { "pred": o, "mask": v["mask"][j].numpy(), "img_gt_with_alpha": v["img_gt_with_alpha"][j].numpy(), "filename": v["filename"][j], } for j, o in enumerate(pred) ] async def ainference_model_remote(pool, model, v, device, verbose=False): ips, mel = v["ips"], v["mel"] try: pred = await model.model( ips=ips, mel=mel, ) loop = asyncio.get_running_loop() return await loop.run_in_executor(pool, postprocess_result, pred, v) except Exception as e: return [None] * len(v["filename"]) def get_head_box(df, move=False, head_box_idx=0, template_ratio=1.0): # sz = df['cropped_size'].values[0] # 원래 4k 템플릿에서 축소된 비율만큼 cropped_box 크기를 줄여준다. if move: x1, y1, x2, y2 = np.array(df["cropped_box"][head_box_idx]) else: x1, y1, x2, y2 = np.round( np.array(df["cropped_box"].values[0]) * template_ratio ).astype(np.uint8) return x1, y1, x2, y2