yerang's picture
Upload 1110 files
e3af00f verified
raw
history blame
8.04 kB
import asyncio
from itertools import chain
import numpy as np
import torch
from pydub import AudioSegment, silence
def check_split_lengths(silent_ranges, len_audio):
prev_end = 0
for idx, (start, end) in enumerate(silent_ranges):
if idx < len(silent_ranges) - 1:
if silent_ranges[idx + 1][0] - start > 70000:
return False
else:
if len_audio - start > 70000:
return False
return True
def load_and_split_audio_by_silence(
audio_segment,
silence_thresh: int = -75,
min_silence_len: int = 500,
min_chunk_length_ms: int = 40,
seek_step: int = 100,
verbose: bool = False,
):
audio_segment = audio_segment.set_channels(1)
audio_segment = audio_segment.set_frame_rate(16000)
for st in range(silence_thresh, -50, 5):
for msl in range(min_silence_len, 0, -100):
silent_ranges = silence.detect_silence(
audio_segment, msl, st, seek_step=seek_step
)
length_ok = check_split_lengths(silent_ranges, len(audio_segment))
if length_ok:
break
if len(silent_ranges) > 0 and length_ok:
break
if (
len(silent_ranges) == 0
and len(audio_segment) < 70000
and len(audio_segment) >= 40
):
return [audio_segment]
assert (
length_ok and len(silent_ranges) > 0
), "Each sentence must be within 70 seconds, including silence"
audio_chunks = []
prev_end = 0
for idx, (start, end) in enumerate(silent_ranges):
if idx < len(silent_ranges) - 1:
chunk_length = silent_ranges[idx + 1][0] - prev_end
silence_length = end - prev_end
chunk_length_samples = (
chunk_length * 16
) # Convert ms to samples (16000 samples/sec)
if idx == 0:
target_length_samples = (chunk_length_samples // 320 + 1) * 320 + 80
else:
target_length_samples = (chunk_length_samples // 320 + 1) * 320
target_length = target_length_samples // 16 # Convert samples back to ms
adjusted_end = prev_end + target_length
else:
silence_length = (
silent_ranges[-1][1] - prev_end
if silent_ranges[-1][1] != len(audio_segment)
else 0
)
adjusted_end = len(audio_segment)
silence_length_split = max(0, (silence_length - 300)) # ms
if silence_length_split <= 0:
silence_chunk = None
chunk = audio_segment[prev_end if idx == 0 else prev_end - 5 : adjusted_end]
else:
silence_length_samples = (
silence_length_split * 16
) # Convert ms to samples (16000 samples/sec)
if idx == 0:
target_length_samples = (silence_length_samples // 320 + 1) * 320 + 80
else:
target_length_samples = (silence_length_samples // 320 + 1) * 320
silence_length_split = (
target_length_samples // 16
) # Convert samples back to ms
silence_chunk = audio_segment[
prev_end if idx == 0 else prev_end - 5 : prev_end + silence_length_split
]
chunk = audio_segment[prev_end + silence_length_split - 5 : adjusted_end]
if len(chunk) >= min_chunk_length_ms:
if silence_chunk is not None:
audio_chunks.append(silence_chunk)
audio_chunks.append(chunk)
else:
if audio_chunks:
if silence_chunk is not None:
audio_chunks[-1] += silence_chunk
audio_chunks[-1] += chunk
prev_end = adjusted_end
return audio_chunks
def process_audio_chunks(
audio_processor, audio_encoder, audio_chunks: list[AudioSegment], device
):
features_list = []
for audio_chunk in audio_chunks:
features = process_audio_chunk(
audio_processor, audio_encoder, audio_chunk, device
)
features_list.append(features)
return features_list
def process_audio_chunk(audio_processor, audio_encoder, audio_chunk, device):
audio_data = np.array(audio_chunk.get_array_of_samples(), dtype=np.float32)
audio_data /= np.iinfo(
np.int8
if audio_chunk.sample_width == 1
else np.int16
if audio_chunk.sample_width == 2
else np.int32
).max
input_values = audio_processor(
audio_data, sampling_rate=16000, return_tensors="pt"
).to(device)["input_values"]
with torch.no_grad():
logits = audio_encoder(input_values=input_values)
return logits.last_hidden_state[0]
def audio_encode(model, audio_segment, device):
audio_chunks = load_and_split_audio_by_silence(audio_segment)
features_list = process_audio_chunks(
model.audio_processor, model.audio_encoder, audio_chunks, device
)
concatenated_features = torch.cat(features_list, dim=0)
return concatenated_features.detach().cpu().numpy()
def dictzip(*iterators):
try:
while True:
yield dict(chain(*[next(iterator).items() for iterator in iterators]))
except StopIteration as e:
pass
async def adictzip(*aiterators):
try:
while True:
yield dict(
chain(*[(await anext(aiterator)).items() for aiterator in aiterators])
)
except StopAsyncIteration as e:
pass
def to_img(t):
t = t.permute(0, 2, 3, 1)
img = ((t / 2.0) + 0.5) * 255.0
img = torch.clip(img, 0.0, 255.0).type(torch.uint8)
img = img.cpu().numpy()
img = img[:, :, :, [2, 1, 0]]
return img
def inference_model(model, v, device, verbose=False):
with torch.no_grad():
mel, ips, mask, alpha = (
v["mel"],
v["ips"],
v["mask"],
v["img_gt_with_alpha"],
)
cpu_ips = ips
cpu_alpha = alpha
audio = mel.to(device)
ips = ips.to(device).permute(0, 3, 1, 2)
pred = model.model(ips, audio)
gen_face = to_img(pred)
return [
{
"pred": o,
"mask": mask[j].numpy(),
"ips": cpu_ips[j].numpy(),
"img_gt_with_alpha": cpu_alpha[j].numpy(),
"filename": v["filename"][j],
}
for j, o in enumerate(gen_face)
]
def inference_model_remote(model, v, device, verbose=False):
ips, mel = v["ips"], v["mel"]
try:
pred = model.model(
ips=ips,
mel=mel,
)
return postprocess_result(pred, v)
except Exception as e:
return [None] * len(v["filename"])
def postprocess_result(pred, v):
pred = pred.cpu().numpy()
pred = pred.transpose(0, 2, 3, 1)
pred = pred[:, :, :, [2, 1, 0]]
return [
{
"pred": o,
"mask": v["mask"][j].numpy(),
"img_gt_with_alpha": v["img_gt_with_alpha"][j].numpy(),
"filename": v["filename"][j],
}
for j, o in enumerate(pred)
]
async def ainference_model_remote(pool, model, v, device, verbose=False):
ips, mel = v["ips"], v["mel"]
try:
pred = await model.model(
ips=ips,
mel=mel,
)
loop = asyncio.get_running_loop()
return await loop.run_in_executor(pool, postprocess_result, pred, v)
except Exception as e:
return [None] * len(v["filename"])
def get_head_box(df, move=False, head_box_idx=0, template_ratio=1.0):
# sz = df['cropped_size'].values[0]
# 원래 4k 템플릿에서 축소된 비율만큼 cropped_box 크기를 줄여준다.
if move:
x1, y1, x2, y2 = np.array(df["cropped_box"][head_box_idx])
else:
x1, y1, x2, y2 = np.round(
np.array(df["cropped_box"].values[0]) * template_ratio
).astype(np.uint8)
return x1, y1, x2, y2