LatentSync / latentsync /whisper /audio2feature.py
Francke's picture
t
24c345c
# Adapted from https://github.com/TMElyralab/MuseTalk/blob/main/musetalk/whisper/audio2feature.py
from .whisper import load_model
import numpy as np
import torch
import os
class Audio2Feature:
def __init__(
self,
model_path="checkpoints/whisper/tiny.pt",
device=None,
audio_embeds_cache_dir=None,
num_frames=16,
):
self.model = load_model(model_path, device)
self.audio_embeds_cache_dir = audio_embeds_cache_dir
self.num_frames = num_frames
self.embedding_dim = self.model.dims.n_audio_state
def get_sliced_feature(self, feature_array, vid_idx, audio_feat_length=[2, 2], fps=25):
"""
Get sliced features based on a given index
:param feature_array:
:param start_idx: the start index of the feature
:param audio_feat_length:
:return:
"""
length = len(feature_array)
selected_feature = []
selected_idx = []
center_idx = int(vid_idx * 50 / fps)
left_idx = center_idx - audio_feat_length[0] * 2
right_idx = center_idx + (audio_feat_length[1] + 1) * 2
for idx in range(left_idx, right_idx):
idx = max(0, idx)
idx = min(length - 1, idx)
x = feature_array[idx]
selected_feature.append(x)
selected_idx.append(idx)
selected_feature = torch.cat(selected_feature, dim=0)
selected_feature = selected_feature.reshape(-1, self.embedding_dim) # 50*384
return selected_feature, selected_idx
def get_sliced_feature_sparse(self, feature_array, vid_idx, audio_feat_length=[2, 2], fps=25):
"""
Get sliced features based on a given index
:param feature_array:
:param start_idx: the start index of the feature
:param audio_feat_length:
:return:
"""
length = len(feature_array)
selected_feature = []
selected_idx = []
for dt in range(-audio_feat_length[0], audio_feat_length[1] + 1):
left_idx = int((vid_idx + dt) * 50 / fps)
if left_idx < 1 or left_idx > length - 1:
left_idx = max(0, left_idx)
left_idx = min(length - 1, left_idx)
x = feature_array[left_idx]
x = x[np.newaxis, :, :]
x = np.repeat(x, 2, axis=0)
selected_feature.append(x)
selected_idx.append(left_idx)
selected_idx.append(left_idx)
else:
x = feature_array[left_idx - 1 : left_idx + 1]
selected_feature.append(x)
selected_idx.append(left_idx - 1)
selected_idx.append(left_idx)
selected_feature = np.concatenate(selected_feature, axis=0)
selected_feature = selected_feature.reshape(-1, self.embedding_dim) # 50*384
selected_feature = torch.from_numpy(selected_feature)
return selected_feature, selected_idx
def feature2chunks(self, feature_array, fps, audio_feat_length=[2, 2]):
whisper_chunks = []
whisper_idx_multiplier = 50.0 / fps
i = 0
print(f"video in {fps} FPS, audio idx in 50FPS")
while True:
start_idx = int(i * whisper_idx_multiplier)
selected_feature, selected_idx = self.get_sliced_feature(
feature_array=feature_array, vid_idx=i, audio_feat_length=audio_feat_length, fps=fps
)
# print(f"i:{i},selected_idx {selected_idx}")
whisper_chunks.append(selected_feature)
i += 1
if start_idx > len(feature_array):
break
return whisper_chunks
def _audio2feat(self, audio_path: str):
# get the sample rate of the audio
result = self.model.transcribe(audio_path)
embed_list = []
for emb in result["segments"]:
encoder_embeddings = emb["encoder_embeddings"]
encoder_embeddings = encoder_embeddings.transpose(0, 2, 1, 3)
encoder_embeddings = encoder_embeddings.squeeze(0)
start_idx = int(emb["start"])
end_idx = int(emb["end"])
emb_end_idx = int((end_idx - start_idx) / 2)
embed_list.append(encoder_embeddings[:emb_end_idx])
concatenated_array = torch.from_numpy(np.concatenate(embed_list, axis=0))
return concatenated_array
def audio2feat(self, audio_path):
if self.audio_embeds_cache_dir == "" or self.audio_embeds_cache_dir is None:
return self._audio2feat(audio_path)
audio_embeds_cache_path = os.path.join(self.audio_embeds_cache_dir, os.path.basename(audio_path) + ".pt")
if os.path.isfile(audio_embeds_cache_path):
try:
audio_feat = torch.load(audio_embeds_cache_path)
except Exception as e:
print(f"{type(e).__name__} - {e} - {audio_embeds_cache_path}")
os.remove(audio_embeds_cache_path)
audio_feat = self._audio2feat(audio_path)
torch.save(audio_feat, audio_embeds_cache_path)
else:
audio_feat = self._audio2feat(audio_path)
torch.save(audio_feat, audio_embeds_cache_path)
return audio_feat
def crop_overlap_audio_window(self, audio_feat, start_index):
selected_feature_list = []
for i in range(start_index, start_index + self.num_frames):
selected_feature, selected_idx = self.get_sliced_feature(
feature_array=audio_feat, vid_idx=i, audio_feat_length=[2, 2], fps=25
)
selected_feature_list.append(selected_feature)
mel_overlap = torch.stack(selected_feature_list)
return mel_overlap
if __name__ == "__main__":
audio_encoder = Audio2Feature(model_path="checkpoints/whisper/tiny.pt")
audio_path = "assets/demo1_audio.wav"
array = audio_encoder.audio2feat(audio_path)
print(array.shape)
fps = 25
whisper_idx_multiplier = 50.0 / fps
i = 0
print(f"video in {fps} FPS, audio idx in 50FPS")
while True:
start_idx = int(i * whisper_idx_multiplier)
selected_feature, selected_idx = audio_encoder.get_sliced_feature(
feature_array=array, vid_idx=i, audio_feat_length=[2, 2], fps=fps
)
print(f"video idx {i},\t audio idx {selected_idx},\t shape {selected_feature.shape}")
i += 1
if start_idx > len(array):
break