# Adapted from https://github.com/TMElyralab/MuseTalk/blob/main/musetalk/whisper/audio2feature.py from .whisper import load_model import numpy as np import torch import os class Audio2Feature: def __init__( self, model_path="checkpoints/whisper/tiny.pt", device=None, audio_embeds_cache_dir=None, num_frames=16, ): self.model = load_model(model_path, device) self.audio_embeds_cache_dir = audio_embeds_cache_dir self.num_frames = num_frames self.embedding_dim = self.model.dims.n_audio_state def get_sliced_feature(self, feature_array, vid_idx, audio_feat_length=[2, 2], fps=25): """ Get sliced features based on a given index :param feature_array: :param start_idx: the start index of the feature :param audio_feat_length: :return: """ length = len(feature_array) selected_feature = [] selected_idx = [] center_idx = int(vid_idx * 50 / fps) left_idx = center_idx - audio_feat_length[0] * 2 right_idx = center_idx + (audio_feat_length[1] + 1) * 2 for idx in range(left_idx, right_idx): idx = max(0, idx) idx = min(length - 1, idx) x = feature_array[idx] selected_feature.append(x) selected_idx.append(idx) selected_feature = torch.cat(selected_feature, dim=0) selected_feature = selected_feature.reshape(-1, self.embedding_dim) # 50*384 return selected_feature, selected_idx def get_sliced_feature_sparse(self, feature_array, vid_idx, audio_feat_length=[2, 2], fps=25): """ Get sliced features based on a given index :param feature_array: :param start_idx: the start index of the feature :param audio_feat_length: :return: """ length = len(feature_array) selected_feature = [] selected_idx = [] for dt in range(-audio_feat_length[0], audio_feat_length[1] + 1): left_idx = int((vid_idx + dt) * 50 / fps) if left_idx < 1 or left_idx > length - 1: left_idx = max(0, left_idx) left_idx = min(length - 1, left_idx) x = feature_array[left_idx] x = x[np.newaxis, :, :] x = np.repeat(x, 2, axis=0) selected_feature.append(x) selected_idx.append(left_idx) selected_idx.append(left_idx) else: x = feature_array[left_idx - 1 : left_idx + 1] selected_feature.append(x) selected_idx.append(left_idx - 1) selected_idx.append(left_idx) selected_feature = np.concatenate(selected_feature, axis=0) selected_feature = selected_feature.reshape(-1, self.embedding_dim) # 50*384 selected_feature = torch.from_numpy(selected_feature) return selected_feature, selected_idx def feature2chunks(self, feature_array, fps, audio_feat_length=[2, 2]): whisper_chunks = [] whisper_idx_multiplier = 50.0 / fps i = 0 print(f"video in {fps} FPS, audio idx in 50FPS") while True: start_idx = int(i * whisper_idx_multiplier) selected_feature, selected_idx = self.get_sliced_feature( feature_array=feature_array, vid_idx=i, audio_feat_length=audio_feat_length, fps=fps ) # print(f"i:{i},selected_idx {selected_idx}") whisper_chunks.append(selected_feature) i += 1 if start_idx > len(feature_array): break return whisper_chunks def _audio2feat(self, audio_path: str): # get the sample rate of the audio result = self.model.transcribe(audio_path) embed_list = [] for emb in result["segments"]: encoder_embeddings = emb["encoder_embeddings"] encoder_embeddings = encoder_embeddings.transpose(0, 2, 1, 3) encoder_embeddings = encoder_embeddings.squeeze(0) start_idx = int(emb["start"]) end_idx = int(emb["end"]) emb_end_idx = int((end_idx - start_idx) / 2) embed_list.append(encoder_embeddings[:emb_end_idx]) concatenated_array = torch.from_numpy(np.concatenate(embed_list, axis=0)) return concatenated_array def audio2feat(self, audio_path): if self.audio_embeds_cache_dir == "" or self.audio_embeds_cache_dir is None: return self._audio2feat(audio_path) audio_embeds_cache_path = os.path.join(self.audio_embeds_cache_dir, os.path.basename(audio_path) + ".pt") if os.path.isfile(audio_embeds_cache_path): try: audio_feat = torch.load(audio_embeds_cache_path) except Exception as e: print(f"{type(e).__name__} - {e} - {audio_embeds_cache_path}") os.remove(audio_embeds_cache_path) audio_feat = self._audio2feat(audio_path) torch.save(audio_feat, audio_embeds_cache_path) else: audio_feat = self._audio2feat(audio_path) torch.save(audio_feat, audio_embeds_cache_path) return audio_feat def crop_overlap_audio_window(self, audio_feat, start_index): selected_feature_list = [] for i in range(start_index, start_index + self.num_frames): selected_feature, selected_idx = self.get_sliced_feature( feature_array=audio_feat, vid_idx=i, audio_feat_length=[2, 2], fps=25 ) selected_feature_list.append(selected_feature) mel_overlap = torch.stack(selected_feature_list) return mel_overlap if __name__ == "__main__": audio_encoder = Audio2Feature(model_path="checkpoints/whisper/tiny.pt") audio_path = "assets/demo1_audio.wav" array = audio_encoder.audio2feat(audio_path) print(array.shape) fps = 25 whisper_idx_multiplier = 50.0 / fps i = 0 print(f"video in {fps} FPS, audio idx in 50FPS") while True: start_idx = int(i * whisper_idx_multiplier) selected_feature, selected_idx = audio_encoder.get_sliced_feature( feature_array=array, vid_idx=i, audio_feat_length=[2, 2], fps=fps ) print(f"video idx {i},\t audio idx {selected_idx},\t shape {selected_feature.shape}") i += 1 if start_idx > len(array): break