Spaces:
Sleeping
Sleeping
LivePortrait2
/
stf
/stf-api-alternative
/src
/stf_alternative
/.ipynb_checkpoints
/inference-checkpoint.py
import asyncio | |
from itertools import chain | |
import numpy as np | |
import torch | |
from pydub import AudioSegment, silence | |
def check_split_lengths(silent_ranges, len_audio): | |
prev_end = 0 | |
for idx, (start, end) in enumerate(silent_ranges): | |
if idx < len(silent_ranges) - 1: | |
if silent_ranges[idx + 1][0] - start > 70000: | |
return False | |
else: | |
if len_audio - start > 70000: | |
return False | |
return True | |
def load_and_split_audio_by_silence( | |
audio_segment, | |
silence_thresh: int = -75, | |
min_silence_len: int = 500, | |
min_chunk_length_ms: int = 40, | |
seek_step: int = 100, | |
verbose: bool = False, | |
): | |
audio_segment = audio_segment.set_channels(1) | |
audio_segment = audio_segment.set_frame_rate(16000) | |
for st in range(silence_thresh, -50, 5): | |
for msl in range(min_silence_len, 0, -100): | |
silent_ranges = silence.detect_silence( | |
audio_segment, msl, st, seek_step=seek_step | |
) | |
length_ok = check_split_lengths(silent_ranges, len(audio_segment)) | |
if length_ok: | |
break | |
if len(silent_ranges) > 0 and length_ok: | |
break | |
if ( | |
len(silent_ranges) == 0 | |
and len(audio_segment) < 70000 | |
and len(audio_segment) >= 40 | |
): | |
return [audio_segment] | |
assert ( | |
length_ok and len(silent_ranges) > 0 | |
), "Each sentence must be within 70 seconds, including silence" | |
audio_chunks = [] | |
prev_end = 0 | |
for idx, (start, end) in enumerate(silent_ranges): | |
if idx < len(silent_ranges) - 1: | |
chunk_length = silent_ranges[idx + 1][0] - prev_end | |
silence_length = end - prev_end | |
chunk_length_samples = ( | |
chunk_length * 16 | |
) # Convert ms to samples (16000 samples/sec) | |
if idx == 0: | |
target_length_samples = (chunk_length_samples // 320 + 1) * 320 + 80 | |
else: | |
target_length_samples = (chunk_length_samples // 320 + 1) * 320 | |
target_length = target_length_samples // 16 # Convert samples back to ms | |
adjusted_end = prev_end + target_length | |
else: | |
silence_length = ( | |
silent_ranges[-1][1] - prev_end | |
if silent_ranges[-1][1] != len(audio_segment) | |
else 0 | |
) | |
adjusted_end = len(audio_segment) | |
silence_length_split = max(0, (silence_length - 300)) # ms | |
if silence_length_split <= 0: | |
silence_chunk = None | |
chunk = audio_segment[prev_end if idx == 0 else prev_end - 5 : adjusted_end] | |
else: | |
silence_length_samples = ( | |
silence_length_split * 16 | |
) # Convert ms to samples (16000 samples/sec) | |
if idx == 0: | |
target_length_samples = (silence_length_samples // 320 + 1) * 320 + 80 | |
else: | |
target_length_samples = (silence_length_samples // 320 + 1) * 320 | |
silence_length_split = ( | |
target_length_samples // 16 | |
) # Convert samples back to ms | |
silence_chunk = audio_segment[ | |
prev_end if idx == 0 else prev_end - 5 : prev_end + silence_length_split | |
] | |
chunk = audio_segment[prev_end + silence_length_split - 5 : adjusted_end] | |
if len(chunk) >= min_chunk_length_ms: | |
if silence_chunk is not None: | |
audio_chunks.append(silence_chunk) | |
audio_chunks.append(chunk) | |
else: | |
if audio_chunks: | |
if silence_chunk is not None: | |
audio_chunks[-1] += silence_chunk | |
audio_chunks[-1] += chunk | |
prev_end = adjusted_end | |
return audio_chunks | |
def process_audio_chunks( | |
audio_processor, audio_encoder, audio_chunks: list[AudioSegment], device | |
): | |
features_list = [] | |
for audio_chunk in audio_chunks: | |
features = process_audio_chunk( | |
audio_processor, audio_encoder, audio_chunk, device | |
) | |
features_list.append(features) | |
return features_list | |
def process_audio_chunk(audio_processor, audio_encoder, audio_chunk, device): | |
audio_data = np.array(audio_chunk.get_array_of_samples(), dtype=np.float32) | |
audio_data /= np.iinfo( | |
np.int8 | |
if audio_chunk.sample_width == 1 | |
else np.int16 | |
if audio_chunk.sample_width == 2 | |
else np.int32 | |
).max | |
input_values = audio_processor( | |
audio_data, sampling_rate=16000, return_tensors="pt" | |
).to(device)["input_values"] | |
with torch.no_grad(): | |
logits = audio_encoder(input_values=input_values) | |
return logits.last_hidden_state[0] | |
def audio_encode(model, audio_segment, device): | |
audio_chunks = load_and_split_audio_by_silence(audio_segment) | |
features_list = process_audio_chunks( | |
model.audio_processor, model.audio_encoder, audio_chunks, device | |
) | |
concatenated_features = torch.cat(features_list, dim=0) | |
return concatenated_features.detach().cpu().numpy() | |
def dictzip(*iterators): | |
try: | |
while True: | |
yield dict(chain(*[next(iterator).items() for iterator in iterators])) | |
except StopIteration as e: | |
pass | |
async def adictzip(*aiterators): | |
try: | |
while True: | |
yield dict( | |
chain(*[(await anext(aiterator)).items() for aiterator in aiterators]) | |
) | |
except StopAsyncIteration as e: | |
pass | |
def to_img(t): | |
t = t.permute(0, 2, 3, 1) | |
img = ((t / 2.0) + 0.5) * 255.0 | |
img = torch.clip(img, 0.0, 255.0).type(torch.uint8) | |
img = img.cpu().numpy() | |
img = img[:, :, :, [2, 1, 0]] | |
return img | |
def inference_model(model, v, device, verbose=False): | |
with torch.no_grad(): | |
mel, ips, mask, alpha = ( | |
v["mel"], | |
v["ips"], | |
v["mask"], | |
v["img_gt_with_alpha"], | |
) | |
cpu_ips = ips | |
cpu_alpha = alpha | |
audio = mel.to(device) | |
ips = ips.to(device).permute(0, 3, 1, 2) | |
pred = model.model(ips, audio) | |
gen_face = to_img(pred) | |
return [ | |
{ | |
"pred": o, | |
"mask": mask[j].numpy(), | |
"ips": cpu_ips[j].numpy(), | |
"img_gt_with_alpha": cpu_alpha[j].numpy(), | |
"filename": v["filename"][j], | |
} | |
for j, o in enumerate(gen_face) | |
] | |
def inference_model_remote(model, v, device, verbose=False): | |
ips, mel = v["ips"], v["mel"] | |
try: | |
pred = model.model( | |
ips=ips, | |
mel=mel, | |
) | |
return postprocess_result(pred, v) | |
except Exception as e: | |
return [None] * len(v["filename"]) | |
def postprocess_result(pred, v): | |
pred = pred.cpu().numpy() | |
pred = pred.transpose(0, 2, 3, 1) | |
pred = pred[:, :, :, [2, 1, 0]] | |
return [ | |
{ | |
"pred": o, | |
"mask": v["mask"][j].numpy(), | |
"img_gt_with_alpha": v["img_gt_with_alpha"][j].numpy(), | |
"filename": v["filename"][j], | |
} | |
for j, o in enumerate(pred) | |
] | |
async def ainference_model_remote(pool, model, v, device, verbose=False): | |
ips, mel = v["ips"], v["mel"] | |
try: | |
pred = await model.model( | |
ips=ips, | |
mel=mel, | |
) | |
loop = asyncio.get_running_loop() | |
return await loop.run_in_executor(pool, postprocess_result, pred, v) | |
except Exception as e: | |
return [None] * len(v["filename"]) | |
def get_head_box(df, move=False, head_box_idx=0, template_ratio=1.0): | |
# sz = df['cropped_size'].values[0] | |
# 원래 4k 템플릿에서 축소된 비율만큼 cropped_box 크기를 줄여준다. | |
if move: | |
x1, y1, x2, y2 = np.array(df["cropped_box"][head_box_idx]) | |
else: | |
x1, y1, x2, y2 = np.round( | |
np.array(df["cropped_box"].values[0]) * template_ratio | |
).astype(np.uint8) | |
return x1, y1, x2, y2 | |