|
import math |
|
from typing import Dict, List, Optional, Union |
|
import numpy as np |
|
import transformers |
|
from transformers.tokenization_utils_base import AudioInput |
|
from transformers.utils import TensorType |
|
from transformers.feature_extraction_utils import BatchFeature |
|
from transformers import AutoFeatureExtractor, Wav2Vec2FeatureExtractor, Wav2Vec2Config |
|
|
|
|
|
def build_audio_tokens(text: List[str], audio_features: Union[List[List[np.ndarray]]], audio_token="<|audio|>") -> Dict: |
|
bs = len(audio_features) |
|
for i in range(bs): |
|
for j in range(len(audio_features[i])): |
|
tgt_token = f"<|audio_{j+1}|>" * get_num_embeddings(audio_features[i][j].shape[0]) |
|
text[i] = text[i].replace(audio_token, tgt_token, 1) |
|
return text |
|
|
|
def calculate_output_length(length_in, kernel_size, stride=1, padding=0, dilation=1): |
|
return (length_in + 2 * padding - dilation * (kernel_size - 1) - 1) // stride + 1 |
|
|
|
def get_num_embeddings(wav_length: int) -> int: |
|
num_feat_extract_layers = 7 |
|
conv_kernel = [10, 3, 3, 3, 3, 2, 2] |
|
conv_stride = [5, 2, 2, 2, 2, 2, 2] |
|
adapter_kernel_size = 7 |
|
adapter_stride = 4 |
|
|
|
curr_len = wav_length |
|
for i in range(num_feat_extract_layers): |
|
curr_len = calculate_output_length(curr_len, conv_kernel[i], stride=conv_stride[i]) |
|
curr_len = calculate_output_length(curr_len, adapter_kernel_size, stride=adapter_stride, padding=adapter_stride//2) |
|
return curr_len + 2 |
|
|
|
class MllamaAudioFeatureExtractor(Wav2Vec2FeatureExtractor): |
|
|
|
def __call__( |
|
self, |
|
batch_audio_clips: List[List[AudioInput]], |
|
return_tensors: Optional[Union[str, TensorType]] = None, |
|
) -> BatchFeature: |
|
audio_features = [[ super(MllamaAudioFeatureExtractor, self).__call__(audio_j, sampling_rate=16000, return_attention_mask=False)['input_values'][0] for audio_j in audio_i ] for audio_i in batch_audio_clips ] |
|
packed_audio_features = self.pack_audio_clips(audio_features) |
|
|
|
encoded_audio_inputs = BatchFeature( |
|
data={ |
|
"audio_features": packed_audio_features, |
|
}, |
|
tensor_type=return_tensors, |
|
) |
|
|
|
return encoded_audio_inputs |
|
|
|
def pack_audio_clips(self, batch_audio_clips: List[List[np.ndarray]]) -> np.ndarray: |
|
assert batch_audio_clips[0][0].ndim == 1 |
|
|
|
batch_size = len(batch_audio_clips) |
|
max_num_clips = max([len(clips) for clips in batch_audio_clips]) |
|
max_frames = max([clip.shape[0] for clips in batch_audio_clips for clip in clips]) |
|
|
|
stacked_audio_clips = np.zeros((batch_size, max_num_clips, max_frames), dtype=np.float32) |
|
for i, clips in enumerate(batch_audio_clips): |
|
for j, clip in enumerate(clips): |
|
stacked_audio_clips[i, j, :clip.shape[0]] = clip |
|
|
|
return stacked_audio_clips |
|
|
|
AutoFeatureExtractor.register("MllamaAudioFeatureExtractor", MllamaAudioFeatureExtractor) |
|
transformers.MllamaAudioFeatureExtractor = MllamaAudioFeatureExtractor |