import math from typing import Dict, List, Optional, Union import numpy as np import transformers from transformers.tokenization_utils_base import AudioInput from transformers.utils import TensorType from transformers.feature_extraction_utils import BatchFeature from transformers import AutoFeatureExtractor, Wav2Vec2FeatureExtractor, Wav2Vec2Config def build_audio_tokens(text: List[str], audio_features: Union[Dict, List[List[np.ndarray]]], audio_token="<|audio|>") -> Dict: if not isinstance(audio_features, list): audio_features = audio_features['audio_features'] bs = audio_features.shape[0] for i in range(bs): for j in range(len(audio_features[i])): tgt_token = f"<|audio_{j+1}|>" * get_num_embeddings(audio_features[i][j].shape[0]) text[i] = text[i].replace(audio_token, tgt_token, 1) return text def calculate_output_length(length_in, kernel_size, stride=1, padding=0, dilation=1): return (length_in + 2 * padding - dilation * (kernel_size - 1) - 1) // stride + 1 def get_num_embeddings(wav_length: int, config: Wav2Vec2Config) -> int: curr_len = wav_length for i in range(config.num_feat_extract_layers): curr_len = calculate_output_length(curr_len, config.conv_kernel[i], stride=config.conv_stride[i]) curr_len = calculate_output_length(curr_len, config.adapter_kernel_size, stride=config.adapter_stride) return curr_len + 2 # 2 = <|begin_of_audio|>, <|end_of_audio|> class MllamaAudioFeatureExtractor(Wav2Vec2FeatureExtractor): def __call__( self, batch_audio_clips: List[List[AudioInput]], return_tensors: Optional[Union[str, TensorType]] = None, ) -> BatchFeature: audio_features = [[ super(MllamaAudioFeatureExtractor, self).__call__(audio_j, sampling_rate=16000, return_attention_mask=False)['input_features'][0] for audio_j in audio_i ] for audio_i in batch_audio_clips ] packed_audio_features = self.pack_audio_clips(audio_features) encoded_audio_inputs = BatchFeature( data={ "audio_features": packed_audio_features, }, tensor_type=return_tensors, ) return encoded_audio_inputs def pack_audio_clips(self, batch_audio_clips: List[List[np.ndarray]]) -> np.ndarray: assert batch_audio_clips[0][0].ndim == 2 # sequence length x feature dimension # Determine output shape: (batch_size, max_num_clips, max_frames, feature_dim) batch_size = len(batch_audio_clips) max_num_clips = max([len(clips) for clips in batch_audio_clips]) max_frames = max([clip.shape[0] for clips in batch_audio_clips for clip in clips]) feature_dim = batch_audio_clips[0][0].shape[1] stacked_audio_clips = np.zeros((batch_size, max_num_clips, max_frames, feature_dim), dtype=np.float32) for i, clips in enumerate(batch_audio_clips): for j, clip in enumerate(clips): stacked_audio_clips[i, j, :clip.shape[0], :] = clip return stacked_audio_clips AutoFeatureExtractor.register("MllamaAudioFeatureExtractor", MllamaAudioFeatureExtractor) transformers.MllamaAudioFeatureExtractor = MllamaAudioFeatureExtractor